Skip to content

cache cond/switch branch wrappers to avoid tracing cache misses#5517

Open
mohsinm-dev wants to merge 1 commit into
google:mainfrom
mohsinm-dev:fix-cond-tracing-cache-misses
Open

cache cond/switch branch wrappers to avoid tracing cache misses#5517
mohsinm-dev wants to merge 1 commit into
google:mainfrom
mohsinm-dev:fix-cond-tracing-cache-misses

Conversation

@mohsinm-dev

Copy link
Copy Markdown
Contributor

Fixes #5512

nnx.cond and nnx.switch create fresh SimpleCondFn wrappers and merge_inputs closures on every call. JAX keys its tracing cache on callable identity, so new wrappers cause repeated cache misses even when the underlying branch functions are stable.

Cache wrappers in a WeakKeyDictionary keyed by the original callable:

  • SimpleCondFn wrappers via _get_simple_cond_fn (graph_updates=False path)
  • merge_inputs wrappers cached inline (graph_updates=True path)

Non-weakly-referenceable callables fall back to fresh wrappers (same behavior as before).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

flax.nnx.cond causes tracing cache misses

1 participant