Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions flax/nnx/graphlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,15 +1517,16 @@ class StaticCache(tp.NamedTuple):
graphdef: GraphDef[tp.Any]
final_graphdef: GraphDef[tp.Any]
paths: tuple[PathParts, ...]
variables: list[Variable[tp.Any]]
# Variables or raw array leaves (e.g. jax.Array attributes)
variables: list[Variable[tp.Any] | tp.Any]
new_ref_index: RefMap
new_index_ref: IndexMap

@staticmethod
def create(
graphdef: GraphDef[tp.Any],
paths: tuple[PathParts, ...],
variables: list[Variable[tp.Any]],
variables: list[Variable[tp.Any] | tp.Any],
new_ref_index: RefMap,
):
new_index_ref = IndexMap.from_refmap(new_ref_index)
Expand Down Expand Up @@ -1630,6 +1631,12 @@ def _cached_partial(
mutations are allowed (e.g. the use of ``Module.sow``) as long as they are cleaned up before
the function returns (e.g. via ``nnx.pop``).

Raw array attributes (e.g. ``jax.Array`` or ``numpy.ndarray`` values not wrapped in a
``Variable``) found in the cached graph nodes are treated as read-only: their values are
visible inside the function but any updates to them are not propagated back to the
original graph nodes. Attributes that must change across cached calls should be
wrapped in an ``nnx.Variable``.

Args:
f: A function to cache.
*cached_args: A subset of the input arguments containing the graph nodes to cache.
Expand Down Expand Up @@ -1832,8 +1839,12 @@ def flatten( # type: ignore[invalid-annotation]
leaves = node_static_cache.variables
else:
paths = None
# non-Variable leaves (e.g. raw arrays) are passed through as-is
leaves = [
variable.get_raw_value() for variable in node_static_cache.variables
variable.get_raw_value()
if isinstance(variable, Variable)
else variable
for variable in node_static_cache.variables
]
else:
graphdef, flat_state = flatten(
Expand Down Expand Up @@ -1967,6 +1978,10 @@ def unflatten( # type: ignore[invalid-annotation]
f'leaves in the state, got {len(leaves)}'
)
for variable, leaf in zip(static_cache_node.variables, leaves):
if not isinstance(variable, Variable):
# raw array attributes are read-only under cached_partial,
# updates are not propagated back to the original graph nodes
continue
if isinstance(leaf, Variable):
variable.update_from_state(leaf)
else:
Expand Down
26 changes: 26 additions & 0 deletions tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,32 @@ def f(cached_m: nnx.Linear, m: nnx.Linear):
cached_m2 = cached_f(m)
self.assertIs(cached_m, cached_m2)

def test_cache_args_raw_arrays(self):
class Foo(nnx.Module):
def __init__(self):
self.a = jnp.arange(3, dtype=jnp.float32) # raw jax.Array
self.b = np.arange(3, dtype=np.float32) # raw np.ndarray
self.count = nnx.Variable(jnp.array(0))

def __call__(self, x):
self.count[...] += 1
return self.a * x + self.b

m = Foo()

@nnx.jit(graph=True, graph_updates=True)
def f(m: Foo, x):
return m(x) + 1

x = jnp.ones(3)
expected = f(m, x)

cached_f = nnx.compat.cached_partial(f, m)
np.testing.assert_allclose(cached_f(x), expected)
np.testing.assert_allclose(cached_f(x), expected)
# Variable updates are propagated back to the original node
self.assertEqual(m.count[...], 3)

@parameterized.parameters(True, False)
def test_cache_args_functional(self, graph_mode):
m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
Expand Down