From 7310f798e1cafdce54ff93ca5088c11ccdf40e2a Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Mon, 15 Jun 2026 17:03:56 -0700 Subject: [PATCH] add nnx:v1 PiperOrigin-RevId: 932747913 --- benchmarks/nnx_graph_overhead.py | 33 ++- benchmarks/nnx_simple_training.py | 84 +++++- flax/nnx/__init__.py | 1 + flax/nnx/bridge/wrappers.py | 2 +- flax/nnx/compat.py | 2 +- flax/nnx/extract.py | 74 ++--- flax/nnx/nn/activations.py | 2 +- flax/nnx/nn/attention.py | 2 +- flax/nnx/nn/linear.py | 2 +- flax/nnx/nn/normalization.py | 2 +- flax/nnx/nn/recurrent.py | 2 +- flax/nnx/nn/stochastic.py | 2 +- flax/nnx/pytreelib.py | 11 +- flax/nnx/spmd.py | 2 +- flax/nnx/summary.py | 2 +- flax/nnx/training/optimizer.py | 2 +- flax/nnx/transforms/compilation.py | 430 +++++++++++++++++++++++------ flax/nnx/transforms/iteration.py | 13 +- tests/nnx/transforms_test.py | 11 +- 19 files changed, 514 insertions(+), 165 deletions(-) diff --git a/benchmarks/nnx_graph_overhead.py b/benchmarks/nnx_graph_overhead.py index fd20fc5a8..92ffc7c4e 100644 --- a/benchmarks/nnx_graph_overhead.py +++ b/benchmarks/nnx_graph_overhead.py @@ -25,7 +25,7 @@ FLAGS = flags.FLAGS flags.DEFINE_enum( - 'mode', 'nnx', ['all', 'nnx', 'jax'], 'Mode to run the script in' + 'mode', 'nnx', ['all', 'nnx', 'jax', 'jit_partial'], 'Mode to run the script in' ) flags.DEFINE_integer('total_steps', 100, 'Total number of training steps') flags.DEFINE_integer('width', 32, 'Hidden layer size') @@ -91,7 +91,6 @@ def main(argv): model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) tx = optax.sgd(1e-3) optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) - t0 = time() @nnx.jit def step_nnx(model: MLP, optimizer: nnx.Optimizer): @@ -108,6 +107,35 @@ def step_nnx(model: MLP, optimizer: nnx.Optimizer): print('total time:', total_time) print(f'time per step: {time_per_step * 1e6:.2f} µs') print(f'time per layer: {time_per_layer * 1e6:.2f} µs') + print() + + # ------------------------------------------------------------ + # JIT Partial + # ------------------------------------------------------------ + if mode in ['all', 'jit_partial']: + model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) + tx = optax.sgd(1e-3) + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + + def step_partial(model: MLP, optimizer: nnx.Optimizer): + pass + + step_partial_jit = nnx.jit_partial( + step_partial, model, optimizer, graph=False + ) + + t0 = time() + for _ in range(total_steps): + step_partial_jit() + + total_time = time() - t0 + time_per_step = total_time / total_steps + time_per_layer = time_per_step / depth + print('### JIT PARTIAL ###') + print('total time:', total_time) + print(f'time per step: {time_per_step * 1e6:.2f} µs') + print(f'time per layer: {time_per_layer * 1e6:.2f} µs') + print() # ------------------------------------------------------------ # JAX @@ -117,7 +145,6 @@ def step_nnx(model: MLP, optimizer: nnx.Optimizer): model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) tx = optax.sgd(1e-3) optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) - t0 = time() @jax.jit def step_jax(graphdef, state): diff --git a/benchmarks/nnx_simple_training.py b/benchmarks/nnx_simple_training.py index e134fa317..2f543b1d8 100644 --- a/benchmarks/nnx_simple_training.py +++ b/benchmarks/nnx_simple_training.py @@ -13,6 +13,9 @@ # limitations under the License. # %% +import cProfile +import pstats +import io from functools import partial import jax import jax.numpy as jnp @@ -27,12 +30,13 @@ FLAGS = flags.FLAGS flags.DEFINE_enum( - 'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in' + 'mode', 'all', ['all', 'nnx', 'jax', 'jit_partial'], 'Mode to run the script in' ) flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps') flags.DEFINE_integer('batch_size', 32, 'Batch size') flags.DEFINE_integer('width', 32, 'Hidden layer size') flags.DEFINE_integer('depth', 5, 'Depth of the model') +flags.DEFINE_bool('profile', False, 'Enable cProfile profiling') def dataset(X, Y, batch_size): @@ -67,13 +71,13 @@ class MLP(nnx.Module): def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs): self.count = Count(jnp.array(0)) self.linear_in = Block(din, dhidden, rngs=rngs) - self.intermediates = [ + self.intermediates = nnx.List([ Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2) - ] + ]) self.linear_out = Block(dhidden, dout, rngs=rngs) def __call__(self, x): - self.count.value += 1 + self.count[...] += 1 x = nnx.relu(self.linear_in(x)) for layer in self.intermediates: x = nnx.relu(layer(x)) @@ -118,6 +122,7 @@ def test_step_nnx(model: MLP, batch): loss = jnp.mean((y - y_pred) ** 2) return {'loss': loss} + logs = {'loss': jnp.array(0.0)} for step, batch in enumerate(dataset(X, Y, batch_size)): train_step_nnx(model, optimizer, batch) @@ -132,7 +137,73 @@ def test_step_nnx(model: MLP, batch): total_time = time() - t0 print('total time:', total_time) print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') - print('times called:', model.count.value) + print('times called:', model.count[...]) + print() + + if mode == 'jit_partial' or mode == 'all': + model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) + tx = optax.sgd(1e-3) + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + t0 = time() + + def train_step(model: MLP, optimizer: nnx.Optimizer, batch): + x, y = batch + + def loss_fn(model: MLP): + y_pred = model(x) + return jnp.mean((y - y_pred) ** 2) + + grads = nnx.grad(loss_fn)(model) + optimizer.update(model, grads) + + def test_step(model: MLP, batch): + x, y = batch + y_pred = model(x) + loss = jnp.mean((y - y_pred) ** 2) + return {'loss': loss} + + train_step_fn = nnx.jit_partial( + train_step, model, optimizer, graph=False + ) + test_step_fn = nnx.jit_partial(test_step, model, graph=False) + + logs = {'loss': jnp.array(0.0)} + # Warmup + for step, batch in enumerate(dataset(X, Y, batch_size)): + train_step_fn(batch) + if step >= 10: + break + + pr = None + if FLAGS.profile: + pr = cProfile.Profile() + pr.enable() + + for step, batch in enumerate(dataset(X, Y, batch_size)): + train_step_fn(batch) + + if step % 1000 == 0: + logs = test_step_fn((X, Y)) + + if step >= total_steps - 1: + break + + if pr is not None: + pr.disable() + for sort_key in ('cumulative', 'tottime'): + s = io.StringIO() + ps = pstats.Stats(pr, stream=s) + ps.sort_stats(sort_key) + ps.print_stats(40) + print(s.getvalue()) + + print('### JIT PARTIAL ###') + print(f'final loss: {logs["loss"]}') + total_time = time() - t0 + print('total time:', total_time) + print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') + print('times called:', model.count[...]) + print() if mode == 'jax' or mode == 'all': model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) @@ -165,6 +236,7 @@ def test_step_jax(state, batch): graphdef, state = nnx.split((model, optimizer)) + logs = {'loss': jnp.array(0.0)} for step, batch in enumerate(dataset(X, Y, batch_size)): state = train_step_jax(state, batch) @@ -181,7 +253,7 @@ def test_step_jax(state, batch): total_time = time() - t0 print('total time:', total_time) print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') - print('times called:', model.count.value) + print('times called:', model.count[...]) if __name__ == '__main__': diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 78a550411..d43e05794 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from flax.core.spmd import logical_axis_rules as logical_axis_rules from flax.linen.pooling import avg_pool as avg_pool from flax.linen.pooling import max_pool as max_pool diff --git a/flax/nnx/bridge/wrappers.py b/flax/nnx/bridge/wrappers.py index a68dfa141..d47c343c1 100644 --- a/flax/nnx/bridge/wrappers.py +++ b/flax/nnx/bridge/wrappers.py @@ -20,7 +20,7 @@ from flax import linen from flax import core -from flax import nnx +from flax.nnx import src as nnx from flax.core import FrozenDict from flax.core import meta from flax.nnx import graphlib diff --git a/flax/nnx/compat.py b/flax/nnx/compat.py index 0a07f3b72..59eb7f9db 100644 --- a/flax/nnx/compat.py +++ b/flax/nnx/compat.py @@ -20,7 +20,7 @@ Example:: - from flax import nnx + from flax.nnx import src as nnx graphdef, state = nnx.compat.split(model) # graph=True by default diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index cc816075b..303432788 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -732,7 +732,7 @@ def check_prefix( graph_updates: bool, none_leaf: bool = True, ): - unique_prefixes: OrderedDict[tp.Any, tp.Any] = OrderedDict() + unique_prefixes: set[tp.Any] = set() def _check_prefix(path, leaf): if isinstance(leaf, variablelib.Variable): @@ -798,12 +798,12 @@ def _check_prefix(path, leaf): ) def _collect_prefix(_, leaf): - unique_prefixes[leaf] = leaf + unique_prefixes.add(leaf) jax.tree.map_with_path( _collect_prefix, prefix, is_leaf=lambda x: x is None and none_leaf ) - return unique_prefixes + return list(unique_prefixes) def variable_changed(post: variablelib.Variable, pre: variablelib.Variable) -> bool: @@ -818,27 +818,28 @@ def variable_changed(post: variablelib.Variable, pre: variablelib.Variable) -> b ] +@dataclasses.dataclass(slots=True) class Updates( tp.Sequence[tuple[jax.tree_util.KeyPath, variablelib.Variable]], reprlib.Representable, ): - __slots__ = ('_keys', '_values') + _keys: list[tp.Any] = dataclasses.field(default_factory=list) + _values: list[variablelib.Variable] = dataclasses.field(default_factory=list) - _keys: list[jax.tree_util.KeyPath] - _values: list[variablelib.Variable] - - def __init__( - self, + @classmethod + def create( + cls, items: tp.Iterable[ tuple[jax.tree_util.KeyPath, variablelib.Variable] ] = (), - ): - self._keys, self._values = [], [] + ) -> 'Updates': + keys, values = [], [] for key, value in items: - self._keys.append(key) - self._values.append(value) + keys.append(key) + values.append(value) + return cls(_keys=keys, _values=values) - def append(self, key: jax.tree_util.KeyPath, value: variablelib.Variable): + def append(self, key: tp.Any, value: variablelib.Variable): self._keys.append(key) self._values.append(value) @@ -880,7 +881,7 @@ def __len__(self): return len(self._keys) def __iter__(self): - return iter(zip(self._keys, self._values)) + return zip(self._keys, self._values) def __nnx_repr__(self): yield reprlib.Object(type=type(self), kv_sep=': ', start='({', end='})') @@ -892,30 +893,10 @@ def __nnx_repr__(self): ) -def _updates_flatten_with_keys(x: Updates): - key_children = [ - (jax.tree_util.FlattenedIndexKey(i), v) - for i, v in enumerate(x._values) - ] - return key_children, x._keys - - -def _updates_flatten(x: Updates): - return x._values, x._keys - - -def _updates_unflatten(keys, values) -> Updates: - updates = object.__new__(Updates) - updates._keys = keys - updates._values = list(values) - return updates - - -jax.tree_util.register_pytree_with_keys( +jax.tree_util.register_dataclass( Updates, - _updates_flatten_with_keys, - _updates_unflatten, - flatten_func=_updates_flatten, + data_fields=['_values'], + meta_fields=['_keys'], ) def get_updates( @@ -929,7 +910,7 @@ def get_updates( if keep_fn is None: keep_fn = lambda _, _pfx, cur, snap: variable_changed(cur, snap) - updates = OrderedDict((pfx, Updates()) for pfx in known_prefixes) + updates = {pfx: Updates.create() for pfx in known_prefixes} def _mask_updates(path, prefix_leaf, current, snapshot): if isinstance(current, variablelib.Variable): @@ -944,14 +925,14 @@ def _mask_updates(path, prefix_leaf, current, snapshot): _mask_updates, prefix, current_tree, snapshot_tree, is_leaf=is_leaf, prefix_leaf=prefix_leaf, ) - return updates + return list(updates.values()) def apply_updates( variables: dict[jax.tree_util.KeyPath, variablelib.Variable], - updates: OrderedDict[tp.Any, Updates], + updates: list[Updates], ): - for _, flat_state in updates.items(): + for flat_state in updates: for path, update in flat_state: if path in variables: variable = variables[path] @@ -965,6 +946,7 @@ def apply_updates( ) + def treemap_copy_args(f: F) -> F: @functools.wraps(f) def wrapper(*args, **kwargs): @@ -1017,7 +999,7 @@ def prefix( Example usage:: - from flax import nnx + from flax.nnx import src as nnx import jax.numpy as jnp d = {'a': nnx.Param(jnp.array(2)), 'b': nnx.BatchStat(jnp.arange(5))} @@ -1110,9 +1092,9 @@ def _apply_prefix(jax_path, leaf): return jax.tree.map_with_path(_apply_prefix, node, is_leaf=is_leaf) -def to_masked(tree, all_updates: OrderedDict[tp.Any, Updates]): - combined: OrderedDict[tp.Any, tp.Any] = OrderedDict() - for updates in all_updates.values(): +def to_masked(tree, all_updates: list[Updates]): + combined: dict[tp.Any, tp.Any] = {} + for updates in all_updates: combined.update(updates) return jax.tree.map_with_path( lambda path, _: combined.get(path, None), tree, diff --git a/flax/nnx/nn/activations.py b/flax/nnx/nn/activations.py index ae5f43ef8..2424e29db 100644 --- a/flax/nnx/nn/activations.py +++ b/flax/nnx/nn/activations.py @@ -43,7 +43,7 @@ import jax.numpy as jnp from jax.numpy import tanh -from flax import nnx +from flax.nnx import src as nnx from flax.nnx.nn import dtypes from flax.typing import Array, Dtype, PromoteDtypeFn diff --git a/flax/nnx/nn/attention.py b/flax/nnx/nn/attention.py index 4f0c4f0cd..6740153e8 100644 --- a/flax/nnx/nn/attention.py +++ b/flax/nnx/nn/attention.py @@ -27,7 +27,7 @@ import jax.numpy as jnp from jax import lax, random -from flax import nnx +from flax.nnx import src as nnx from flax.nnx import rnglib from flax.nnx.module import Module, first_from from flax.nnx.nn import initializers diff --git a/flax/nnx/nn/linear.py b/flax/nnx/nn/linear.py index d8e49023b..a23c85097 100644 --- a/flax/nnx/nn/linear.py +++ b/flax/nnx/nn/linear.py @@ -23,7 +23,7 @@ import opt_einsum from flax.core.frozen_dict import FrozenDict -from flax import nnx +from flax.nnx import src as nnx from flax.nnx import rnglib, variablelib from flax.nnx.module import Module, first_from from flax.nnx.nn import dtypes, initializers diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py index e06ab9fcd..49886a13c 100644 --- a/flax/nnx/nn/normalization.py +++ b/flax/nnx/nn/normalization.py @@ -19,7 +19,7 @@ import jax.numpy as jnp from jax import lax -from flax import nnx +from flax.nnx import src as nnx from flax.nnx import rnglib from flax.nnx.module import Module, first_from from flax.nnx.nn import dtypes, initializers diff --git a/flax/nnx/nn/recurrent.py b/flax/nnx/nn/recurrent.py index 9df3f065e..60033d5db 100644 --- a/flax/nnx/nn/recurrent.py +++ b/flax/nnx/nn/recurrent.py @@ -26,7 +26,7 @@ import jax import jax.numpy as jnp -from flax import nnx +from flax.nnx import src as nnx from flax.nnx import filterlib, rnglib from flax.nnx.module import Module from flax.nnx.nn import initializers, dtypes diff --git a/flax/nnx/nn/stochastic.py b/flax/nnx/nn/stochastic.py index 6d03e7353..d429aa766 100644 --- a/flax/nnx/nn/stochastic.py +++ b/flax/nnx/nn/stochastic.py @@ -21,7 +21,7 @@ from flax.nnx import rnglib from flax.nnx.module import Module, first_from -from flax import nnx +from flax.nnx import src as nnx class Dropout(Module): diff --git a/flax/nnx/pytreelib.py b/flax/nnx/pytreelib.py index 201b8026c..382854bda 100644 --- a/flax/nnx/pytreelib.py +++ b/flax/nnx/pytreelib.py @@ -29,7 +29,8 @@ import treescope # type: ignore[import-untyped] from treescope import rendering_parts -from flax import errors, nnx +from flax import errors +from flax.nnx import src as nnx from flax.nnx import ( graphlib, reprlib, @@ -78,7 +79,7 @@ def data(value: tp.Any = MISSING, /, **kwargs) -> tp.Any: Example:: - from flax import nnx + from flax.nnx import src as nnx import jax class Foo(nnx.Pytree): @@ -123,7 +124,7 @@ def register_data_type(type_: T, /) -> T: Example:: - from flax import nnx + from flax.nnx import src as nnx from dataclasses import dataclass @dataclass(frozen=True) @@ -240,7 +241,7 @@ def static(value: tp.Any = MISSING, /, **kwargs) -> tp.Any: Example:: - from flax import nnx + from flax.nnx import src as nnx class Foo(nnx.Pytree): def __init__(self, a, b): @@ -837,7 +838,7 @@ def __nnx_repr__(self): OBJECT_CONTEXT.node_stats = None def __treescope_repr__(self, path, subtree_renderer): - from flax import nnx + from flax.nnx import src as nnx if OBJECT_CONTEXT.node_stats is None or id(self) not in OBJECT_CONTEXT.node_stats: node_stats: dict[int, dict[type[Variable], SizeBytes]] = {} diff --git a/flax/nnx/spmd.py b/flax/nnx/spmd.py index a80eb8134..72ce4eaf9 100644 --- a/flax/nnx/spmd.py +++ b/flax/nnx/spmd.py @@ -199,7 +199,7 @@ def as_abstract( Example usage:: - from flax import nnx + from flax.nnx import src as nnx import jax mesh = jax.make_mesh((2, 2), ('a', 'b'), diff --git a/flax/nnx/summary.py b/flax/nnx/summary.py index c6b29d99b..3b2ae7174 100644 --- a/flax/nnx/summary.py +++ b/flax/nnx/summary.py @@ -30,7 +30,7 @@ import yaml import jax.numpy as jnp -from flax import nnx +from flax.nnx import src as nnx from flax import typing from flax.nnx import graphlib, statelib, variablelib diff --git a/flax/nnx/training/optimizer.py b/flax/nnx/training/optimizer.py index 1ca14ac1e..3bdf4c827 100644 --- a/flax/nnx/training/optimizer.py +++ b/flax/nnx/training/optimizer.py @@ -20,7 +20,7 @@ import jax.numpy as jnp import optax -from flax import nnx +from flax.nnx import src as nnx from flax.nnx import filterlib from flax.nnx.pytreelib import Pytree from flax.nnx.variablelib import Param, Variable diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index 52ec0205a..2bf5e3db9 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -30,6 +30,8 @@ statelib, variablelib, ) +from flax import errors +from flax.nnx import tracers from flax.nnx.extract import labeled from flax.nnx.transforms.transforms import ( _resolve_bound_callable, @@ -383,7 +385,8 @@ def jit( update_shardings = extract.check_prefix( in_shardings, 'in_shardings', 'jit', graph, graph_updates ) - update_shardings[None] = None # kwargs sharding + if None not in update_shardings: + update_shardings.append(None) # kwargs sharding extract.check_prefix( out_shardings, 'out_shardings', 'jit', graph, graph_updates ) @@ -413,39 +416,48 @@ def jit( ) -@dataclasses.dataclass(frozen=True, slots=True) -class PartialState: - """Container for a pre-flattened partial argument. - - Stores the pytree structure (``treedef``) as static metadata and the - flattened leaves as dynamic data. Variables within the original argument - are kept as leaves so their values can change between calls without - triggering recompilation. - """ - treedef: jax.tree_util.PyTreeDef - leaves: list[tp.Any] -jax.tree_util.register_dataclass( - PartialState, - data_fields=['leaves'], - meta_fields=['treedef'], -) +@dataclasses.dataclass(eq=False) +class SimpleJitFn: + f: tp.Callable[..., tp.Any] + in_shardings: tp.Any + out_shardings: tp.Any + donate_argnums: frozenset[int] + donate_argnames: frozenset[str] + graph: bool + update_shardings: tuple[tp.Any, ...] + def __post_init__(self): + functools.update_wrapper(self, self.f, updated=()) -def _flatten_to_partial_state( - arg: tp.Any, - ref_index: graphlib.RefMap | None, -) -> PartialState: - if ref_index is not None: - graphdef, flat_state = graphlib.flatten(arg, ref_index=ref_index, graph=True) - return PartialState(treedef=graphdef, leaves=flat_state.leaves) - is_leaf = lambda x: isinstance(x, variablelib.Variable) - leaves, treedef = jax.tree.flatten(arg, is_leaf=is_leaf) - return PartialState(treedef=treedef, leaves=leaves) + @extract.treemap_copy_args + def __call__(self, *args, **kwargs): + current, snapshot = extract.snapshot( + labeled(args=args, kwargs=kwargs) + ) + if self.graph: + args, kwargs = extract.from_tree2((args, kwargs)) + out = self.f(*args, **kwargs) + if self.graph: + out = extract.to_tree2(out, prefix=self.out_shardings) + extract.check_no_aliases('jit', **current, out=out, check=['out']) + def keep_fn(jax_path, prefix, c, s): + if extract.variable_changed(c, s): + return True + arg_type, arg_key, *_ = graphlib.jax_to_nnx_path(jax_path) + if arg_type == 'args': + return arg_key in self.donate_argnums + else: # arg_type == 'kwargs': + return arg_key in self.donate_argnames + updates = extract.get_updates( + current, snapshot, prefix=labeled(args=self.in_shardings, kwargs=None), + known_prefixes=self.update_shardings, keep_fn=keep_fn + ) + return out, updates @dataclasses.dataclass(eq=False) -class SimpleJitFn: +class SimpleJitPartialFn: f: tp.Callable[..., tp.Any] in_shardings: tp.Any out_shardings: tp.Any @@ -468,8 +480,14 @@ def __call__(self, *args, **kwargs): if self.graph: out = extract.to_tree2(out, prefix=self.out_shardings) extract.check_no_aliases('jit', **current, out=out, check=['out']) - def keep_fn(jax_path, prefix, c, s): + def keep_fn(jax_path, prefix, c: variablelib.Variable, s: variablelib.Variable): if extract.variable_changed(c, s): + if c.get_metadata() != s.get_metadata(): + path_str = jax.tree_util.keystr(jax_path) + raise ValueError( + f'Variable metadata changed inside jit at path {path_str}. ' + f'Changing Variable metadata inside jit is not supported.' + ) return True arg_type, arg_key, *_ = graphlib.jax_to_nnx_path(jax_path) if arg_type == 'args': @@ -480,6 +498,13 @@ def keep_fn(jax_path, prefix, c, s): current, snapshot, prefix=labeled(args=self.in_shardings, kwargs=None), known_prefixes=self.update_shardings, keep_fn=keep_fn ) + for update_group in updates: + update_group._keys = [ + k[-1].idx for k in update_group._keys + ] + update_group._values = [ + v.get_raw_value() for v in update_group._values + ] return out, updates @@ -498,9 +523,9 @@ def __init__( device: tp.Optional[jax.Device], backend: tp.Optional[str], inline: bool, - partial_args: tuple[PartialState, ...], + partial_args: tuple[tp.Any, ...], graph: bool, - update_shardings: extract.OrderedDict, + update_shardings: tuple[tp.Any, ...], ): functools.update_wrapper(self, fun) self.fun: tp.Callable[P, R] = fun @@ -549,12 +574,17 @@ def __init__( def _maybe_to_tree(self, args, kwargs): if self.graph: + if self.in_shardings is not None and isinstance(self.in_shardings, (tuple, list)): + runtime_prefix = self.in_shardings[len(self.partial_args):] + else: + runtime_prefix = self.in_shardings + args, kwargs = extract.to_tree2( (args, kwargs), - prefix=(self.in_shardings, None) - if self.in_shardings is not None + prefix=(runtime_prefix, None) + if runtime_prefix is not None else None, - check_aliasing=self.in_shardings is not None, + check_aliasing=runtime_prefix is not None, ) return args, kwargs @@ -581,7 +611,7 @@ def eval_shape(self, *args, **kwargs): args, kwargs = self._maybe_to_tree(args, kwargs) if not self.graph: extract.check_no_aliases('jit', args=args, kwargs=kwargs) - out, updates = self.jitted_fn.eval_shape(*args, **kwargs) + out, _ = self.jitted_fn.eval_shape(*args, **kwargs) return self._maybe_from_tree(out) def trace(self, *args, **kwargs): @@ -599,6 +629,109 @@ def lower(self, *args, **kwargs): extract.check_no_aliases('jit', args=args, kwargs=kwargs) lowered = self.jitted_fn.lower(*args, **kwargs) return SimpleLowered(lowered, self) + + +def _apply_raw_updates( + partial_args: list[tp.Any], + updates: list[extract.Updates], +): + """Apply updates containing raw values using integer indices into partial_args.""" + trace = tracers.current_jax_trace() + for flat_state in updates: + for index, raw_value in flat_state: + var = partial_args[index] + if var._trace_state._jax_trace != trace: + raise errors.TraceContextError( + f'Cannot mutate {type(var).__name__} from a different trace level' + ) + object.__setattr__(var, '_raw_value', raw_value) + + +class SimpleJitPartialWrapped(tp.Generic[P, R]): + + def __init__( + self, + fun: tp.Callable[P, R], + in_shardings: tp.Any, + out_shardings: tp.Any, + static_argnums: int | tp.Sequence[int] | None, + static_argnames: str | tp.Iterable[str] | None, + donate_argnums: int | tp.Sequence[int] | None, + donate_argnames: str | tp.Iterable[str] | None, + keep_unused: bool, + device: tp.Optional[jax.Device], + backend: tp.Optional[str], + inline: bool, + partial_args: list[tp.Any], + graph: bool, + update_shardings: list[tp.Any], + ): + functools.update_wrapper(self, fun) + self.fun: tp.Callable[P, R] = fun + self.in_shardings = in_shardings + self.out_shardings = out_shardings + self.partial_args = partial_args + self.graph = graph + + donate_argnums_set = frozenset( + (donate_argnums,) if isinstance(donate_argnums, int) + else donate_argnums or () + ) + donate_argnames_set = frozenset( + (donate_argnames,) if isinstance(donate_argnames, str) + else donate_argnames or () + ) + self.jitted_fn = jax.jit( + SimpleJitPartialFn( + fun, + in_shardings, + out_shardings, + donate_argnums_set, + donate_argnames_set, + graph, + tuple(update_shardings), + ), + in_shardings=in_shardings, + out_shardings=(out_shardings, update_shardings), + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + donate_argnames=donate_argnames, + keep_unused=keep_unused, + device=device, + backend=backend, + inline=inline, + ) + + def _maybe_from_tree(self, out): + if self.graph: + out = extract.from_tree2(out) + return out + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: + out, updates = self.jitted_fn(self.partial_args, *args, **kwargs) + _apply_raw_updates(self.partial_args, updates) + return self._maybe_from_tree(out) + + def __get__(self, obj, objtype=None): + if obj is None: + return self + return functools.partial(self, obj) + + def eval_shape(self, *args, **kwargs): + out, _ = self.jitted_fn.eval_shape( + self.partial_args, *args, **kwargs + ) + return self._maybe_from_tree(out) + + def trace(self, *args, **kwargs): + traced = self.jitted_fn.trace(self.partial_args, *args, **kwargs) + return SimplePartialTraced(traced, self) + + def lower(self, *args, **kwargs): + lowered = self.jitted_fn.lower(self.partial_args, *args, **kwargs) + return SimplePartialLowered(lowered, self) + def jit_partial( fun: tp.Callable[..., R], *partial_args: tp.Any, @@ -614,12 +747,12 @@ def jit_partial( inline: bool = False, graph: bool | None = None, graph_updates: bool | None = None, -) -> SimpleJitWrapped[..., R]: +) -> SimpleJitPartialWrapped[..., R]: """JIT-compile ``fun`` with pre-flattened partial arguments. Similar to ``nnx.cached_partial`` but designed for tree-mode (``graph=False``). Each ``partial_arg`` is flattened into a - ``PartialState`` whose pytree structure is fixed at construction time. + list of Variables and Arrays whose pytree structure is fixed at construction time. Variable values inside partial arguments can still change between calls without triggering recompilation, and any mutations to Variables are propagated back to the originals after each call. @@ -684,7 +817,8 @@ def jit_partial( update_shardings = extract.check_prefix( in_shardings, 'in_shardings', 'jit_partial', graph, graph_updates ) - update_shardings[None] = None # kwargs sharding + if None not in update_shardings: + update_shardings.append(None) # kwargs sharding if any(isinstance(x, StateSharding) for x in jax.tree.leaves(in_shardings)): raise ValueError( '`in_shardings` cannot contain `StateSharding` objects ' @@ -697,65 +831,89 @@ def jit_partial( ) is_variable = lambda x: isinstance(x, variablelib.Variable) - ref_index = graphlib.RefMap() if graph else None - flat_partial_args = tuple( - _flatten_to_partial_state(arg, ref_index=ref_index) - for arg in partial_args + + # 1. Graph->tree conversion and alias check beforehand + if graph: + if in_shardings is not None and isinstance(in_shardings, (tuple, list)): + partial_in_axes = in_shardings[:len(partial_args)] + else: + partial_in_axes = in_shardings + tree_partial_args = extract.to_tree2( + partial_args, + prefix=partial_in_axes, + check_aliasing=partial_in_axes is not None, + ) + else: + tree_partial_args = partial_args + + # Check no aliases beforehand + extract.check_no_aliases('jit_partial', args=tree_partial_args) + + # 2. Flatten the partial_args to a single list of Variables and Arrays + flat_partial_args, partial_treedef = jax.tree.flatten( + tree_partial_args, is_leaf=is_variable ) + + # 4. Sharding calculation + # partial_args is passed as a single list argument, so in_shardings + # for that argument is a list matching its pytree structure. jit_in_shardings: tp.Any = None - if in_shardings is not None and isinstance(in_shardings, (tuple, list)) and not graph: + if in_shardings is not None and isinstance(in_shardings, (tuple, list)): num_partial = len(partial_args) partial_shardings = in_shardings[:num_partial] runtime_shardings = in_shardings[num_partial:] - flat_partial_shardings = [] - for flat_arg, orig_arg, sharding in zip( - flat_partial_args, partial_args, partial_shardings): - broadcasted = extract.broadcast_prefix( - sharding, orig_arg, + broadcasted = extract.broadcast_prefix( + partial_shardings, tree_partial_args, prefix_is_leaf=lambda x: x is None or isinstance(x, variablelib.Variable), tree_is_leaf=is_variable, - ) - flat_partial_shardings.append( - PartialState(treedef=flat_arg.treedef, leaves=broadcasted) - ) - jit_in_shardings = (*flat_partial_shardings, *runtime_shardings) + ) + flat_partial_shardings = jax.tree.leaves(broadcasted, is_leaf=lambda x: x is None) + jit_in_shardings = (flat_partial_shardings, *runtime_shardings) else: jit_in_shardings = in_shardings + # 5. wrapped_fun accepts partial_args as its first argument (a list) @functools.wraps(fun) - def wrapped_fun(*args, **kwargs): - index_ref = graphlib.IndexMap() if graph else None - def _unflatten(arg): - if not isinstance(arg, PartialState): - return arg - elif graph: - return graphlib.unflatten( - arg.treedef, arg.leaves, index_ref=index_ref, - copy_variables=False, - ) - else: - return jax.tree.unflatten(arg.treedef, arg.leaves) - args = (_unflatten(a) for a in args) - return fun(*args, **kwargs) + def wrapped_fun(flat_partial_list, *args, **kwargs): + # Check no Variables in runtime args/kwargs + runtime_leaves = jax.tree.leaves((args, kwargs), is_leaf=is_variable) + if any(is_variable(x) for x in runtime_leaves): + raise ValueError( + 'Found Variable in non-partial arguments. ' + 'jit_partial only supports Variables in partial arguments.' + ) - return SimpleJitWrapped( - wrapped_fun, - in_shardings=jit_in_shardings, - out_shardings=out_shardings, - static_argnums=static_argnums, - static_argnames=static_argnames, - donate_argnums=donate_argnums, - donate_argnames=donate_argnames, - keep_unused=keep_unused, - device=device, - backend=backend, - inline=inline, - partial_args=flat_partial_args, - graph=graph, - update_shardings=update_shardings, + # Unflatten to tree_partial_args (which contains TreeState if graph=True) + tree_partial_args = jax.tree.unflatten(partial_treedef, flat_partial_list) + + # Convert TreeState back to Modules if graph=True, preserving Variable identity + if graph: + reconstructed_partial_args = extract.from_tree2( + tree_partial_args, recreate_variables=False + ) + else: + reconstructed_partial_args = tree_partial_args + + return fun(*reconstructed_partial_args, *args, **kwargs) + + return SimpleJitPartialWrapped( + wrapped_fun, + in_shardings=jit_in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + donate_argnames=donate_argnames, + keep_unused=keep_unused, + device=device, + backend=backend, + inline=inline, + partial_args=flat_partial_args, + graph=graph, + update_shardings=update_shardings, ) @@ -1265,6 +1423,112 @@ def lower( ) -> SimpleLowered: lowered = self.traced.lower(lowering_platforms=lowering_platforms) return SimpleLowered(lowered, self.jit_wrapped) + + +@dataclasses.dataclass(frozen=True, slots=True) +class SimplePartialCompiled(Stage): + compiled: jax.stages.Compiled + jit_wrapped: SimpleJitPartialWrapped + + @property + def _inner_obj(self): + return self.compiled + + @property + def args_info(self) -> tp.Any: + raise self.compiled.args_info + + @staticmethod + def call(*args, **kwargs): + raise NotImplementedError + + def __call__(self, *args, **kwargs): + out, updates = self.compiled(self.jit_wrapped.partial_args, *args, **kwargs) + _apply_raw_updates(self.jit_wrapped.partial_args, updates) + return self.jit_wrapped._maybe_from_tree(out) + + @property + def out_tree(self) -> jax.tree_util.PyTreeDef: + return self.compiled.out_tree + + def as_text(self) -> str | None: + return self.compiled.as_text() + + def cost_analysis(self) -> tp.Any | None: + return self.compiled.cost_analysis() + + def memory_analysis(self) -> tp.Any | None: + return self.compiled.memory_analysis() + + def runtime_executable(self) -> tp.Any | None: + return self.compiled.runtime_executable() + + @property + def input_shardings(self): + return self.compiled.input_shardings + + @property + def output_shardings(self): + return self.compiled.output_shardings + + @property + def input_layouts(self): + return self.compiled.input_formats + + +@dataclasses.dataclass(frozen=True, slots=True) +class SimplePartialLowered(Stage): + lowered: jax.stages.Lowered + jit_wrapped: SimpleJitPartialWrapped + + @property + def _inner_obj(self): + return self.lowered + + @property + def args_info(self) -> tp.Any: + return self.lowered.args_info + + @property + def out_tree(self): + return self.lowered.out_tree + + def compile( + self, compiler_options: jax.stages.CompilerOptions | None = None + ) -> SimplePartialCompiled: + compiled = self.lowered.compile(compiler_options) + return SimplePartialCompiled(compiled, self.jit_wrapped) + + def as_text( + self, dialect: str | None = None, *, debug_info: bool = False + ) -> str: + return self.lowered.as_text(dialect=dialect, debug_info=debug_info) + + def compiler_ir(self, dialect: str | None = None) -> tp.Any | None: + return self.lowered.compiler_ir(dialect=dialect) + + def cost_analysis(self) -> tp.Any | None: + return self.lowered.cost_analysis() + + +@dataclasses.dataclass(frozen=True, slots=True) +class SimplePartialTraced(Stage): + traced: jax.stages.Traced + jit_wrapped: SimpleJitPartialWrapped + + @property + def _inner_obj(self): + return self.traced + + @property + def out_info(self): + return self.traced.out_info + + def lower( + self, *, lowering_platforms: tuple[str, ...] | None = None + ) -> SimplePartialLowered: + lowered = self.traced.lower(lowering_platforms=lowering_platforms) + return SimplePartialLowered(lowered, self.jit_wrapped) # ------------------------------- # shard_map # ------------------------------- @@ -1382,7 +1646,7 @@ def shard_map( import jax import jax.numpy as jnp - from flax import nnx + from flax.nnx import src as nnx from jax.sharding import PartitionSpec as P mesh = jax.sharding.Mesh(jax.local_devices(), ('data',)) diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index a13516600..f49cfaf04 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -517,7 +517,8 @@ def vmap( _raise_bound_method_error('vmap') update_axes = extract.check_prefix(in_axes, 'in_axes', 'vmap', graph, graph_updates) - update_axes[0] = 0 # kwargs axes + if 0 not in update_axes: + update_axes.append(0) # kwargs axes extract.check_prefix(out_axes, 'out_axes', 'vmap', graph, graph_updates) if not (graph and graph_updates): @@ -771,7 +772,8 @@ def pmap( _raise_bound_method_error('pmap') update_axes = extract.check_prefix(in_axes, 'in_axes', 'pmap', graph, graph_updates) - update_axes[0] = 0 # kwargs axes + if 0 not in update_axes: + update_axes.append(0) # kwargs axes extract.check_prefix(out_axes, 'out_axes', 'pmap', graph, graph_updates) if not (graph and graph_updates): @@ -1512,7 +1514,7 @@ def scan( Example:: import jax - from flax import nnx + from flax.nnx import src as nnx class Block(nnx.Module): def __init__(self, input_dim, features, *, rngs): @@ -1662,12 +1664,11 @@ def _simple_scan( f, f_unbound, *, graph, in_axes, out_axes, length, reverse, unroll, _split_transpose, - updates_axes: extract.OrderedDict, + updates_axes: list[tp.Any], ): _validate_scan_axes(in_axes, out_axes) # None and Carry aren't valid update axes - updates_axes.pop(None, None) - updates_axes.pop(Carry, None) + updates_axes = [ax for ax in updates_axes if ax is not None and ax is not Carry] out_is_tuple = isinstance(out_axes, tuple) was_carry = in_axes is Carry diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index a1795caf0..b1d75bc73 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -1087,7 +1087,7 @@ def test_jit_partial_no_partial_args(self, graph_mode): y = f_partial(jnp.array(3.0)) np.testing.assert_allclose(y, 6.0) - @parameterized.parameters((False,)) + @parameterized.parameters(True, False) def test_jit_partial_in_shardings_none_broadcast(self, graph_mode): n_devices = jax.local_device_count() devices = mesh_utils.create_device_mesh((n_devices,)) @@ -1103,7 +1103,7 @@ def f(m, x): y = f_jit(x) self.assertEqual(y.shape, (n_devices, 3)) - @parameterized.parameters((False,)) + @parameterized.parameters(True, False) def test_jit_partial_in_shardings_named(self, graph_mode): n_devices = jax.local_device_count() devices = mesh_utils.create_device_mesh((n_devices,)) @@ -1124,7 +1124,7 @@ def f(v, x): y = f_jit(x) self.assertEqual(y.shape, (n_devices, 4)) - @parameterized.parameters((False,)) + @parameterized.parameters(True, False) def test_jit_partial_mixed_shardings(self, graph_mode): n_devices = jax.local_device_count() devices = mesh_utils.create_device_mesh((n_devices,)) @@ -1196,12 +1196,13 @@ def f(c1, c2, x): c1.v[...] += x return c1.v[...] + c2.v[...] - f_jit = nnx.jit_partial(f, c1, c2, graph=graph, graph_updates=False) if not graph: with self.assertRaisesRegex(ValueError, 'Duplicate Param'): - f_jit(jnp.array(1.0)) + nnx.jit_partial(f, c1, c2, graph=graph, graph_updates=False) return + f_jit = nnx.jit_partial(f, c1, c2, graph=graph, graph_updates=False) + y = f_jit(jnp.array(1.0)) np.testing.assert_allclose(y, 4.0) np.testing.assert_allclose(v[...], 2.0)