diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index a3a0aa8fb..9840e4e9e 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -1517,7 +1517,8 @@ class StaticCache(tp.NamedTuple): graphdef: GraphDef[tp.Any] final_graphdef: GraphDef[tp.Any] paths: tuple[PathParts, ...] - variables: list[Variable[tp.Any]] + # Variables or raw array leaves (e.g. jax.Array attributes) + variables: list[Variable[tp.Any] | tp.Any] new_ref_index: RefMap new_index_ref: IndexMap @@ -1525,7 +1526,7 @@ class StaticCache(tp.NamedTuple): def create( graphdef: GraphDef[tp.Any], paths: tuple[PathParts, ...], - variables: list[Variable[tp.Any]], + variables: list[Variable[tp.Any] | tp.Any], new_ref_index: RefMap, ): new_index_ref = IndexMap.from_refmap(new_ref_index) @@ -1630,6 +1631,12 @@ def _cached_partial( mutations are allowed (e.g. the use of ``Module.sow``) as long as they are cleaned up before the function returns (e.g. via ``nnx.pop``). + Raw array attributes (e.g. ``jax.Array`` or ``numpy.ndarray`` values not wrapped in a + ``Variable``) found in the cached graph nodes are treated as read-only: their values are + visible inside the function but any updates to them are not propagated back to the + original graph nodes. Attributes that must change across cached calls should be + wrapped in an ``nnx.Variable``. + Args: f: A function to cache. *cached_args: A subset of the input arguments containing the graph nodes to cache. @@ -1832,8 +1839,12 @@ def flatten( # type: ignore[invalid-annotation] leaves = node_static_cache.variables else: paths = None + # non-Variable leaves (e.g. raw arrays) are passed through as-is leaves = [ - variable.get_raw_value() for variable in node_static_cache.variables + variable.get_raw_value() + if isinstance(variable, Variable) + else variable + for variable in node_static_cache.variables ] else: graphdef, flat_state = flatten( @@ -1967,6 +1978,10 @@ def unflatten( # type: ignore[invalid-annotation] f'leaves in the state, got {len(leaves)}' ) for variable, leaf in zip(static_cache_node.variables, leaves): + if not isinstance(variable, Variable): + # raw array attributes are read-only under cached_partial, + # updates are not propagated back to the original graph nodes + continue if isinstance(leaf, Variable): variable.update_from_state(leaf) else: diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index a1795caf0..99c1b5a4c 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -608,6 +608,32 @@ def f(cached_m: nnx.Linear, m: nnx.Linear): cached_m2 = cached_f(m) self.assertIs(cached_m, cached_m2) + def test_cache_args_raw_arrays(self): + class Foo(nnx.Module): + def __init__(self): + self.a = jnp.arange(3, dtype=jnp.float32) # raw jax.Array + self.b = np.arange(3, dtype=np.float32) # raw np.ndarray + self.count = nnx.Variable(jnp.array(0)) + + def __call__(self, x): + self.count[...] += 1 + return self.a * x + self.b + + m = Foo() + + @nnx.jit(graph=True, graph_updates=True) + def f(m: Foo, x): + return m(x) + 1 + + x = jnp.ones(3) + expected = f(m, x) + + cached_f = nnx.compat.cached_partial(f, m) + np.testing.assert_allclose(cached_f(x), expected) + np.testing.assert_allclose(cached_f(x), expected) + # Variable updates are propagated back to the original node + self.assertEqual(m.count[...], 3) + @parameterized.parameters(True, False) def test_cache_args_functional(self, graph_mode): m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))