From 3cb1fce2ede6b7970738e276ed178b375fcd5a3c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 11 Jun 2026 16:04:40 +0100 Subject: [PATCH] [Refactor] ActionChunkTransform as a CatFrames recipe Eliminate the duplicated sliding-window implementation between ActionChunkTransform and CatFrames by making ActionChunkTransform a thin Compose recipe over CatFrames (the R3MTransform pattern): an UnsqueezeTransform opens the chunk dim and a forward-looking CatFrames does the windowing. The public API and outputs are unchanged (pinned by a byte-identical equivalence test against the previous gather). CatFrames gains two offline-only options: - future=True: forward-looking windows [t, ..., t+N-1], implemented by time-reversing the input around the existing unfold/pad core so that padding="same" becomes repeat-last end padding. Raises on the env step path (the online buffer cannot see future frames); the done entry is optional in this mode (absent done = one contiguous trajectory per row). - mask_key=...: exposes the per-window validity mask that unfold_done already computed and discarded (True = fabricated/padded slot, the action_is_pad convention). Available for history windows too. Also: - _apply_same_padding vectorized (a gather instead of a Python loop over batch*time samples), making the offline path compile-friendly. - unfold_done builds its leading no-reset block explicitly instead of slicing the done entry, fixing offline windows longer than the trajectory (N > T). - Chunks become boundary-aware: when the sampled window carries ("next", done), steps past a trajectory boundary inside the window are padded and masked instead of leaking actions across episodes (done_key=None opts out; absent done keeps the previous behavior). - CatFramesConfig brought to kwarg parity (padding, padding_value, as_inverse, reset_key, done_key, future, mask_key). - Benchmark for the offline windowing path (ActionChunkTransform forward and CatFrames unfolding). Co-Authored-By: Claude Fable 5 --- benchmarks/test_replaybuffer_benchmark.py | 29 +++ docs/source/reference/vla.rst | 3 + test/transforms/test_action_transforms.py | 87 ++++++++- .../transforms/test_observation_transforms.py | 147 ++++++++++++++ torchrl/envs/transforms/_action.py | 184 +++++++++++++++--- torchrl/envs/transforms/_observation.py | 108 +++++++++- torchrl/envs/transforms/functional.py | 36 ++-- .../trainers/algorithms/configs/transforms.py | 7 + tutorials/sphinx-tutorials/vla.py | 6 +- 9 files changed, 542 insertions(+), 65 deletions(-) diff --git a/benchmarks/test_replaybuffer_benchmark.py b/benchmarks/test_replaybuffer_benchmark.py index c89eb6df23e..cbe6c9a5984 100644 --- a/benchmarks/test_replaybuffer_benchmark.py +++ b/benchmarks/test_replaybuffer_benchmark.py @@ -24,6 +24,7 @@ SamplerWithoutReplacement, SliceSampler, ) +from torchrl.envs.transforms import ActionChunkTransform, CatFrames _TensorDictPrioritizedReplayBuffer = functools.partial( TensorDictPrioritizedReplayBuffer, alpha=1, beta=0.9 @@ -449,6 +450,34 @@ def test_rb_extend_sample( ) +class TestWindowingTransformsBenchmark: + """Offline (sample-path) sliding-window transforms: CatFrames.unfolding + and the ActionChunkTransform recipe built on top of it.""" + + @pytest.mark.parametrize("done_key", ["done", None], ids=["done_aware", "no_done"]) + def test_action_chunk_transform(self, benchmark, done_key): + t = ActionChunkTransform(chunk_size=8, done_key=done_key) + td = TensorDict( + { + "action": torch.randn(64, 32, 7), + ("next", "done"): torch.zeros(64, 32, 1, dtype=torch.bool), + }, + batch_size=[64], + ) + benchmark(t, td) + + def test_catframes_offline(self, benchmark): + t = CatFrames(N=4, dim=-3, in_keys=["pixels"], out_keys=["pixels_cat"]) + td = TensorDict( + { + "pixels": torch.randn(8, 32, 3, 32, 32), + ("next", "done"): torch.zeros(8, 32, 1, dtype=torch.bool), + }, + batch_size=[8, 32], + ).refine_names(None, "time") + benchmark(t, td) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/docs/source/reference/vla.rst b/docs/source/reference/vla.rst index 1a5161019c2..f0dad0fcefa 100644 --- a/docs/source/reference/vla.rst +++ b/docs/source/reference/vla.rst @@ -79,6 +79,9 @@ pipeline) and they live alongside the other transforms, documented in full on th - :class:`~torchrl.envs.transforms.ActionChunkTransform` -- build fixed-length action chunks (``[*B, T, H, action_dim]``) and a padding mask from a sampled trajectory window, the standard training target for chunked VLA policies. + Internally a thin recipe over a forward-looking + :class:`~torchrl.envs.transforms.CatFrames`; when the window carries its + done state, chunks stop (pad and mask) at the trajectory boundaries. - :class:`~torchrl.envs.transforms.ActionScaling` -- affine action normalization; built with the :meth:`~torchrl.envs.transforms.ActionScaling.from_metadata` / diff --git a/test/transforms/test_action_transforms.py b/test/transforms/test_action_transforms.py index 06fb92153e5..5761121a82e 100644 --- a/test/transforms/test_action_transforms.py +++ b/test/transforms/test_action_transforms.py @@ -2515,14 +2515,87 @@ def test_trailing_dim_enforced(self): with pytest.raises(ValueError, match="immediately follow"): t(TensorDict({"action": torch.randn(2, 4, 3)}, batch_size=[2, 4])) - def test_compile_build_chunk(self): + @staticmethod + def _reference_gather(action, H): + # the pre-0.14 dedicated gather (arange + clamp + index_select), kept + # as the ground truth the CatFrames recipe must reproduce exactly + T = action.shape[-2] + idx = torch.arange(T).unsqueeze(-1) + torch.arange(H).unsqueeze(0) + is_pad = idx >= T + idx = idx.clamp_max(T - 1).reshape(-1) + chunk = action.index_select(-2, idx).unflatten(-2, (T, H)) + is_pad = is_pad.expand(chunk.shape[:-1]).contiguous() + return chunk, is_pad + + @pytest.mark.parametrize("batch_size", [(), (2,), (3, 2)]) + @pytest.mark.parametrize("chunk_size", [1, 3, 7]) + def test_equivalence_with_reference_gather(self, batch_size, chunk_size): + torch.manual_seed(0) + action = torch.randn(*batch_size, 5, 4) + t = ActionChunkTransform(chunk_size=chunk_size) + out = t(TensorDict({"action": action}, batch_size=batch_size)) + ref_chunk, ref_pad = self._reference_gather(action, chunk_size) + # byte-identical: the recipe copies the same source elements + assert torch.equal(out["action_chunk"], ref_chunk) + assert torch.equal(out["action_is_pad"], ref_pad) + + def test_done_aware_chunking(self): + # a done inside the window: chunks must not cross the trajectory + # boundary (steps past it are padded with the last valid action) + action = torch.arange(4.0).view(1, 4, 1) + done = torch.zeros(1, 4, 1, dtype=torch.bool) + done[0, 1] = True # boundary between steps 1 and 2 + td = TensorDict({"action": action, ("next", "done"): done}, batch_size=[1, 4]) + out = ActionChunkTransform(3)(td) + expected_chunk = torch.tensor( + [[0, 1, 1], [1, 1, 1], [2, 3, 3], [3, 3, 3]] + ).float() + torch.testing.assert_close(out["action_chunk"][0, :, :, 0], expected_chunk) + expected_pad = torch.tensor([[0, 0, 1], [0, 1, 1], [0, 0, 1], [0, 1, 1]]).bool() + assert torch.equal(out["action_is_pad"][0], expected_pad) + + def test_done_key_none_ignores_dones(self): + action = torch.arange(4.0).view(1, 4, 1) + done = torch.zeros(1, 4, 1, dtype=torch.bool) + done[0, 1] = True + td = TensorDict({"action": action, ("next", "done"): done}, batch_size=[1, 4]) + out = ActionChunkTransform(3, done_key=None)(td) + # the boundary is ignored: the chunk at t=0 reads across the done + expected_chunk = torch.tensor( + [[0, 1, 2], [1, 2, 3], [2, 3, 3], [3, 3, 3]] + ).float() + torch.testing.assert_close(out["action_chunk"][0, :, :, 0], expected_chunk) + + def test_done_shape_mismatch_raises(self): + td = TensorDict( + { + "action": torch.randn(2, 4, 3), + ("next", "done"): torch.zeros(2, 5, 1, dtype=torch.bool), + }, + batch_size=[2], + ) + with pytest.raises(ValueError, match="does not line up"): + ActionChunkTransform(2)(td) + + def test_clone_keeps_recipe(self): + t = ActionChunkTransform( + chunk_size=3, action_key=("data", "action"), done_key=None + ).clone() + assert isinstance(t, ActionChunkTransform) + assert t.chunk_size == 3 + assert t.action_key == ("data", "action") + assert t.done_key is None + td = TensorDict({"data": {"action": torch.randn(2, 4, 3)}}, batch_size=[2, 4]) + assert t(td)["action_chunk"].shape == torch.Size([2, 4, 3, 3]) + + def test_compile(self): t = ActionChunkTransform(chunk_size=3) - action = torch.randn(2, 5, 2) - eager_chunk, eager_pad = t._build_chunk(action) - compiled = torch.compile(t._build_chunk, fullgraph=True) - c_chunk, c_pad = compiled(action) - torch.testing.assert_close(c_chunk, eager_chunk) - assert torch.equal(c_pad, eager_pad) + td = TensorDict({"action": torch.randn(2, 5, 2)}, batch_size=[2, 5]) + eager = t(td.clone()) + compiled = torch.compile(t) + out = compiled(td.clone()) + torch.testing.assert_close(out["action_chunk"], eager["action_chunk"]) + assert torch.equal(out["action_is_pad"], eager["action_is_pad"]) class TestActionTokenizerTransform(TransformBase): diff --git a/test/transforms/test_observation_transforms.py b/test/transforms/test_observation_transforms.py index 6738313356f..59511740e83 100644 --- a/test/transforms/test_observation_transforms.py +++ b/test/transforms/test_observation_transforms.py @@ -667,6 +667,153 @@ def test_constant_padding(self, padding_value): cat_td = cat_frames._call(cat_td) assert (cat_td.get("cat_first_key") == padding_value).sum() == N - 4 + def test_unfolding_n_larger_than_t(self): + # history windows longer than the sampled trajectory: the leading + # no-reset block used to be built by slicing the done entry (capped + # at the time length) and crashed for N > T + t = CatFrames(N=6, dim=-2, in_keys=["obs"], out_keys=["obs_cat"]) + obs = torch.arange(3.0).view(1, 3, 1, 1) + done = torch.zeros(1, 3, 1, dtype=torch.bool) + td = TensorDict( + {"obs": obs, ("next", "done"): done}, batch_size=[1, 3] + ).refine_names(None, "time") + out = t(td) + assert out["obs_cat"].shape == torch.Size([1, 3, 6, 1]) + # padding="same": the first frame fills the missing history + torch.testing.assert_close(out["obs_cat"][0, 0, :, 0], torch.zeros(6)) + torch.testing.assert_close( + out["obs_cat"][0, 2, :, 0], + torch.tensor([0.0, 0.0, 0.0, 0.0, 1.0, 2.0]), + ) + + def test_unfolding_missing_done_raises(self): + t = CatFrames(N=3, dim=-2, in_keys=["obs"], out_keys=["obs_cat"]) + td = TensorDict( + {"obs": torch.randn(1, 4, 1, 1)}, batch_size=[1, 4] + ).refine_names(None, "time") + with pytest.raises(KeyError, match="delimit"): + t(td) + + def test_future_windows(self): + t = CatFrames( + N=3, + dim=-2, + in_keys=["obs"], + out_keys=["obs_cat"], + future=True, + mask_key="mask", + ) + obs = torch.arange(4.0).view(1, 4, 1, 1) + # the done entry is optional with forward windows: each row is one + # contiguous trajectory + td = TensorDict({"obs": obs}, batch_size=[1, 4]).refine_names(None, "time") + out = t(td.clone()) + expected = torch.tensor( + [[0.0, 1.0, 2.0], [1.0, 2.0, 3.0], [2.0, 3.0, 3.0], [3.0, 3.0, 3.0]] + ) + torch.testing.assert_close(out["obs_cat"][0, :, :, 0], expected) + expected_mask = torch.tensor( + [[0, 0, 0], [0, 0, 0], [0, 0, 1], [0, 1, 1]] + ).bool() + assert torch.equal(out["mask"][0], expected_mask) + + def test_future_windows_with_done(self): + t = CatFrames( + N=3, + dim=-2, + in_keys=["obs"], + out_keys=["obs_cat"], + future=True, + mask_key="mask", + ) + obs = torch.arange(4.0).view(1, 4, 1, 1) + done = torch.zeros(1, 4, 1, dtype=torch.bool) + done[0, 1] = True # boundary between steps 1 and 2 + td = TensorDict( + {"obs": obs, ("next", "done"): done}, batch_size=[1, 4] + ).refine_names(None, "time") + out = t(td.clone()) + expected = torch.tensor( + [[0.0, 1.0, 1.0], [1.0, 1.0, 1.0], [2.0, 3.0, 3.0], [3.0, 3.0, 3.0]] + ) + torch.testing.assert_close(out["obs_cat"][0, :, :, 0], expected) + expected_mask = torch.tensor( + [[0, 0, 1], [0, 1, 1], [0, 0, 1], [0, 1, 1]] + ).bool() + assert torch.equal(out["mask"][0], expected_mask) + + def test_future_windows_constant_padding(self): + t = CatFrames( + N=3, + dim=-2, + in_keys=["obs"], + out_keys=["obs_cat"], + future=True, + padding="constant", + padding_value=-1.0, + ) + obs = torch.arange(4.0).view(1, 4, 1, 1) + td = TensorDict({"obs": obs}, batch_size=[1, 4]).refine_names(None, "time") + out = t(td.clone()) + expected = torch.tensor( + [[0.0, 1.0, 2.0], [1.0, 2.0, 3.0], [2.0, 3.0, -1.0], [3.0, -1.0, -1.0]] + ) + torch.testing.assert_close(out["obs_cat"][0, :, :, 0], expected) + + def test_mask_key_history_windows(self): + # the mask is also available for plain (backward) history windows + t = CatFrames( + N=3, dim=-2, in_keys=["obs"], out_keys=["obs_cat"], mask_key="mask" + ) + obs = torch.arange(4.0).view(1, 4, 1, 1) + done = torch.zeros(1, 4, 1, dtype=torch.bool) + done[0, 1] = True + td = TensorDict( + {"obs": obs, ("next", "done"): done}, batch_size=[1, 4] + ).refine_names(None, "time") + out = t(td.clone()) + # window slots read oldest-to-newest; True = fabricated by padding + expected_mask = torch.tensor( + [[1, 1, 0], [1, 0, 0], [1, 1, 0], [1, 0, 0]] + ).bool() + assert torch.equal(out["mask"][0], expected_mask) + + def test_future_env_raises(self): + env = TransformedEnv( + ContinuousActionVecMockEnv(), + CatFrames(N=2, dim=-1, in_keys=["observation"], future=True), + ) + with pytest.raises(RuntimeError, match="offline"): + env.reset() + + def test_mask_key_env_raises(self): + env = TransformedEnv( + ContinuousActionVecMockEnv(), + CatFrames(N=2, dim=-1, in_keys=["observation"], mask_key="mask"), + ) + with pytest.raises(RuntimeError, match="offline"): + env.reset() + + def test_future_next_key_overlap_raises(self): + t = CatFrames( + N=2, + dim=-1, + in_keys=["obs", ("next", "obs")], + out_keys=["obs_cat", ("next", "obs_cat")], + future=True, + ) + obs = torch.randn(1, 4, 2) + td = TensorDict( + { + "obs": obs, + ("next", "obs"): obs, + ("next", "done"): torch.zeros(1, 4, 1, dtype=torch.bool), + }, + batch_size=[1, 4], + ).refine_names(None, "time") + with pytest.raises(NotImplementedError, match="history"): + t(td) + @staticmethod def _unfold_done(done, N, ndim): # Mirror of ``CatFrames.unfolding.unfold_done`` (window-padding mask) diff --git a/torchrl/envs/transforms/_action.py b/torchrl/envs/transforms/_action.py index ad248c65094..066d8c93be7 100644 --- a/torchrl/envs/transforms/_action.py +++ b/torchrl/envs/transforms/_action.py @@ -15,7 +15,7 @@ import torch -from tensordict import TensorDictBase +from tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey, unravel_key from torch import nn @@ -46,7 +46,8 @@ ACTION_TOKENS_KEY, ) from torchrl.data.vla.tokenizers import ActionTokenizerBase -from torchrl.envs.transforms._base import FORWARD_NOT_IMPLEMENTED, Transform +from torchrl.envs.transforms._base import Compose, FORWARD_NOT_IMPLEMENTED, Transform +from torchrl.envs.transforms._observation import CatFrames, UnsqueezeTransform __all__ = [ "ActionChunkTransform", @@ -1697,7 +1698,7 @@ def __repr__(self) -> str: ) -class ActionChunkTransform(Transform): +class ActionChunkTransform(Compose): """Build fixed-length action chunks from a trajectory window. Action *chunking* is the defining trait of modern VLA policies (ACT, @@ -1710,6 +1711,19 @@ class ActionChunkTransform(Transform): ``action_is_pad`` mask ``[*B, T, H]`` marking the steps that ran past the end of the window (and were filled by repeating the last available action). + Internally this is a recipe over the generic transforms (the same pattern + as :class:`~torchrl.envs.transforms.R3MTransform`): an + :class:`~torchrl.envs.transforms.UnsqueezeTransform` opens the chunk dim + and a forward-looking :class:`~torchrl.envs.transforms.CatFrames` + (``future=True, padding="same", mask_key=...``) does the windowing, so + chunking shares one sliding-window implementation with frame stacking. + + .. versionchanged:: 0.14 + ``ActionChunkTransform`` is now a :class:`~torchrl.envs.transforms.Compose` + recipe over :class:`~torchrl.envs.transforms.CatFrames`. The output is + unchanged, and additionally chunks become *boundary-aware* when the + sampled data carries its done state (see ``done_key``). + .. note:: **How to read "many actions in one tensor".** The ``H`` actions of a chunk are *predictions* -- overlapping, stride-1 training targets (each dataset step ``t`` gets its own window ``a[t..t+H-1]``, so a @@ -1735,7 +1749,11 @@ class ActionChunkTransform(Transform): ``time_dim`` must be a single contiguous trajectory window. Chunks are built independently per row and never cross a row boundary; the downstream chunked behavior-cloning loss masks the padded steps out using - ``action_is_pad``. + ``action_is_pad``. When the input additionally carries its done state at + ``("next", done_key)``, chunks are also cut at the trajectory boundaries + *inside* a row: the steps past a done are padded (repeating the last + in-trajectory action) and flagged in ``action_is_pad``, exactly like the + end of the window. .. note:: A :class:`~torchrl.data.SliceSampler` returns a *flat* ``[B * T, ...]`` @@ -1760,6 +1778,16 @@ class ActionChunkTransform(Transform): Defaults to ``"action_is_pad"``. time_dim (int): the time dimension of the action tensor (the action dimension must come right after it). Defaults to ``-2``. + done_key (NestedKey or None): the leaf done key: when the input + tensordict has a ``("next", done_key)`` entry (shaped like the + action without its trailing ``action_dim``, with or without a + trailing singleton), chunks do not cross the trajectory boundaries + it marks. When the entry is absent, each row is treated as a + single contiguous trajectory (the pre-0.14 behavior). Pass + ``None`` to ignore the done state altogether. + Defaults to ``"done"``. + + .. versionadded:: 0.14 Examples: >>> import torch @@ -1782,6 +1810,22 @@ class ActionChunkTransform(Transform): [False, False, False], [False, False, True], [False, True, True]]) + >>> # when the window carries its done state, chunks are also cut at + >>> # the trajectory boundary inside the window (here after step 1) + >>> td = TensorDict( + ... { + ... "action": torch.arange(4).view(1, 4, 1).float(), + ... ("next", "done"): torch.tensor( + ... [False, True, False, False] + ... ).view(1, 4, 1), + ... }, + ... batch_size=[1, 4], + ... ) + >>> t(td)["action_chunk"][0, :, :, 0] + tensor([[0., 1., 1.], + [1., 1., 1.], + [2., 3., 3.], + [3., 3., 3.]]) >>> # on a replay buffer: extend with raw [T, action_dim] trajectory >>> # windows (stored as-is), the chunks are built on the sample path >>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer @@ -1806,19 +1850,33 @@ def __init__( chunk_key: NestedKey = ACTION_CHUNK_KEY, pad_key: NestedKey = ACTION_IS_PAD_KEY, time_dim: int = -2, + done_key: NestedKey | None = "done", ) -> None: if chunk_size < 1: raise ValueError(f"chunk_size must be >= 1, got {chunk_size}.") + # The recipe: open a singleton chunk dim on a copy of the action, then + # concatenate the N upcoming actions along it. ``padding="same"`` + # repeats the last in-trajectory action past the boundaries and + # ``mask_key`` flags those fabricated slots. + super().__init__( + UnsqueezeTransform(dim=-2, in_keys=[action_key], out_keys=[chunk_key]), + CatFrames( + N=int(chunk_size), + dim=-2, + in_keys=[chunk_key], + out_keys=[chunk_key], + padding="same", + future=True, + mask_key=pad_key, + done_key=done_key, + ), + ) self.chunk_size = int(chunk_size) self.time_dim = int(time_dim) - # ``forward`` is fully overridden (it writes the chunk and the pad mask - # from the action), so no transform keys are declared: this is a pure - # data-path transform and the chunk/pad entries never appear in env - # specs. - super().__init__(in_keys=[], out_keys=[]) self._action_key = action_key self._chunk_key = chunk_key self._pad_key = pad_key + self._done_key = done_key @property def action_key(self) -> NestedKey: @@ -1832,27 +1890,30 @@ def chunk_key(self) -> NestedKey: def pad_key(self) -> NestedKey: return self._pad_key - def _build_chunk(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - H = self.chunk_size - dim = self.time_dim if self.time_dim >= 0 else action.dim() + self.time_dim - if dim != action.dim() - 2: + @property + def done_key(self) -> NestedKey | None: + return self._done_key + + def _maybe_get_done( + self, tensordict: TensorDictBase, action: torch.Tensor, dim: int + ) -> torch.Tensor | None: + if self._done_key is None: + return None + done = tensordict.get(("next", self._done_key), default=None) + if done is None: + return None + lead = action.shape[: dim + 1] + if done.shape == lead: + done = done.unsqueeze(-1) + elif done.shape != torch.Size((*lead, 1)): raise ValueError( - f"{type(self).__name__} expects the action dimension to immediately " - f"follow the time dimension (action shaped [..., T, action_dim]); got " - f"action.shape={tuple(action.shape)} with time_dim={self.time_dim}." + f"{type(self).__name__}: the ('next', {self._done_key!r}) entry " + f"of shape {tuple(done.shape)} does not line up with the action " + f"of shape {tuple(action.shape)}: expected {(*lead, 1)} or " + f"{tuple(lead)}. Pass done_key=None to chunk without " + "trajectory-boundary information." ) - T = action.shape[dim] - device = action.device - # idx[t, h] = t + h, clamped to the last valid step; is_pad marks h that - # ran past the end of the window. - idx = torch.arange(T, device=device).unsqueeze(-1) + torch.arange( - H, device=device - ).unsqueeze(0) - is_pad = idx >= T - idx = idx.clamp_max(T - 1).reshape(-1) - chunk = action.index_select(dim, idx).unflatten(dim, (T, H)) - is_pad = is_pad.expand(chunk.shape[:-1]).contiguous() - return chunk, is_pad + return done def forward(self, tensordict: TensorDictBase) -> TensorDictBase: action = tensordict.get(self.action_key, default=None) @@ -1863,11 +1924,72 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: f"{type(self).__name__}: '{self.action_key}' not found in tensordict " f"{tensordict}." ) - chunk, is_pad = self._build_chunk(action) - tensordict.set(self.chunk_key, chunk) - tensordict.set(self.pad_key, is_pad) + dim = self.time_dim if self.time_dim >= 0 else action.dim() + self.time_dim + if dim != action.dim() - 2 or dim < 0: + raise ValueError( + f"{type(self).__name__} expects the action dimension to immediately " + f"follow the time dimension (action shaped [..., T, action_dim]); got " + f"action.shape={tuple(action.shape)} with time_dim={self.time_dim}." + ) + # CatFrames' offline path keys the windowing on the *tensordict* batch + # dims (time last), while the chunk convention is keyed on the action's + # shape ([*B, T, action_dim]) -- the sampled tensordict may well be + # flat. Bridge the two by windowing a minimal time-structured view of + # the action (and of the done state, when available). + inner = TensorDict(batch_size=action.shape[: dim + 1]) + inner.set(self.action_key, action) + done = self._maybe_get_done(tensordict, action, dim) + if done is not None: + inner.set(("next", self._done_key), done) + inner = inner.refine_names(*[None] * dim, "time") + inner = super().forward(inner) + tensordict.set(self.chunk_key, inner.get(self.chunk_key)) + tensordict.set(self.pad_key, inner.get(self.pad_key)) + return tensordict + + def clone(self) -> Self: + # Compose.clone returns a plain Compose, which would drop the + # env-path overrides below; rebuild the recipe instead. + return type(self)( + self.chunk_size, + action_key=self._action_key, + chunk_key=self._chunk_key, + pad_key=self._pad_key, + time_dim=self.time_dim, + done_key=self._done_key, + ) + + # Attached to an environment, the transform is a documented no-op: chunk + # execution belongs to MultiStepActorWrapper / MultiAction, and the inner + # CatFrames is offline-only (future=True). The Compose machinery is + # bypassed on every env-facing path. + def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: + return next_tensordict + + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: + return next_tensordict + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + return tensordict_reset + + def _reset_on_native_autoreset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + return tensordict_reset + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict + def transform_input_spec(self, input_spec: Composite) -> Composite: + return input_spec + + def transform_output_spec(self, output_spec: Composite) -> Composite: + return output_spec + class ActionTokenizerTransform(Transform): """Encode and decode actions with an :class:`~torchrl.data.vla.ActionTokenizerBase`. diff --git a/torchrl/envs/transforms/_observation.py b/torchrl/envs/transforms/_observation.py index 53db1e1f352..404111470e7 100644 --- a/torchrl/envs/transforms/_observation.py +++ b/torchrl/envs/transforms/_observation.py @@ -896,6 +896,27 @@ class CatFrames(ObservationTransform): and raises an exception otherwise. done_key (NestedKey, optional): the done key to be used as partial done indicator. Must be unique. If not provided, defaults to ``"done"``. + future (bool, optional): if ``True``, each step's window gathers the + ``N`` *upcoming* frames ``[t, t + 1, ..., t + N - 1]`` instead of + the ``N`` most recent ones ``[t - N + 1, ..., t]``. With + ``padding="same"`` the slots that run past the end of the + trajectory repeat the last in-trajectory frame. Forward-looking + windows require the full trajectory, so this mode is only + available offline (replay buffer / data pipelines): attaching the + transform to an environment raises a ``RuntimeError`` on the step + path. Defaults to ``False``. + + .. versionadded:: 0.14 + mask_key (NestedKey, optional): if provided, the offline (forward / + unfolding) path also writes a boolean mask of shape + ``[*batch, time, N]`` flagging, for each window, the slots that + were fabricated by padding (``True`` = padded slot, either out of + the trajectory or out of the sampled window). This is the + convention of the ``action_is_pad`` entry of chunked-action + datasets. The mask is not available on the online (env step) + path. Defaults to ``None`` (no mask is written). + + .. versionadded:: 0.14 Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -978,6 +999,10 @@ class CatFrames(ObservationTransform): "different batch-sizes (since negative dims are batch invariant)." ) ACCEPTED_PADDING = {"same", "constant", "zeros"} + # class-level defaults double as fallbacks for instances pickled before + # these options existed + future = False + mask_key = None def __init__( self, @@ -990,6 +1015,8 @@ def __init__( as_inverse=False, reset_key: NestedKey | None = None, done_key: NestedKey | None = None, + future: bool = False, + mask_key: NestedKey | None = None, ): if in_keys is None: in_keys = IMAGE_KEYS @@ -997,6 +1024,8 @@ def __init__( out_keys = copy(in_keys) super().__init__(in_keys=in_keys, out_keys=out_keys) self.N = N + self.future = bool(future) + self.mask_key = mask_key if dim >= 0: raise ValueError(self._CAT_DIM_ERR) self.dim = dim @@ -1077,6 +1106,8 @@ def make_rb_transform_and_sampler( as_inverse=False, reset_key=self.reset_key, done_key=self.done_key, + future=self.future, + mask_key=self.mask_key, ) sampler = SliceSampler(slice_len=self.N, **sampler_kwargs) sampler._batch_size_multiplier = self.N @@ -1175,6 +1206,19 @@ def _inv_call(self, tensordict: TensorDictBase) -> torch.Tensor: def _call(self, next_tensordict: TensorDictBase, _reset=None) -> TensorDictBase: """Update the episode tensordict with max pooled keys.""" + if self.future: + raise RuntimeError( + "CatFrames(future=True) cannot run on the environment step " + "path: forward-looking windows require the full trajectory " + "and are only available offline (replay buffer / data " + "pipelines)." + ) + if self.mask_key is not None: + raise RuntimeError( + "CatFrames(mask_key=...) is only available offline (forward " + "/ unfolding): the online step path does not build " + "per-window validity masks." + ) _just_reset = _reset is not None for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): # Lazy init of buffers @@ -1241,6 +1285,13 @@ def _call(self, next_tensordict: TensorDictBase, _reset=None) -> TensorDictBase: @_apply_to_composite def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + if self.future: + raise RuntimeError( + "CatFrames(future=True) cannot be attached to an " + "environment: forward-looking windows require the full " + "trajectory and are only available offline (replay buffer / " + "data pipelines)." + ) space = observation_spec.space if isinstance(space, ContinuousBox): space.low = torch.cat([space.low] * self.N, self.dim) @@ -1301,9 +1352,14 @@ def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase: def unfold_done(done, N): prefix = (slice(None),) * (tensordict.ndim - 1) + # the leading no-reset block is built explicitly rather than by + # slicing ``done`` (which would cap it at the time length and + # break windows longer than the trajectory, N > T) + zeros_shape = list(done.shape) + zeros_shape[tensordict.ndim - 1] = self.N - 1 reset = torch.cat( [ - torch.zeros_like(done[prefix + (slice(self.N - 1),)]), + torch.zeros(zeros_shape, dtype=done.dtype, device=done.device), torch.ones_like(done[prefix + (slice(1),)]), done[prefix + (slice(None, -1),)], ], @@ -1320,9 +1376,43 @@ def unfold_done(done, N): reset[prefix + (0,)] = 1 return reset_unfold, reset - done = tensordict.get(("next", self.done_key)) + # The time axis is the last batch dim of (the possibly transposed) + # ``tensordict``; the same index addresses it in every entry since the + # batch dims lead the tensors. + tdim = tensordict.ndim - 1 + done = tensordict.get(("next", self.done_key), default=None) + if done is None: + if not self.future: + raise KeyError( + f"CatFrames.unfolding requires the {('next', self.done_key)} " + "entry to delimit trajectories. Make sure the sampled data " + "carries its done state, or use forward-looking windows " + "(future=True) to treat each batch row as a single " + "contiguous trajectory." + ) + # Absent done in future mode: each batch row is one contiguous + # trajectory and only the windows that run past its end are padded. + done = torch.zeros( + (*tensordict.shape, 1), + dtype=torch.bool, + device=tensordict.get(keys[0][0]).device, + ) + if self.future: + # Forward windows are backward windows of the time-reversed data: + # the chunk ``[t, ..., t + N - 1]`` is the reversed window at + # ``T - 1 - t`` read backwards. A boundary between steps ``t`` and + # ``t + 1`` (``done[t]``) sits between reversed steps ``T - 2 - t`` + # and ``T - 1 - t``, hence the flip + shift; the rolled-in last + # entry is never read (``unfold_done`` drops the final done). + done = done.flip(tdim).roll(-1, dims=tdim) done_mask, reset = unfold_done(done, self.N) + if self.mask_key is not None: + mask = done_mask + if self.future: + mask = mask.flip(tdim).flip(-1) + tensordict.set(self.mask_key, mask.reshape(*tensordict.shape, self.N)) + for in_key, out_key in keys: # check if we have an obs in "next" that has already been processed. # If so, we must add an offset @@ -1340,6 +1430,13 @@ def unfold_done(done, N): first_val = prev_val.unflatten( data.ndim + self.dim, (self.N, n_feat) ) + if first_val is not None and self.future: + raise NotImplementedError( + "CatFrames(future=True) does not support processing a " + "('next', key) entry alongside its root counterpart: the " + "one-step-offset fixup is only implemented for " + "history (backward) windows." + ) # The time axis sits at ``tensordict.ndim - 1`` within ``data`` (the # tensordict batch dims lead the tensor). Expressed relative to @@ -1348,6 +1445,8 @@ def unfold_done(done, N): # ``cat_frames`` functional so that the offline transform stays # byte-for-byte identical to its stateless core. time_dim = (tensordict.ndim - 1) - data.ndim + if self.future: + data = data.flip(tdim) data = F._cat_frames_windows( data, self.N, @@ -1357,6 +1456,11 @@ def unfold_done(done, N): time_dim=time_dim, done_mask=done_mask, ) + if self.future: + # Back to forward time, windows read oldest-to-newest: undo + # the time reversal and flip the window axis (which + # ``_cat_frames_windows`` placed just before the cat axis). + data = data.flip(tdim).flip(data.ndim + self.dim - 1) if first_val is not None: data0_pad = torch.full_like( diff --git a/torchrl/envs/transforms/functional.py b/torchrl/envs/transforms/functional.py index 295841486cc..4abd10f99f9 100644 --- a/torchrl/envs/transforms/functional.py +++ b/torchrl/envs/transforms/functional.py @@ -65,31 +65,21 @@ def _apply_same_padding(dim: int, data: Tensor, done_mask: Tensor) -> Tensor: (marked by ``done_mask``) are overwritten with the earliest in-trajectory frame of that window. ``data`` is the permuted, windowed tensor (the window axis already moved to ``data.ndim + dim - 1``); ``done_mask`` is the - un-permuted boolean mask of shape ``(*batch, time, N)``. + un-permuted boolean mask of shape ``(*batch, time, 1, N)`` (the singleton is + the done "feature" dim). """ d = data.ndim + dim - 1 - res = data.clone() - num_repeats_per_sample = done_mask.sum(dim=-1) - - if num_repeats_per_sample.dim() > 2: - extra_dims = num_repeats_per_sample.dim() - 2 - num_repeats_per_sample = num_repeats_per_sample.flatten(0, extra_dims) - res_flat_series = res.flatten(0, extra_dims) - else: - extra_dims = 0 - res_flat_series = res - - if d - 1 > extra_dims: - res_flat_series_flat_batch = res_flat_series.flatten(1, d - 1) - else: - res_flat_series_flat_batch = res_flat_series[:, None] - - for sample_idx, num_repeats in enumerate(num_repeats_per_sample): - if num_repeats > 0: - res_slice = res_flat_series_flat_batch[sample_idx] - res_slice[:, :num_repeats] = res_slice[:, num_repeats : num_repeats + 1] - - return res + N = done_mask.shape[-1] + # The out-of-trajectory slots always form a prefix of the window (the + # oldest frames), so slot ``j`` reads slot ``max(j, num_padded)``: its own + # value once past the prefix, the first in-trajectory frame otherwise. + num_padded = done_mask.sum(dim=-1) + index = torch.maximum(torch.arange(N, device=data.device), num_padded).clamp_max_( + N - 1 + ) + data = data.movedim(d, -1) + index = index.reshape(*index.shape[:-1], *(1,) * (data.ndim - index.ndim), N) + return data.gather(-1, index.expand(data.shape)).movedim(-1, d) def _cat_frames_windows( diff --git a/torchrl/trainers/algorithms/configs/transforms.py b/torchrl/trainers/algorithms/configs/transforms.py index 80e60f6cc2a..3d6ff4fb866 100644 --- a/torchrl/trainers/algorithms/configs/transforms.py +++ b/torchrl/trainers/algorithms/configs/transforms.py @@ -198,6 +198,13 @@ class CatFramesConfig(TransformConfig): dim: int = -3 in_keys: list[str] | None = None out_keys: list[str] | None = None + padding: str = "same" + padding_value: float = 0.0 + as_inverse: bool = False + reset_key: str | None = None + done_key: str | None = None + future: bool = False + mask_key: str | None = None _target_: str = "torchrl.envs.transforms.transforms.CatFrames" def __post_init__(self) -> None: diff --git a/tutorials/sphinx-tutorials/vla.py b/tutorials/sphinx-tutorials/vla.py index 867fd5e9bc4..0d9600d27c8 100644 --- a/tutorials/sphinx-tutorials/vla.py +++ b/tutorials/sphinx-tutorials/vla.py @@ -168,8 +168,10 @@ def make_observation(batch=batch): # :class:`~torchrl.envs.transforms.ActionChunkTransform` turns a per-step action # tensor ``[*B, T, action_dim]`` into the chunked training target # ``action_chunk`` ``[*B, T, H, action_dim]`` (plus an ``action_is_pad`` mask): -# for every step ``t`` it gathers the next ``H`` actions. This is the training -# target of modern chunked VLA policies (ACT, OpenVLA-OFT, pi0). +# for every step ``t`` it gathers the next ``H`` actions, stopping (padding and +# masking) at the trajectory boundaries when the sampled window carries its +# done state. This is the training target of modern chunked VLA policies (ACT, +# OpenVLA-OFT, pi0). # # Chunks mean different things on the two sides of the pipeline, and keeping # the two pictures apart avoids a classic confusion::