From 98158056d91b1e93c5ad335711fd7a2b55150cd4 Mon Sep 17 00:00:00 2001 From: Philipp Sinitsin Date: Sat, 13 Jun 2026 12:46:05 +0100 Subject: [PATCH] fix cached_partial for raw array attributes nnx.cached_partial crashed with an AttributeError when cached graph nodes contained raw jax or numpy array attributes not wrapped in nnx.Variable, because StaticCache assumed all leaves were Variables. Raw array leaves are now passed through as-is when flattening and skipped during update propagation when merging, i.e. they are treated as read-only. The read-only semantics are documented in the cached_partial docstring and covered by a regression test. Fixes #5109 --- flax/nnx/graphlib.py | 21 ++++++++++++++++++--- tests/nnx/transforms_test.py | 26 ++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index a3a0aa8fb..9840e4e9e 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -1517,7 +1517,8 @@ class StaticCache(tp.NamedTuple): graphdef: GraphDef[tp.Any] final_graphdef: GraphDef[tp.Any] paths: tuple[PathParts, ...] - variables: list[Variable[tp.Any]] + # Variables or raw array leaves (e.g. jax.Array attributes) + variables: list[Variable[tp.Any] | tp.Any] new_ref_index: RefMap new_index_ref: IndexMap @@ -1525,7 +1526,7 @@ class StaticCache(tp.NamedTuple): def create( graphdef: GraphDef[tp.Any], paths: tuple[PathParts, ...], - variables: list[Variable[tp.Any]], + variables: list[Variable[tp.Any] | tp.Any], new_ref_index: RefMap, ): new_index_ref = IndexMap.from_refmap(new_ref_index) @@ -1630,6 +1631,12 @@ def _cached_partial( mutations are allowed (e.g. the use of ``Module.sow``) as long as they are cleaned up before the function returns (e.g. via ``nnx.pop``). + Raw array attributes (e.g. ``jax.Array`` or ``numpy.ndarray`` values not wrapped in a + ``Variable``) found in the cached graph nodes are treated as read-only: their values are + visible inside the function but any updates to them are not propagated back to the + original graph nodes. Attributes that must change across cached calls should be + wrapped in an ``nnx.Variable``. + Args: f: A function to cache. *cached_args: A subset of the input arguments containing the graph nodes to cache. @@ -1832,8 +1839,12 @@ def flatten( # type: ignore[invalid-annotation] leaves = node_static_cache.variables else: paths = None + # non-Variable leaves (e.g. raw arrays) are passed through as-is leaves = [ - variable.get_raw_value() for variable in node_static_cache.variables + variable.get_raw_value() + if isinstance(variable, Variable) + else variable + for variable in node_static_cache.variables ] else: graphdef, flat_state = flatten( @@ -1967,6 +1978,10 @@ def unflatten( # type: ignore[invalid-annotation] f'leaves in the state, got {len(leaves)}' ) for variable, leaf in zip(static_cache_node.variables, leaves): + if not isinstance(variable, Variable): + # raw array attributes are read-only under cached_partial, + # updates are not propagated back to the original graph nodes + continue if isinstance(leaf, Variable): variable.update_from_state(leaf) else: diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index a1795caf0..99c1b5a4c 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -608,6 +608,32 @@ def f(cached_m: nnx.Linear, m: nnx.Linear): cached_m2 = cached_f(m) self.assertIs(cached_m, cached_m2) + def test_cache_args_raw_arrays(self): + class Foo(nnx.Module): + def __init__(self): + self.a = jnp.arange(3, dtype=jnp.float32) # raw jax.Array + self.b = np.arange(3, dtype=np.float32) # raw np.ndarray + self.count = nnx.Variable(jnp.array(0)) + + def __call__(self, x): + self.count[...] += 1 + return self.a * x + self.b + + m = Foo() + + @nnx.jit(graph=True, graph_updates=True) + def f(m: Foo, x): + return m(x) + 1 + + x = jnp.ones(3) + expected = f(m, x) + + cached_f = nnx.compat.cached_partial(f, m) + np.testing.assert_allclose(cached_f(x), expected) + np.testing.assert_allclose(cached_f(x), expected) + # Variable updates are propagated back to the original node + self.assertEqual(m.count[...], 3) + @parameterized.parameters(True, False) def test_cache_args_functional(self, graph_mode): m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))