Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions docs_nnx/guides/flax_gspmd.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,8 @@
"\n",
"class MultiDotReluDot(nnx.Module):\n",
" def __init__(self, depth: int, num_layers: int, rngs: nnx.Rngs):\n",
" # Annotate the additional axis with sharding=None, meaning it will be\n",
" # replicated across all devices.\n",
" @nnx.vmap(transform_metadata={nnx.PARTITION_NAME: None})\n",
" # The additional axis is unsharded by default.\n",
" @nnx.vmap\n",
" def create_sublayers(r):\n",
" return DotReluDot(depth, r)\n",
" self.layers = create_sublayers(rngs.fork(split=num_layers))\n",
Expand Down Expand Up @@ -428,7 +427,7 @@
"\n",
"class LogicalMultiDotReluDot(nnx.Module):\n",
" def __init__(self, depth: int, num_layers: int, rngs: nnx.Rngs):\n",
" @nnx.vmap(transform_metadata={nnx.PARTITION_NAME: None})\n",
" @nnx.vmap\n",
" def create_sublayers(r):\n",
" return LogicalDotReluDot(depth, r)\n",
" self.layers = create_sublayers(rngs.fork(split=num_layers))\n",
Expand Down
7 changes: 3 additions & 4 deletions docs_nnx/guides/flax_gspmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,8 @@ class DotReluDot(nnx.Module):

class MultiDotReluDot(nnx.Module):
def __init__(self, depth: int, num_layers: int, rngs: nnx.Rngs):
# Annotate the additional axis with sharding=None, meaning it will be
# replicated across all devices.
@nnx.vmap(transform_metadata={nnx.PARTITION_NAME: None})
# The additional axis is unsharded by default.
@nnx.vmap
def create_sublayers(r):
return DotReluDot(depth, r)
self.layers = create_sublayers(rngs.fork(split=num_layers))
Expand Down Expand Up @@ -276,7 +275,7 @@ class LogicalDotReluDot(nnx.Module):

class LogicalMultiDotReluDot(nnx.Module):
def __init__(self, depth: int, num_layers: int, rngs: nnx.Rngs):
@nnx.vmap(transform_metadata={nnx.PARTITION_NAME: None})
@nnx.vmap
def create_sublayers(r):
return LogicalDotReluDot(depth, r)
self.layers = create_sublayers(rngs.fork(split=num_layers))
Expand Down
18 changes: 14 additions & 4 deletions docs_nnx/guides/transforms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -797,9 +797,15 @@
"\n",
"Flax NNX `Variable`s can hold arbitrary metadata, which can be added by simply passing it as keyword arguments to its constructor. This is often used to store `sharding` information, as used by the `nnx.spmd` APIs (like `nnx.get_partition_spec` and `nnx.get_named_sharding`).\n",
"\n",
"However, it is often important to keep this axes-related information in sync to what the actual state of the axes is when transforms are involved. For example, if you vectorize a variable on axis `1`, you should remove the `sharding` information at position `1` when inside a `vmap` or `scan` to reflect the fact that the axes are temporarily removed.\n",
"However, it is often important to keep this axes-related information in sync to what the actual state of the axes is when transforms are involved. For example, if you vectorize a variable on axis `1`, you should remove the `out_sharding` information at position `1` when inside a `vmap` or `scan` to reflect the fact that the axes are temporarily removed.\n",
"\n",
"To achieve this, Flax NNX transforms provide a non-standard `transform_metadata` dictionary argument. And when the `nnx.PARTITION_NAME` key is present, the `sharding` metadata will be updated as specified by `in_axes` and `out_axes`.\n",
"In graph mode with graph updates enabled, Flax NNX transforms automatically\n",
"keep `out_sharding` metadata aligned with `in_axes` and `out_axes`. By default,\n",
"a transform-added axis is annotated with `None`, meaning that it is unsharded.\n",
"You can use the non-standard `transform_metadata` dictionary argument with the\n",
"`nnx.PARTITION_NAME` key to give the transformed axis an explicit logical name\n",
"instead. Other tuple-valued metadata can be transformed by adding it to the\n",
"same dictionary.\n",
"\n",
"Let's see an example of this in action:"
]
Expand Down Expand Up @@ -846,9 +852,13 @@
"id": "a23bda09",
"metadata": {},
"source": [
"Here, you added a `sharding` metadata to the `nnx.Param` variables, and used `transform_metadata` to update the `sharding` metadata to reflect the axis changes. Specifically, you can see that the first axis `b` was removed from the `sharding` metadata when inside of `nnx.vmap`, and then added back when outside of `nnx.vmap`.\n",
"Here, you added `out_sharding` metadata to the `nnx.Param` variables and used\n",
"`transform_metadata` to explicitly name the transformed axis `b`. Specifically,\n",
"you can see that `b` was removed from `out_sharding` when inside `nnx.vmap`,\n",
"and then added back when outside `nnx.vmap`. If `transform_metadata` were\n",
"omitted, the same axis would be represented by `None`.\n",
"\n",
"You can verify that this also works when `nnx.Module`s are created inside the transformation - the new `sharding` axes will be added to the `nnx.Module` `nnx.Variable`s outside the transformation, matching the axes of the transformed `nnx.Variable`s."
"You can verify that this also works when `nnx.Module`s are created inside the transformation - the new `out_sharding` axes will be added to the `nnx.Module` `nnx.Variable`s outside the transformation, matching the axes of the transformed `nnx.Variable`s."
]
},
{
Expand Down
18 changes: 14 additions & 4 deletions docs_nnx/guides/transforms.md
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,15 @@ except ValueError as e:

Flax NNX `Variable`s can hold arbitrary metadata, which can be added by simply passing it as keyword arguments to its constructor. This is often used to store `sharding` information, as used by the `nnx.spmd` APIs (like `nnx.get_partition_spec` and `nnx.get_named_sharding`).

However, it is often important to keep this axes-related information in sync to what the actual state of the axes is when transforms are involved. For example, if you vectorize a variable on axis `1`, you should remove the `sharding` information at position `1` when inside a `vmap` or `scan` to reflect the fact that the axes are temporarily removed.
However, it is often important to keep this axes-related information in sync to what the actual state of the axes is when transforms are involved. For example, if you vectorize a variable on axis `1`, you should remove the `out_sharding` information at position `1` when inside a `vmap` or `scan` to reflect the fact that the axes are temporarily removed.

To achieve this, Flax NNX transforms provide a non-standard `transform_metadata` dictionary argument. And when the `nnx.PARTITION_NAME` key is present, the `sharding` metadata will be updated as specified by `in_axes` and `out_axes`.
In graph mode with graph updates enabled, Flax NNX transforms automatically
keep `out_sharding` metadata aligned with `in_axes` and `out_axes`. By default,
a transform-added axis is annotated with `None`, meaning that it is unsharded.
You can use the non-standard `transform_metadata` dictionary argument with the
`nnx.PARTITION_NAME` key to give the transformed axis an explicit logical name
instead. Other tuple-valued metadata can be transformed by adding it to the
same dictionary.

Let's see an example of this in action:

Expand All @@ -407,9 +413,13 @@ print(f'Outter {m.param.shape = }')
print(f'Outter {m.param.out_sharding = }')
```

Here, you added a `sharding` metadata to the `nnx.Param` variables, and used `transform_metadata` to update the `sharding` metadata to reflect the axis changes. Specifically, you can see that the first axis `b` was removed from the `sharding` metadata when inside of `nnx.vmap`, and then added back when outside of `nnx.vmap`.
Here, you added `out_sharding` metadata to the `nnx.Param` variables and used
`transform_metadata` to explicitly name the transformed axis `b`. Specifically,
you can see that `b` was removed from `out_sharding` when inside `nnx.vmap`,
and then added back when outside `nnx.vmap`. If `transform_metadata` were
omitted, the same axis would be represented by `None`.

You can verify that this also works when `nnx.Module`s are created inside the transformation - the new `sharding` axes will be added to the `nnx.Module` `nnx.Variable`s outside the transformation, matching the axes of the transformed `nnx.Variable`s.
You can verify that this also works when `nnx.Module`s are created inside the transformation - the new `out_sharding` axes will be added to the `nnx.Module` `nnx.Variable`s outside the transformation, matching the axes of the transformed `nnx.Variable`s.

```{code-cell} ipython3
@nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'})
Expand Down
2 changes: 2 additions & 0 deletions flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def add_axis(tree: A, index: int, transform_metadata: tp.Mapping) -> A:

def insert_field(fields, index, value):
iterable = list(fields)
if index < 0:
index += len(iterable) + 1
while len(iterable) < index:
iterable.append(None)
iterable.insert(index, value)
Expand Down
21 changes: 20 additions & 1 deletion flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@
Index = int


def _default_transform_metadata(
transform_metadata: tp.Mapping[str, tp.Any],
) -> tp.Mapping[str, tp.Any]:
if spmd.PARTITION_NAME in transform_metadata:
return transform_metadata
return FrozenDict({spmd.PARTITION_NAME: None, **transform_metadata})


class Carry:
"""Helper class for :func:`flax.nnx.scan` function to mark input and output axis as carry.
"""
Expand Down Expand Up @@ -424,6 +432,9 @@ def vmap(
axis so that parallel collectives can be applied.
axis_size: Optional, an integer indicating the size of the axis to be
mapped. If not provided, the mapped axis size is inferred from arguments.
transform_metadata: Optional mapping of tuple-valued axis metadata. If
``nnx.PARTITION_NAME`` is omitted, the mapped axis is added to
``out_sharding`` as an unsharded axis when graph updates are enabled.
graph: If ``True`` (default), uses graph-mode which supports the full
NNX feature set including shared references and reference semantics.
If ``False``, uses tree-mode which treats Modules as regular JAX
Expand Down Expand Up @@ -495,6 +506,7 @@ def vmap(
[0, 2, 4, 6],
[0, 3, 6, 9]], dtype=int32)
"""
transform_metadata = _default_transform_metadata(transform_metadata)
if graph is None:
graph = graphlib.set_graph_mode.current_value()
if graph_updates is None:
Expand Down Expand Up @@ -733,7 +745,9 @@ def pmap(
result. You should not reuse buffers that you donate to a computation,
JAX will raise an error if you try to. Note that donate_argnums only
work for positional arguments, and keyword arguments will not be donated.
transform_metadata: Optional mapping of metadata for the transform.
transform_metadata: Optional mapping of tuple-valued axis metadata. If
``nnx.PARTITION_NAME`` is omitted, the mapped axis is added to
``out_sharding`` as an unsharded axis when graph updates are enabled.
graph: if True, use graph-mode (default). If False, use tree-mode.
If None, uses the value of ``nnx_graph_mode`` config.
graph_updates: If ``True``, propagates updates on graph structure
Expand All @@ -746,6 +760,7 @@ def pmap(
``f`` but with extra array axes at positions indicated by ``in_axes`` and
with output that has an additional leading array axis (with the same size).
"""
transform_metadata = _default_transform_metadata(transform_metadata)
if graph is None:
graph = graphlib.set_graph_mode.current_value()
if graph_updates is None:
Expand Down Expand Up @@ -1596,13 +1611,17 @@ def forward(x, model):
out_axes: integer, None, :class:`flax.nnx.Carry` or sequence of values specifying
the kind of output args. See ``in_axes`` for details. Note that If ``in_axes``
contains :class:`flax.nnx.Carry` then ``out_axes`` must also contain :class:`flax.nnx.Carry`.
transform_metadata: Optional mapping of tuple-valued axis metadata. If
``nnx.PARTITION_NAME`` is omitted, the scanned axis is added to
``out_sharding`` as an unsharded axis when graph updates are enabled.
graph_updates: If ``True``, propagates updates on graph structure
that happen inside the transform to the input graphs, has no
effect when ``graph=False``. When ``False``, using ``StateAxes``
is not supported.

.. _jax.lax.scan: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html>
"""
transform_metadata = _default_transform_metadata(transform_metadata)
if f is Missing:
return functools.partial(
scan,
Expand Down
157 changes: 157 additions & 0 deletions tests/nnx/spmd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,163 @@ def test_get_abstract_no_sharding_metadata(self):
getattr(abs_model.kernel.get_value(), 'sharding', None)
)

def test_vmap_default_transform_metadata(self):
mesh = jax.make_mesh((2, 2), ('a', 'b'))
rules = (('A', 'a'), ('B', 'b'))

class Model(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
@nnx.split_rngs(splits=1)
@nnx.vmap(in_axes=(0,), out_axes=0)
def create_linear(rngs: nnx.Rngs):
return nnx.Param(
jnp.ones((4, 4)),
out_sharding=('A', 'B'),
mesh=mesh,
sharding_rules=rules,
)

self.w = create_linear(rngs=rngs)

@nnx.jit
def init():
model = Model(rngs=nnx.Rngs(params=0))
optimizer = nnx.Optimizer(
model, optax.adam(1e-3), wrt=nnx.Param
)
return model, optimizer

with jax.set_mesh(mesh):
model, optimizer = init()

self.assertEqual(model.w.shape, (1, 4, 4))
self.assertEqual(model.w.out_sharding, (None, 'A', 'B'))
self.assertEqual(
optimizer.opt_state[0].mu['w'].out_sharding, (None, 'A', 'B')
)
self.assertEqual(
optimizer.opt_state[0].nu['w'].out_sharding, (None, 'A', 'B')
)

@parameterized.parameters(None, 'layers')
def test_vmap_explicit_transform_metadata(self, partition_name):
@nnx.vmap(
in_axes=None,
out_axes=0,
axis_size=2,
transform_metadata={nnx.PARTITION_NAME: partition_name},
)
def create_param():
return nnx.Param(
jnp.ones((4, 4)),
out_sharding=('din', 'dout'),
eager_sharding=False,
)

param = create_param()

self.assertEqual(
param.out_sharding, (partition_name, 'din', 'dout')
)

def test_vmap_merges_default_with_other_transform_metadata(self):
@nnx.vmap(
in_axes=None,
out_axes=0,
axis_size=2,
transform_metadata={'nickname': 'batch'},
)
def create_param():
return nnx.Param(
jnp.ones((3, 4)),
out_sharding=('din', 'dout'),
nickname=('in', 'out'),
eager_sharding=False,
)

param = create_param()

self.assertEqual(param.out_sharding, (None, 'din', 'dout'))
self.assertEqual(param.nickname, ('batch', 'in', 'out'))

def test_vmap_with_partitioning_default_transform_metadata(self):
@nnx.vmap(in_axes=None, out_axes=0, axis_size=2)
def create_param():
return nnx.Param(
nnx.with_partitioning(
lambda: jnp.ones((3, 4)), ('din', 'dout')
)(),
eager_sharding=False,
)

param = create_param()

self.assertEqual(param.shape, (2, 3, 4))
self.assertEqual(param.out_sharding, (None, 'din', 'dout'))

def test_vmap_default_transform_metadata_negative_out_axis(self):
@nnx.vmap(in_axes=None, out_axes=-1, axis_size=2)
def create_param():
return nnx.Param(
jnp.ones((3, 4)),
out_sharding=('din', 'dout'),
eager_sharding=False,
)

param = create_param()

self.assertEqual(param[...].shape, (3, 4, 2))
self.assertEqual(param.out_sharding, ('din', 'dout', None))

def test_scan_default_transform_metadata(self):
@nnx.split_rngs(splits=3)
@nnx.scan(
in_axes=(nnx.Carry, 0),
out_axes=(nnx.Carry, 0),
length=3,
)
def create_param(_, rngs: nnx.Rngs):
return None, nnx.Param(
jnp.ones((4, 4)),
out_sharding=('din', 'dout'),
eager_sharding=False,
)

_, param = create_param(None, nnx.Rngs(0))

self.assertEqual(param.shape, (3, 4, 4))
self.assertEqual(param.out_sharding, (None, 'din', 'dout'))

def test_pmap_default_transform_metadata(self):
@nnx.pmap(
in_axes=0,
out_axes=0,
axis_size=1,
devices=jax.devices()[:1],
)
def create_param(_):
return nnx.Param(
jnp.ones((3, 4)),
out_sharding=('din', 'dout'),
eager_sharding=False,
)

param = create_param(jnp.zeros(1))

self.assertEqual(param.shape, (1, 3, 4))
self.assertEqual(param.out_sharding, (None, 'din', 'dout'))

def test_vmap_no_sharding_metadata_unaffected(self):
@nnx.vmap(in_axes=None, out_axes=0, axis_size=2)
def create_param():
return nnx.Param(jnp.ones((4, 4)))

param = create_param()

self.assertEqual(param.shape, (2, 4, 4))
self.assertFalse(param.has_metadata('out_sharding'))


def has_sharding_spec(array):
sharding = array.sharding
if hasattr(sharding, 'spec'):
Expand Down