Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions benchmarks/test_replaybuffer_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
SamplerWithoutReplacement,
SliceSampler,
)
from torchrl.envs.transforms import ActionChunkTransform, CatFrames

_TensorDictPrioritizedReplayBuffer = functools.partial(
TensorDictPrioritizedReplayBuffer, alpha=1, beta=0.9
Expand Down Expand Up @@ -449,6 +450,34 @@ def test_rb_extend_sample(
)


class TestWindowingTransformsBenchmark:
"""Offline (sample-path) sliding-window transforms: CatFrames.unfolding
and the ActionChunkTransform recipe built on top of it."""

@pytest.mark.parametrize("done_key", ["done", None], ids=["done_aware", "no_done"])
def test_action_chunk_transform(self, benchmark, done_key):
t = ActionChunkTransform(chunk_size=8, done_key=done_key)
td = TensorDict(
{
"action": torch.randn(64, 32, 7),
("next", "done"): torch.zeros(64, 32, 1, dtype=torch.bool),
},
batch_size=[64],
)
benchmark(t, td)

def test_catframes_offline(self, benchmark):
t = CatFrames(N=4, dim=-3, in_keys=["pixels"], out_keys=["pixels_cat"])
td = TensorDict(
{
"pixels": torch.randn(8, 32, 3, 32, 32),
("next", "done"): torch.zeros(8, 32, 1, dtype=torch.bool),
},
batch_size=[8, 32],
).refine_names(None, "time")
benchmark(t, td)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
3 changes: 3 additions & 0 deletions docs/source/reference/vla.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ pipeline) and they live alongside the other transforms, documented in full on th
- :class:`~torchrl.envs.transforms.ActionChunkTransform` -- build fixed-length
action chunks (``[*B, T, H, action_dim]``) and a padding mask from a sampled
trajectory window, the standard training target for chunked VLA policies.
Internally a thin recipe over a forward-looking
:class:`~torchrl.envs.transforms.CatFrames`; when the window carries its
done state, chunks stop (pad and mask) at the trajectory boundaries.
- :class:`~torchrl.envs.transforms.ActionScaling` -- affine action
normalization; built with the
:meth:`~torchrl.envs.transforms.ActionScaling.from_metadata` /
Expand Down
87 changes: 80 additions & 7 deletions test/transforms/test_action_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2515,14 +2515,87 @@ def test_trailing_dim_enforced(self):
with pytest.raises(ValueError, match="immediately follow"):
t(TensorDict({"action": torch.randn(2, 4, 3)}, batch_size=[2, 4]))

def test_compile_build_chunk(self):
@staticmethod
def _reference_gather(action, H):
# the pre-0.14 dedicated gather (arange + clamp + index_select), kept
# as the ground truth the CatFrames recipe must reproduce exactly
T = action.shape[-2]
idx = torch.arange(T).unsqueeze(-1) + torch.arange(H).unsqueeze(0)
is_pad = idx >= T
idx = idx.clamp_max(T - 1).reshape(-1)
chunk = action.index_select(-2, idx).unflatten(-2, (T, H))
is_pad = is_pad.expand(chunk.shape[:-1]).contiguous()
return chunk, is_pad

@pytest.mark.parametrize("batch_size", [(), (2,), (3, 2)])
@pytest.mark.parametrize("chunk_size", [1, 3, 7])
def test_equivalence_with_reference_gather(self, batch_size, chunk_size):
torch.manual_seed(0)
action = torch.randn(*batch_size, 5, 4)
t = ActionChunkTransform(chunk_size=chunk_size)
out = t(TensorDict({"action": action}, batch_size=batch_size))
ref_chunk, ref_pad = self._reference_gather(action, chunk_size)
# byte-identical: the recipe copies the same source elements
assert torch.equal(out["action_chunk"], ref_chunk)
assert torch.equal(out["action_is_pad"], ref_pad)

def test_done_aware_chunking(self):
# a done inside the window: chunks must not cross the trajectory
# boundary (steps past it are padded with the last valid action)
action = torch.arange(4.0).view(1, 4, 1)
done = torch.zeros(1, 4, 1, dtype=torch.bool)
done[0, 1] = True # boundary between steps 1 and 2
td = TensorDict({"action": action, ("next", "done"): done}, batch_size=[1, 4])
out = ActionChunkTransform(3)(td)
expected_chunk = torch.tensor(
[[0, 1, 1], [1, 1, 1], [2, 3, 3], [3, 3, 3]]
).float()
torch.testing.assert_close(out["action_chunk"][0, :, :, 0], expected_chunk)
expected_pad = torch.tensor([[0, 0, 1], [0, 1, 1], [0, 0, 1], [0, 1, 1]]).bool()
assert torch.equal(out["action_is_pad"][0], expected_pad)

def test_done_key_none_ignores_dones(self):
action = torch.arange(4.0).view(1, 4, 1)
done = torch.zeros(1, 4, 1, dtype=torch.bool)
done[0, 1] = True
td = TensorDict({"action": action, ("next", "done"): done}, batch_size=[1, 4])
out = ActionChunkTransform(3, done_key=None)(td)
# the boundary is ignored: the chunk at t=0 reads across the done
expected_chunk = torch.tensor(
[[0, 1, 2], [1, 2, 3], [2, 3, 3], [3, 3, 3]]
).float()
torch.testing.assert_close(out["action_chunk"][0, :, :, 0], expected_chunk)

def test_done_shape_mismatch_raises(self):
td = TensorDict(
{
"action": torch.randn(2, 4, 3),
("next", "done"): torch.zeros(2, 5, 1, dtype=torch.bool),
},
batch_size=[2],
)
with pytest.raises(ValueError, match="does not line up"):
ActionChunkTransform(2)(td)

def test_clone_keeps_recipe(self):
t = ActionChunkTransform(
chunk_size=3, action_key=("data", "action"), done_key=None
).clone()
assert isinstance(t, ActionChunkTransform)
assert t.chunk_size == 3
assert t.action_key == ("data", "action")
assert t.done_key is None
td = TensorDict({"data": {"action": torch.randn(2, 4, 3)}}, batch_size=[2, 4])
assert t(td)["action_chunk"].shape == torch.Size([2, 4, 3, 3])

def test_compile(self):
t = ActionChunkTransform(chunk_size=3)
action = torch.randn(2, 5, 2)
eager_chunk, eager_pad = t._build_chunk(action)
compiled = torch.compile(t._build_chunk, fullgraph=True)
c_chunk, c_pad = compiled(action)
torch.testing.assert_close(c_chunk, eager_chunk)
assert torch.equal(c_pad, eager_pad)
td = TensorDict({"action": torch.randn(2, 5, 2)}, batch_size=[2, 5])
eager = t(td.clone())
compiled = torch.compile(t)
out = compiled(td.clone())
torch.testing.assert_close(out["action_chunk"], eager["action_chunk"])
assert torch.equal(out["action_is_pad"], eager["action_is_pad"])


class TestActionTokenizerTransform(TransformBase):
Expand Down
147 changes: 147 additions & 0 deletions test/transforms/test_observation_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,153 @@ def test_constant_padding(self, padding_value):
cat_td = cat_frames._call(cat_td)
assert (cat_td.get("cat_first_key") == padding_value).sum() == N - 4

def test_unfolding_n_larger_than_t(self):
# history windows longer than the sampled trajectory: the leading
# no-reset block used to be built by slicing the done entry (capped
# at the time length) and crashed for N > T
t = CatFrames(N=6, dim=-2, in_keys=["obs"], out_keys=["obs_cat"])
obs = torch.arange(3.0).view(1, 3, 1, 1)
done = torch.zeros(1, 3, 1, dtype=torch.bool)
td = TensorDict(
{"obs": obs, ("next", "done"): done}, batch_size=[1, 3]
).refine_names(None, "time")
out = t(td)
assert out["obs_cat"].shape == torch.Size([1, 3, 6, 1])
# padding="same": the first frame fills the missing history
torch.testing.assert_close(out["obs_cat"][0, 0, :, 0], torch.zeros(6))
torch.testing.assert_close(
out["obs_cat"][0, 2, :, 0],
torch.tensor([0.0, 0.0, 0.0, 0.0, 1.0, 2.0]),
)

def test_unfolding_missing_done_raises(self):
t = CatFrames(N=3, dim=-2, in_keys=["obs"], out_keys=["obs_cat"])
td = TensorDict(
{"obs": torch.randn(1, 4, 1, 1)}, batch_size=[1, 4]
).refine_names(None, "time")
with pytest.raises(KeyError, match="delimit"):
t(td)

def test_future_windows(self):
t = CatFrames(
N=3,
dim=-2,
in_keys=["obs"],
out_keys=["obs_cat"],
future=True,
mask_key="mask",
)
obs = torch.arange(4.0).view(1, 4, 1, 1)
# the done entry is optional with forward windows: each row is one
# contiguous trajectory
td = TensorDict({"obs": obs}, batch_size=[1, 4]).refine_names(None, "time")
out = t(td.clone())
expected = torch.tensor(
[[0.0, 1.0, 2.0], [1.0, 2.0, 3.0], [2.0, 3.0, 3.0], [3.0, 3.0, 3.0]]
)
torch.testing.assert_close(out["obs_cat"][0, :, :, 0], expected)
expected_mask = torch.tensor(
[[0, 0, 0], [0, 0, 0], [0, 0, 1], [0, 1, 1]]
).bool()
assert torch.equal(out["mask"][0], expected_mask)

def test_future_windows_with_done(self):
t = CatFrames(
N=3,
dim=-2,
in_keys=["obs"],
out_keys=["obs_cat"],
future=True,
mask_key="mask",
)
obs = torch.arange(4.0).view(1, 4, 1, 1)
done = torch.zeros(1, 4, 1, dtype=torch.bool)
done[0, 1] = True # boundary between steps 1 and 2
td = TensorDict(
{"obs": obs, ("next", "done"): done}, batch_size=[1, 4]
).refine_names(None, "time")
out = t(td.clone())
expected = torch.tensor(
[[0.0, 1.0, 1.0], [1.0, 1.0, 1.0], [2.0, 3.0, 3.0], [3.0, 3.0, 3.0]]
)
torch.testing.assert_close(out["obs_cat"][0, :, :, 0], expected)
expected_mask = torch.tensor(
[[0, 0, 1], [0, 1, 1], [0, 0, 1], [0, 1, 1]]
).bool()
assert torch.equal(out["mask"][0], expected_mask)

def test_future_windows_constant_padding(self):
t = CatFrames(
N=3,
dim=-2,
in_keys=["obs"],
out_keys=["obs_cat"],
future=True,
padding="constant",
padding_value=-1.0,
)
obs = torch.arange(4.0).view(1, 4, 1, 1)
td = TensorDict({"obs": obs}, batch_size=[1, 4]).refine_names(None, "time")
out = t(td.clone())
expected = torch.tensor(
[[0.0, 1.0, 2.0], [1.0, 2.0, 3.0], [2.0, 3.0, -1.0], [3.0, -1.0, -1.0]]
)
torch.testing.assert_close(out["obs_cat"][0, :, :, 0], expected)

def test_mask_key_history_windows(self):
# the mask is also available for plain (backward) history windows
t = CatFrames(
N=3, dim=-2, in_keys=["obs"], out_keys=["obs_cat"], mask_key="mask"
)
obs = torch.arange(4.0).view(1, 4, 1, 1)
done = torch.zeros(1, 4, 1, dtype=torch.bool)
done[0, 1] = True
td = TensorDict(
{"obs": obs, ("next", "done"): done}, batch_size=[1, 4]
).refine_names(None, "time")
out = t(td.clone())
# window slots read oldest-to-newest; True = fabricated by padding
expected_mask = torch.tensor(
[[1, 1, 0], [1, 0, 0], [1, 1, 0], [1, 0, 0]]
).bool()
assert torch.equal(out["mask"][0], expected_mask)

def test_future_env_raises(self):
env = TransformedEnv(
ContinuousActionVecMockEnv(),
CatFrames(N=2, dim=-1, in_keys=["observation"], future=True),
)
with pytest.raises(RuntimeError, match="offline"):
env.reset()

def test_mask_key_env_raises(self):
env = TransformedEnv(
ContinuousActionVecMockEnv(),
CatFrames(N=2, dim=-1, in_keys=["observation"], mask_key="mask"),
)
with pytest.raises(RuntimeError, match="offline"):
env.reset()

def test_future_next_key_overlap_raises(self):
t = CatFrames(
N=2,
dim=-1,
in_keys=["obs", ("next", "obs")],
out_keys=["obs_cat", ("next", "obs_cat")],
future=True,
)
obs = torch.randn(1, 4, 2)
td = TensorDict(
{
"obs": obs,
("next", "obs"): obs,
("next", "done"): torch.zeros(1, 4, 1, dtype=torch.bool),
},
batch_size=[1, 4],
).refine_names(None, "time")
with pytest.raises(NotImplementedError, match="history"):
t(td)

@staticmethod
def _unfold_done(done, N, ndim):
# Mirror of ``CatFrames.unfolding.unfold_done`` (window-padding mask)
Expand Down
Loading
Loading