Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs_nnx/examples/image_segmentation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1632,8 +1632,8 @@
" self.count = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))\n",
"\n",
" def reset(self):\n",
" self.confusion_matrix.value = jnp.zeros((self.num_classes, self.num_classes), dtype=jnp.int32)\n",
" self.count.value = jnp.array(0, dtype=jnp.int32)\n",
" self.confusion_matrix[...] = jnp.zeros((self.num_classes, self.num_classes), dtype=jnp.int32)\n",
" self.count[...] = jnp.array(0, dtype=jnp.int32)\n",
"\n",
" def _check_shape(self, y_pred: jax.Array, y: jax.Array):\n",
" if y_pred.shape[-1] != self.num_classes:\n",
Expand Down
4 changes: 2 additions & 2 deletions docs_nnx/examples/image_segmentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -1147,8 +1147,8 @@ class ConfusionMatrix(nnx.Metric):
self.count = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))

def reset(self):
self.confusion_matrix.value = jnp.zeros((self.num_classes, self.num_classes), dtype=jnp.int32)
self.count.value = jnp.array(0, dtype=jnp.int32)
self.confusion_matrix[...] = jnp.zeros((self.num_classes, self.num_classes), dtype=jnp.int32)
self.count[...] = jnp.array(0, dtype=jnp.int32)

def _check_shape(self, y_pred: jax.Array, y: jax.Array):
if y_pred.shape[-1] != self.num_classes:
Expand Down
10 changes: 5 additions & 5 deletions docs_nnx/guides/bridge_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
"nnx.display(model)\n",
"\n",
"# In-place swap your weight array and the model still works!\n",
"model.w.value = jax.random.normal(jax.random.key(1), (32, 64))\n",
"model.w[...] = jax.random.normal(jax.random.key(1), (32, 64))\n",
"assert not jnp.allclose(y, model(x))"
]
},
Expand Down Expand Up @@ -199,7 +199,7 @@
"outputs": [],
"source": [
"assert isinstance(model.dot.w, nnx.Param)\n",
"assert isinstance(model.dot.w.value, jax.Array)"
"assert isinstance(model.dot.w[...], jax.Array)"
]
},
{
Expand Down Expand Up @@ -606,8 +606,8 @@
" self.count = Count(jnp.array(0))\n",
"\n",
" def __call__(self, x):\n",
" self.count.value += 1\n",
" return (x @ self.w.value) + self.lora(x)\n",
" self.count[...] += 1\n",
" return (x @ self.w[...]) + self.lora(x)\n",
"\n",
"xkey, pkey, dkey = jax.random.split(jax.random.key(0), 3)\n",
"x = jax.random.normal(xkey, (2, 4))\n",
Expand Down Expand Up @@ -694,7 +694,7 @@
"\n",
"print(type(model.w)) # `nnx.Param`\n",
"print(model.w.sharding) # The partition annotation attached with `w`\n",
"print(model.w.value.sharding) # The underlying JAX array is sharded across the 2x4 mesh"
"print(model.w[...].sharding) # The underlying JAX array is sharded across the 2x4 mesh"
]
},
{
Expand Down
10 changes: 5 additions & 5 deletions docs_nnx/guides/bridge_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ y = model(x) # => `y = model.apply(var, x)` in Linen
nnx.display(model)

# In-place swap your weight array and the model still works!
model.w.value = jax.random.normal(jax.random.key(1), (32, 64))
model.w[...] = jax.random.normal(jax.random.key(1), (32, 64))
assert not jnp.allclose(y, model(x))
```

Expand All @@ -116,7 +116,7 @@ We will talk more about different collections and types in the [NNX Variable <->

```{code-cell} ipython3
assert isinstance(model.dot.w, nnx.Param)
assert isinstance(model.dot.w.value, jax.Array)
assert isinstance(model.dot.w[...], jax.Array)
```

If you create this model witout using `nnx.bridge.lazy_init`, the NNX variables defined outside will be initialized as usual, but the Linen part (wrapped inside `ToNNX`) will not.
Expand Down Expand Up @@ -311,8 +311,8 @@ class NNXMultiCollections(nnx.Module):
self.count = Count(jnp.array(0))

def __call__(self, x):
self.count.value += 1
return (x @ self.w.value) + self.lora(x)
self.count[...] += 1
return (x @ self.w[...]) + self.lora(x)

xkey, pkey, dkey = jax.random.split(jax.random.key(0), 3)
x = jax.random.normal(xkey, (2, 4))
Expand Down Expand Up @@ -375,7 +375,7 @@ with jax.set_mesh(mesh):

print(type(model.w)) # `nnx.Param`
print(model.w.sharding) # The partition annotation attached with `w`
print(model.w.value.sharding) # The underlying JAX array is sharded across the 2x4 mesh
print(model.w[...].sharding) # The underlying JAX array is sharded across the 2x4 mesh
```

We have 8 fake JAX devices now to partition this model...
Expand Down
10 changes: 5 additions & 5 deletions docs_nnx/guides/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
" self.count = Count(0) # stateful variables are defined as attributes\n",
"\n",
" def __call__(self, x):\n",
" self.count.value += 1 # in-place stateful updates\n",
" self.count[...] += 1 # in-place stateful updates\n",
" for block in self.blocks:\n",
" x = block(x)\n",
" return x\n",
Expand Down Expand Up @@ -245,7 +245,7 @@
"output_type": "stream",
"text": [
"y.shape = (2, 4)\n",
"model.count.value = Array(3, dtype=int32, weak_type=True)\n"
"model.count[...] = Array(3, dtype=int32, weak_type=True)\n"
]
}
],
Expand All @@ -265,7 +265,7 @@
"model.update(state)\n",
"\n",
"print(f'{y.shape = }')\n",
"print(f'{model.count.value = }')"
"print(f'{model.count[...] = }')"
]
},
{
Expand Down Expand Up @@ -313,7 +313,7 @@
"output_type": "stream",
"text": [
"y.shape = (2, 4)\n",
"parent.model.count.value = Array(4, dtype=int32, weak_type=True)\n"
"parent.model.count[...] = Array(4, dtype=int32, weak_type=True)\n"
]
}
],
Expand Down Expand Up @@ -342,7 +342,7 @@
"y = parent(jnp.ones((2, 4)))\n",
"\n",
"print(f'{y.shape = }')\n",
"print(f'{parent.model.count.value = }')"
"print(f'{parent.model.count[...] = }')"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions docs_nnx/guides/demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class MLP(nnx.Module):
self.count = Count(0) # stateful variables are defined as attributes

def __call__(self, x):
self.count.value += 1 # in-place stateful updates
self.count[...] += 1 # in-place stateful updates
for block in self.blocks:
x = block(x)
return x
Expand Down Expand Up @@ -113,7 +113,7 @@ y, state = forward(graphdef,state, x)
model.update(state)

print(f'{y.shape = }')
print(f'{model.count.value = }')
print(f'{model.count[...] = }')
```

```{code-cell} ipython3
Expand Down Expand Up @@ -160,7 +160,7 @@ parent = Parent(model)
y = parent(jnp.ones((2, 4)))

print(f'{y.shape = }')
print(f'{parent.model.count.value = }')
print(f'{parent.model.count[...] = }')
```

```{code-cell} ipython3
Expand Down
10 changes: 5 additions & 5 deletions docs_nnx/guides/surgery.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"When you fill every `nnx.Variable` pytree leaf's `value` attributes with real `jax.Array`s, the abstract model becomes equivalent to a real model."
"When you fill every `nnx.Variable` pytree leaf with real `jax.Array`s, the abstract model becomes equivalent to a real model."
]
},
{
Expand All @@ -172,10 +172,10 @@
"outputs": [],
"source": [
"model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n",
"abs_state['linear1']['kernel'].value = model.linear1.kernel.value\n",
"abs_state['linear1']['bias'].value = model.linear1.bias.value\n",
"abs_state['linear2']['kernel'].value = model.linear2.kernel.value\n",
"abs_state['linear2']['bias'].value = model.linear2.bias.value\n",
"abs_state['linear1']['kernel'][...] = model.linear1.kernel[...]\n",
"abs_state['linear1']['bias'][...] = model.linear1.bias[...]\n",
"abs_state['linear2']['kernel'][...] = model.linear2.kernel[...]\n",
"abs_state['linear2']['bias'][...] = model.linear2.bias[...]\n",
"nnx.update(abs_model, abs_state)\n",
"np.testing.assert_allclose(abs_model(x), model(x)) # They are equivalent now!"
]
Expand Down
10 changes: 5 additions & 5 deletions docs_nnx/guides/surgery.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,14 @@ gdef, abs_state = nnx.split(abs_model)
pprint(abs_state)
```

When you fill every `nnx.Variable` pytree leaf's `value` attributes with real `jax.Array`s, the abstract model becomes equivalent to a real model.
When you fill every `nnx.Variable` pytree leaf with real `jax.Array`s, the abstract model becomes equivalent to a real model.

```{code-cell} ipython3
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
abs_state['linear1']['kernel'].value = model.linear1.kernel.value
abs_state['linear1']['bias'].value = model.linear1.bias.value
abs_state['linear2']['kernel'].value = model.linear2.kernel.value
abs_state['linear2']['bias'].value = model.linear2.bias.value
abs_state['linear1']['kernel'][...] = model.linear1.kernel[...]
abs_state['linear1']['bias'][...] = model.linear1.bias[...]
abs_state['linear2']['kernel'][...] = model.linear2.kernel[...]
abs_state['linear2']['bias'][...] = model.linear2.bias[...]
nnx.update(abs_model, abs_state)
np.testing.assert_allclose(abs_model(x), model(x)) # They are equivalent now!
```
Expand Down
14 changes: 7 additions & 7 deletions docs_nnx/key_concepts.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@
"\n",
"# Flax created a `Param` wrapper over the actual `jax.Array` parameter to track metadata\n",
"print(type(linear.kernel)) # flax.nnx.Param\n",
"print(type(linear.kernel.value)) # jax.Array\n",
"print(type(linear.kernel[...])) # jax.Array\n",
"\n",
"# The computation of the two are the same\n",
"x = jax.random.normal(jax.random.key(0), (2, 4))\n",
"flax_y = linear(x)\n",
"jax_y = jax_linear(x, linear.kernel.value, linear.bias.value)\n",
"jax_y = jax_linear(x, linear.kernel[...], linear.bias[...])\n",
"assert jnp.array_equal(flax_y, jax_y)"
]
},
Expand Down Expand Up @@ -245,10 +245,10 @@
"output_type": "stream",
"text": [
"model.dim = 4\n",
"model.traced_dim.value = JitTracer<~int32[]>\n",
"model.traced_dim.get_value() = JitTracer<~int32[]>\n",
"Code path based on static data value works fine.\n",
"Code path based on JAX data value throws error: Attempted boolean conversion of traced array with shape bool[].\n",
"The error occurred while tracing the function jitted at /var/folders/4c/ylxxyg_n67957jf6616c7z5000gbn1/T/ipykernel_69242/584946237.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument model.traced_dim.value.\n",
"The error occurred while tracing the function jitted at /var/folders/4c/ylxxyg_n67957jf6616c7z5000gbn1/T/ipykernel_69242/584946237.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument model.traced_dim.get_value().\n",
"See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError\n"
]
}
Expand All @@ -257,11 +257,11 @@
"@jax.jit\n",
"def jitted(model):\n",
" print(f'{model.dim = }')\n",
" print(f'{model.traced_dim.value = }') # This is being traced\n",
" print(f'{model.traced_dim.get_value() = }') # This is being traced\n",
" if model.dim == 4:\n",
" print('Code path based on static data value works fine.')\n",
" try:\n",
" if model.traced_dim.value == 4:\n",
" if model.traced_dim.get_value() == 4:\n",
" print('This will never run :(')\n",
" except jax.errors.TracerBoolConversionError as e:\n",
" print(f'Code path based on JAX data value throws error: {e}')\n",
Expand Down Expand Up @@ -455,7 +455,7 @@
" model = MLP(dim, nlayers, nnx.Rngs(0))\n",
" return jax.lax.with_sharding_constraint(model, model_shardings)\n",
"model = sharded_init(dim, nlayers)\n",
"jax.debug.visualize_array_sharding(model.blocks[0].kernel.value)"
"jax.debug.visualize_array_sharding(model.blocks[0].kernel[...])"
]
},
{
Expand Down
14 changes: 7 additions & 7 deletions docs_nnx/key_concepts.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ linear = nnx.Linear(in_features=4, out_features=2, rngs=nnx.Rngs(42))

# Flax created a `Param` wrapper over the actual `jax.Array` parameter to track metadata
print(type(linear.kernel)) # flax.nnx.Param
print(type(linear.kernel.value)) # jax.Array
print(type(linear.kernel[...])) # jax.Array

# The computation of the two are the same
x = jax.random.normal(jax.random.key(0), (2, 4))
flax_y = linear(x)
jax_y = jax_linear(x, linear.kernel.value, linear.bias.value)
jax_y = jax_linear(x, linear.kernel[...], linear.bias[...])
assert jnp.array_equal(flax_y, jax_y)
```

Expand Down Expand Up @@ -156,11 +156,11 @@ When compiling a function using this pytree, you'll notice the difference betwee
@jax.jit
def jitted(model):
print(f'{model.dim = }')
print(f'{model.traced_dim.value = }') # This is being traced
print(f'{model.traced_dim.get_value() = }') # This is being traced
if model.dim == 4:
print('Code path based on static data value works fine.')
try:
if model.traced_dim.value == 4:
if model.traced_dim.get_value() == 4:
print('This will never run :(')
except jax.errors.TracerBoolConversionError as e:
print(f'Code path based on JAX data value throws error: {e}')
Expand All @@ -169,10 +169,10 @@ jitted(foo)
```

model.dim = 4
model.traced_dim.value = JitTracer<~int32[]>
model.traced_dim.get_value() = JitTracer<~int32[]>
Code path based on static data value works fine.
Code path based on JAX data value throws error: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function jitted at /var/folders/4c/ylxxyg_n67957jf6616c7z5000gbn1/T/ipykernel_69242/584946237.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument model.traced_dim.value.
The error occurred while tracing the function jitted at /var/folders/4c/ylxxyg_n67957jf6616c7z5000gbn1/T/ipykernel_69242/584946237.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument model.traced_dim.get_value().
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError


Expand Down Expand Up @@ -270,7 +270,7 @@ def sharded_init(dim, nlayers):
model = MLP(dim, nlayers, nnx.Rngs(0))
return jax.lax.with_sharding_constraint(model, model_shardings)
model = sharded_init(dim, nlayers)
jax.debug.visualize_array_sharding(model.blocks[0].kernel.value)
jax.debug.visualize_array_sharding(model.blocks[0].kernel[...])
```

Param(
Expand Down
9 changes: 4 additions & 5 deletions docs_nnx/migrating/haiku_to_flax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ If you want to access Flax model parameters in the stateless, dictionary-like fa

# Parameters were already initialized during model instantiation.

assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)
assert model.linear.bias[...].shape == (10,)
assert model.block.linear.kernel[...].shape == (784, 256)

Training step and compilation
=============================
Expand Down Expand Up @@ -692,9 +692,9 @@ be set and accessed as normal using regular Python class semantics.
nnx.initializers.ones(rngs.params(), [1,], jnp.float32)
)
def __call__(self, x):
output = x + self.multiplier * self.counter.value
output = x + self.multiplier * self.counter[...]

self.counter.value += 1
self.counter[...] += 1
return output

model = FooModule(rngs=nnx.Rngs(0))
Expand All @@ -703,4 +703,3 @@ be set and accessed as normal using regular Python class semantics.




6 changes: 3 additions & 3 deletions docs_nnx/migrating/linen_to_nnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ If you want to access Flax NNX model parameters in the stateless, dictionary-lik

# Parameters were already initialized during model instantiation.

assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)
assert model.linear.bias[...].shape == (10,)
assert model.block.linear.kernel[...].shape == (784, 256)


Training step and compilation
Expand Down Expand Up @@ -255,7 +255,7 @@ For all the built-in Flax Linen layers and collections, Flax NNX already creates
def __call__(self, x):
x = self.linear(x)
x = self.batchnorm(x)
self.count.value += 1
self.count[...] += 1
x = jax.nn.relu(x)
return x

Expand Down
Loading