From 3a7ec0c2a15d479b551650adcc8fc5798118d09d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 15 May 2026 15:09:37 +0100 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- test/objectives/test_values.py | 64 ++++++++++++ test/test_collectors.py | 5 +- torchrl/objectives/value/advantages.py | 138 ++++++++++++++++++++++++- 3 files changed, 205 insertions(+), 2 deletions(-) diff --git a/test/objectives/test_values.py b/test/objectives/test_values.py index a1fa01e651b..f2c65cac162 100644 --- a/test/objectives/test_values.py +++ b/test/objectives/test_values.py @@ -466,6 +466,70 @@ def test_gae_recurrent(self, module, vectorized): a1 = r1["advantage"] torch.testing.assert_close(a0, a1) + @pytest.mark.parametrize("has_internal_truncation", [False, True]) + def test_gae_shifted_rollout_shape_call(self, has_internal_truncation): + # Shifted=True now dispatches to a rollout-shape single call when + # strictly equivalent to the flatten/interleave path (no internal + # truncations). When internal truncations exist it falls back to + # the flatten path. Either way the result must match shifted=False. + from torchrl.objectives.value.advantages import ValueEstimatorBase + + torch.manual_seed(0) + B, T, obs_dim = 4, 8, 6 + # Mimic an env rollout: obs[t+1] == next_obs[t] for non-done steps. + # At done positions, next_obs is whatever the env returned at the + # terminal/truncation step and obs[t+1] is the post-reset obs of + # a new trajectory (independent). + all_obs = torch.randn(B, T + 1, obs_dim) + obs = all_obs[:, :T].clone() + next_obs = all_obs[:, 1:].clone() + reward = torch.randn(B, T, 1) + done = torch.zeros(B, T, 1, dtype=torch.bool) + terminated = torch.zeros(B, T, 1, dtype=torch.bool) + if has_internal_truncation: + # Put a truncation (done=True, terminated=False) at t=3 on row 0 + # and decouple next_obs from obs[t+1] at that position. + done[0, 3, 0] = True + next_obs[0, 3] = torch.randn(obs_dim) + # Always have done at the rollout boundary (T-1). next_obs[T-1] is + # the rollout's exit observation; since we never see obs[T] in the + # rollout, the value at the boundary slot is dictated by next_obs. + done[:, -1, 0] = True + td = TensorDict( + { + "observation": obs, + "next": TensorDict( + { + "observation": next_obs, + "reward": reward, + "done": done, + "terminated": terminated, + "truncated": done & ~terminated, + }, + batch_size=[B, T], + ), + }, + batch_size=[B, T], + ) + td.refine_names(..., "time") + # Confirm the equivalence detector decides correctly. + assert ValueEstimatorBase._can_use_rollout_shape_call(td, ndim=2) is ( + not has_internal_truncation + ) + + value_net = TensorDictModule( + nn.Linear(obs_dim, 1), + in_keys=["observation"], + out_keys=["state_value"], + ) + gae_shifted = GAE(gamma=0.9, lmbda=0.95, value_network=value_net, shifted=True) + gae_unshifted = GAE( + gamma=0.9, lmbda=0.95, value_network=value_net, shifted=False + ) + adv_shifted = gae_shifted(td.copy())["advantage"] + adv_unshifted = gae_unshifted(td.copy())["advantage"] + torch.testing.assert_close(adv_shifted, adv_unshifted) + @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("gamma", [0.1, 0.5, 0.99]) @pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99]) diff --git a/test/test_collectors.py b/test/test_collectors.py index 8ad334e70dc..8705eead149 100644 --- a/test/test_collectors.py +++ b/test/test_collectors.py @@ -1564,7 +1564,10 @@ def make_env(): assert fake_keys == real_keys, set(real_keys) ^ set(fake_keys) # Must be all-zero. for key, val in fake.items(True, True): - if val.dtype in (torch.bool, torch.uint8) or not val.is_floating_point(): + if ( + val.dtype in (torch.bool, torch.uint8) + or not val.is_floating_point() + ): continue assert not val.any(), f"{key} is not zeroed" finally: diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index f76ed4e3b89..721bc2d1a93 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -567,6 +567,124 @@ def _maybe_drop_final_obs(tensordict: TensorDictBase) -> None: if "final" in tensordict.keys(): del tensordict["final"] + @staticmethod + def _can_use_rollout_shape_call(data: TensorDictBase, ndim: int) -> bool: + """Return whether the rollout-shape single-call path is safe to use. + + The rollout-shape path is strictly equivalent to the flatten/interleave + path whenever no internal truncations exist within the rollout. + + The rollout-shape path concatenates the boundary ``next_obs`` to the + rollout-shaped inputs and slices into ``value`` / ``next_value``. + It is strictly equivalent to the flatten/interleave path whenever + every internal ``done`` step (``t < T-1``) is also terminated. At + truncation-only steps (``done=True, terminated=False`` with + ``t < T-1``) the flatten path bootstraps from ``V(real next_obs)``, + whereas the rollout-shape path would substitute ``V(obs[t+1])`` from + the start of the next trajectory. Internal terminations are always + masked downstream by ``(1-terminated)``, so they don't matter. + """ + if ndim < 1: + return True + time_idx = ndim - 1 + T = data.shape[time_idx] + if T < 2: + return True + done = data.get(("next", "done"), default=None) + if done is None: + return True + truncated = data.get(("next", "truncated"), default=None) + if truncated is None: + # No truncated key: GAE treats terminated==done; no internal + # truncations possible. + return True + terminated = data.get(("next", "terminated"), default=done) + internal_idx = (slice(None),) * time_idx + (slice(0, T - 1),) + trunc = done[internal_idx].bool() & (~terminated[internal_idx].bool()) + return not bool(trunc.any()) + + def _call_value_net_rollout_shape( + self, + data: TensorDictBase, + params: TensorDictBase | None, + next_params: TensorDictBase | None, + value_key: NestedKey, + ndim: int, + value_net: TensorDictModuleBase, + _call_value_net, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Single call on rollout-shaped data: ``[..., T] → [..., T+1]``. + + Builds the value-net input by appending the rollout-boundary + ``next_obs`` along the time dim (using ``("final", k)`` when + present, else ``("next", k)[..., T-1]``), runs ``value_net`` once, + then slices the output into ``value = out[..., :T]`` and + ``next_value = out[..., 1:T+1]``. + + Caller must verify equivalence with the flatten path via + :meth:`_can_use_rollout_shape_call`. + """ + if next_params is not None and next_params is not params: + raise ValueError( + "the value at t and t+1 cannot be retrieved in a single call when both params and next params are passed." + ) + in_keys = value_net.in_keys + time_idx = ndim - 1 + T = data.shape[time_idx] + # Include value_key in the select so skip_existing on the value + # network can detect pre-filled values (matches flatten path). + root_part = data.select(*in_keys, value_key, strict=False) + boundary_index = (slice(None),) * time_idx + (slice(T - 1, T),) + next_select = data.get("next").select(*in_keys, value_key, strict=False) + boundary_part = next_select[boundary_index] + # Fill keys absent from ("next", ...) from root at the boundary + # slice. Matches the flatten path's _fill_missing_next_inputs + # contract — in particular, this gives "is_init" at the boundary + # the value of root["is_init"][..., T-1] (typically False), so the + # value net continues the trajectory's hidden state through to the + # boundary slot (correct bootstrap semantics). + boundary_part = self._fill_missing_next_inputs( + boundary_part, root_part[boundary_index], in_keys + ) + # Honor ("final", k) when present (collector contract for true + # rollout-boundary next obs), mirroring _apply_final_obs_to_next_done. + final_root = data.get("final", default=None) + if final_root is not None: + copied = False + for k in in_keys: + key_tuple = (k,) if isinstance(k, str) else tuple(k) + fv = final_root.get( + key_tuple[0] if len(key_tuple) == 1 else key_tuple, + default=None, + ) + if fv is None: + continue + if not copied: + boundary_part = boundary_part.copy() + copied = True + boundary_part.set(k, fv.unsqueeze(time_idx)) + data_in = torch.cat([root_part, boundary_part], dim=time_idx) + if params is not None: + with params.to_module(value_net): + value_est_full = _call_value_net(data_in) + else: + value_est_full = _call_value_net(data_in) + root_idx = (slice(None),) * time_idx + (slice(0, T),) + next_idx = (slice(None),) * time_idx + (slice(1, T + 1),) + value = value_est_full[root_idx] + value_ = value_est_full[next_idx] + done = data.get(("next", "done"), default=None) + if done is not None: + try: + value = value.view_as(done) + value_ = value_.view_as(done) + except RuntimeError: + # Value feat dim doesn't match done's trailing dim; leave + # shapes as produced by the network. Downstream GAE handles + # both layouts. + pass + return value, value_ + def _call_value_nets( self, data: TensorDictBase, @@ -605,7 +723,25 @@ def _call_value_net(data_in: TensorDictBase) -> torch.Tensor: values.append(value_net(chunk).get(value_key)) return torch.cat(values, dim=0) - if single_call: + if single_call and self._can_use_rollout_shape_call(data, ndim): + # Rollout-shape single call: feed [..., T+1] to the value net in + # one shot, then slice into value/next_value. Strictly equivalent + # to the flatten/interleave path when no internal truncations + # exist (verified by _can_use_rollout_shape_call). Preserves + # batch parallelism along the natural batch dim, which matters + # for RNN value nets with reset-aware backends (scan/triton/ + # cuDNN-with-resets) where the flatten path would serialise + # along a [1, B*T] time axis. + value, value_ = self._call_value_net_rollout_shape( + data=data, + params=params, + next_params=next_params, + value_key=value_key, + ndim=ndim, + value_net=value_net, + _call_value_net=_call_value_net, + ) + elif single_call: # We are going to flatten the data, then interleave the last observation of each trajectory in between its # previous obs (from the root TD) and the first of the next trajectory. Eventually, each trajectory will # have T+1 elements (or, for a batch of N trajectories, we will have \Sum_{t=0}^{T-1} length_t + T From e703abd8eb81f5a64ff375eb0720bd070dcb6782 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 15 May 2026 16:08:48 +0100 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- test/objectives/test_values.py | 122 +++++--- torchrl/objectives/value/advantages.py | 415 ++++++++++++++++--------- 2 files changed, 339 insertions(+), 198 deletions(-) diff --git a/test/objectives/test_values.py b/test/objectives/test_values.py index f2c65cac162..8de88b6c7c6 100644 --- a/test/objectives/test_values.py +++ b/test/objectives/test_values.py @@ -11,7 +11,7 @@ import pytest import torch -from tensordict import assert_allclose_td, TensorDict +from tensordict import assert_allclose_td, TensorDict, UnbatchedTensor from tensordict.nn import ( set_composite_lp_aggregate, TensorDictModule, @@ -466,69 +466,105 @@ def test_gae_recurrent(self, module, vectorized): a1 = r1["advantage"] torch.testing.assert_close(a0, a1) - @pytest.mark.parametrize("has_internal_truncation", [False, True]) - def test_gae_shifted_rollout_shape_call(self, has_internal_truncation): - # Shifted=True now dispatches to a rollout-shape single call when - # strictly equivalent to the flatten/interleave path (no internal - # truncations). When internal truncations exist it falls back to - # the flatten path. Either way the result must match shifted=False. - from torchrl.objectives.value.advantages import ValueEstimatorBase - - torch.manual_seed(0) + def _build_shifted_test_td(self, *, with_final: bool, with_internal_done: bool): + """Build a rollout-shaped tensordict for shifted-mode tests.""" B, T, obs_dim = 4, 8, 6 - # Mimic an env rollout: obs[t+1] == next_obs[t] for non-done steps. - # At done positions, next_obs is whatever the env returned at the - # terminal/truncation step and obs[t+1] is the post-reset obs of - # a new trajectory (independent). + # Real rollout invariant: obs[t+1] == next_obs[t] for non-done steps. all_obs = torch.randn(B, T + 1, obs_dim) obs = all_obs[:, :T].clone() next_obs = all_obs[:, 1:].clone() reward = torch.randn(B, T, 1) done = torch.zeros(B, T, 1, dtype=torch.bool) terminated = torch.zeros(B, T, 1, dtype=torch.bool) - if has_internal_truncation: - # Put a truncation (done=True, terminated=False) at t=3 on row 0 - # and decouple next_obs from obs[t+1] at that position. + if with_internal_done: + # Internal truncation at (0, 3): decouple next_obs from obs[t+1]. done[0, 3, 0] = True next_obs[0, 3] = torch.randn(obs_dim) - # Always have done at the rollout boundary (T-1). next_obs[T-1] is - # the rollout's exit observation; since we never see obs[T] in the - # rollout, the value at the boundary slot is dictated by next_obs. + # Rollout boundary at T-1 (truncation by collector cutoff). done[:, -1, 0] = True - td = TensorDict( - { - "observation": obs, - "next": TensorDict( - { - "observation": next_obs, - "reward": reward, - "done": done, - "terminated": terminated, - "truncated": done & ~terminated, - }, - batch_size=[B, T], - ), - }, - batch_size=[B, T], - ) + td_data = { + "observation": obs, + "next": TensorDict( + { + "observation": next_obs, + "reward": reward, + "done": done, + "terminated": terminated, + "truncated": done & ~terminated, + }, + batch_size=[B, T], + ), + } + if with_final: + # Collector with final_obs=True stores the true rollout-boundary + # next obs as an UnbatchedTensor leaf under ("final", k) so the + # leaf escapes the rollout's [B, T] batch-size enforcement. + td_data["final"] = TensorDict( + {"observation": UnbatchedTensor(all_obs[:, T])}, + batch_size=[], + ) + td = TensorDict(td_data, batch_size=[B, T]) td.refine_names(..., "time") - # Confirm the equivalence detector decides correctly. - assert ValueEstimatorBase._can_use_rollout_shape_call(td, ndim=2) is ( - not has_internal_truncation + return td, obs_dim + + @pytest.mark.parametrize("with_final", [False, True]) + @pytest.mark.parametrize("with_internal_done", [False, True]) + def test_gae_shifted_compact_and_legacy(self, with_final, with_internal_done): + # Both shifted='compact' and shifted='legacy' must produce a valid + # advantage. When the rollout has no internal truncations, the two + # paths and shifted=False must all match. When internal truncations + # exist, 'compact' is allowed a small bias (it bootstraps from + # V(obs[t+1]) instead of V(real_next_obs)); 'legacy' must still + # match shifted=False exactly. + torch.manual_seed(0) + td, obs_dim = self._build_shifted_test_td( + with_final=with_final, with_internal_done=with_internal_done ) - value_net = TensorDictModule( nn.Linear(obs_dim, 1), in_keys=["observation"], out_keys=["state_value"], ) - gae_shifted = GAE(gamma=0.9, lmbda=0.95, value_network=value_net, shifted=True) + gae_compact = GAE( + gamma=0.9, lmbda=0.95, value_network=value_net, shifted="compact" + ) + gae_legacy = GAE( + gamma=0.9, lmbda=0.95, value_network=value_net, shifted="legacy" + ) gae_unshifted = GAE( gamma=0.9, lmbda=0.95, value_network=value_net, shifted=False ) - adv_shifted = gae_shifted(td.copy())["advantage"] + adv_compact = gae_compact(td.copy())["advantage"] + adv_legacy = gae_legacy(td.copy())["advantage"] adv_unshifted = gae_unshifted(td.copy())["advantage"] - torch.testing.assert_close(adv_shifted, adv_unshifted) + # legacy is the exact reference: matches shifted=False bit-for-bit. + torch.testing.assert_close(adv_legacy, adv_unshifted) + if not with_internal_done: + # No internal truncations: compact also matches exactly when + # ("final", obs) is provided. Without final, compact has a + # small boundary bias from copying V(obs[T-1]); not asserted. + if with_final: + torch.testing.assert_close(adv_compact, adv_unshifted) + + def test_gae_shifted_true_deprecation_aliases_legacy(self): + torch.manual_seed(0) + td, obs_dim = self._build_shifted_test_td( + with_final=False, with_internal_done=True + ) + value_net = TensorDictModule( + nn.Linear(obs_dim, 1), + in_keys=["observation"], + out_keys=["state_value"], + ) + with pytest.warns(DeprecationWarning, match="shifted=True is deprecated"): + gae_true = GAE(gamma=0.9, lmbda=0.95, value_network=value_net, shifted=True) + gae_legacy = GAE( + gamma=0.9, lmbda=0.95, value_network=value_net, shifted="legacy" + ) + assert gae_true.shifted == "legacy" + adv_true = gae_true(td.copy())["advantage"] + adv_legacy = gae_legacy(td.copy())["advantage"] + torch.testing.assert_close(adv_true, adv_legacy) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("gamma", [0.1, 0.5, 0.99]) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 721bc2d1a93..4261a638623 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -11,6 +11,7 @@ from contextlib import nullcontext from dataclasses import asdict, dataclass from functools import wraps +from typing import Literal import torch from tensordict import is_tensor_collection, TensorDictBase @@ -229,7 +230,7 @@ def __init__( self, *, value_network: TensorDictModule, - shifted: bool = False, + shifted: bool | Literal["compact", "legacy"] = False, differentiable: bool = False, skip_existing: bool | None = None, advantage_key: NestedKey = None, @@ -252,7 +253,7 @@ def __init__( self.skip_existing = skip_existing self.__dict__["value_network"] = value_network self.dep_keys = {} - self.shifted = shifted + self.shifted = self._normalize_shifted(shifted) if advantage_key is not None: raise RuntimeError( @@ -568,42 +569,40 @@ def _maybe_drop_final_obs(tensordict: TensorDictBase) -> None: del tensordict["final"] @staticmethod - def _can_use_rollout_shape_call(data: TensorDictBase, ndim: int) -> bool: - """Return whether the rollout-shape single-call path is safe to use. - - The rollout-shape path is strictly equivalent to the flatten/interleave - path whenever no internal truncations exist within the rollout. - - The rollout-shape path concatenates the boundary ``next_obs`` to the - rollout-shaped inputs and slices into ``value`` / ``next_value``. - It is strictly equivalent to the flatten/interleave path whenever - every internal ``done`` step (``t < T-1``) is also terminated. At - truncation-only steps (``done=True, terminated=False`` with - ``t < T-1``) the flatten path bootstraps from ``V(real next_obs)``, - whereas the rollout-shape path would substitute ``V(obs[t+1])`` from - the start of the next trajectory. Internal terminations are always - masked downstream by ``(1-terminated)``, so they don't matter. + def _normalize_shifted( + shifted: bool | Literal["compact", "legacy"], + ) -> Literal[False, "compact", "legacy"]: + """Normalize the ``shifted`` argument. + + ``shifted=True`` is deprecated; users must opt in explicitly to + either ``"compact"`` (compile-friendly constant-shape single call, + small bias at trajectory boundaries) or ``"legacy"`` (current + flatten/interleave path, exact ``V(next_obs)`` but variable shape). """ - if ndim < 1: - return True - time_idx = ndim - 1 - T = data.shape[time_idx] - if T < 2: - return True - done = data.get(("next", "done"), default=None) - if done is None: - return True - truncated = data.get(("next", "truncated"), default=None) - if truncated is None: - # No truncated key: GAE treats terminated==done; no internal - # truncations possible. - return True - terminated = data.get(("next", "terminated"), default=done) - internal_idx = (slice(None),) * time_idx + (slice(0, T - 1),) - trunc = done[internal_idx].bool() & (~terminated[internal_idx].bool()) - return not bool(trunc.any()) - - def _call_value_net_rollout_shape( + if shifted is False: + return False + if shifted is True: + warnings.warn( + "shifted=True is deprecated and will be removed in v0.15. " + "Pass shifted='legacy' to preserve the current " + "flatten/interleave behavior (exact V(next_obs), variable " + "shape, not compile-friendly), or shifted='compact' to opt " + "into the new constant-shape single-call path (small bias " + "at trajectory boundaries, compile-friendly). The default " + "for shifted=True is currently 'legacy'; this default will " + "be removed in v0.15.", + DeprecationWarning, + stacklevel=3, + ) + return "legacy" + if shifted in ("compact", "legacy"): + return shifted + raise ValueError( + f"shifted must be one of False, 'compact', 'legacy' (or the " + f"deprecated True), got {shifted!r}." + ) + + def _call_value_net_compact( self, data: TensorDictBase, params: TensorDictBase | None, @@ -613,16 +612,22 @@ def _call_value_net_rollout_shape( value_net: TensorDictModuleBase, _call_value_net, ) -> tuple[torch.Tensor, torch.Tensor]: - """Single call on rollout-shaped data: ``[..., T] → [..., T+1]``. - - Builds the value-net input by appending the rollout-boundary - ``next_obs`` along the time dim (using ``("final", k)`` when - present, else ``("next", k)[..., T-1]``), runs ``value_net`` once, - then slices the output into ``value = out[..., :T]`` and - ``next_value = out[..., 1:T+1]``. - - Caller must verify equivalence with the flatten path via - :meth:`_can_use_rollout_shape_call`. + """Compact single-call path: constant-shape value-net call. + + Runs ``value_net`` once on the rollout-shaped inputs along the time + dim. When ``("final", k)`` is present for every value-net in-key + the call has shape ``[..., T+1]`` (the last slot is the + rollout-boundary observation from ``("final", ...)``); otherwise it + has shape ``[..., T]`` and ``V(next_obs[T-1])`` is approximated by + copying ``V(obs[T-1])``. In both cases ``V(next_obs[t])`` for + ``t < T-1`` is taken as ``V(obs[t+1])``, which is exact at + non-trajectory-boundary steps and a small bias at internal + truncations. + + This path has no Python branches on tensor values, no ``.item()`` + syncs, and a shape that depends only on the collector's + ``("final", ...)`` policy — i.e. constant within a training run, + so ``torch.compile`` specializes once and stays specialized. """ if next_params is not None and next_params is not params: raise ValueError( @@ -632,47 +637,59 @@ def _call_value_net_rollout_shape( time_idx = ndim - 1 T = data.shape[time_idx] # Include value_key in the select so skip_existing on the value - # network can detect pre-filled values (matches flatten path). + # network can detect pre-filled values (matches legacy path). root_part = data.select(*in_keys, value_key, strict=False) - boundary_index = (slice(None),) * time_idx + (slice(T - 1, T),) - next_select = data.get("next").select(*in_keys, value_key, strict=False) - boundary_part = next_select[boundary_index] - # Fill keys absent from ("next", ...) from root at the boundary - # slice. Matches the flatten path's _fill_missing_next_inputs - # contract — in particular, this gives "is_init" at the boundary - # the value of root["is_init"][..., T-1] (typically False), so the - # value net continues the trajectory's hidden state through to the - # boundary slot (correct bootstrap semantics). - boundary_part = self._fill_missing_next_inputs( - boundary_part, root_part[boundary_index], in_keys - ) - # Honor ("final", k) when present (collector contract for true - # rollout-boundary next obs), mirroring _apply_final_obs_to_next_done. + # Determine whether ("final", k) is available for every in-key + # whose corresponding root value exists. Decision is deterministic + # given the collector config, so the path taken stays constant + # within a training run. final_root = data.get("final", default=None) - if final_root is not None: - copied = False + final_values: dict = {} + use_final_boundary = final_root is not None + if use_final_boundary: for k in in_keys: + root_v = root_part.get(k, default=None) + if root_v is None: + continue key_tuple = (k,) if isinstance(k, str) else tuple(k) fv = final_root.get( key_tuple[0] if len(key_tuple) == 1 else key_tuple, default=None, ) if fv is None: - continue - if not copied: - boundary_part = boundary_part.copy() - copied = True + use_final_boundary = False + break + final_values[k] = fv + if use_final_boundary and T >= 1: + boundary_index = (slice(None),) * time_idx + (slice(T - 1, T),) + boundary_part = root_part[boundary_index].copy() + for k, fv in final_values.items(): boundary_part.set(k, fv.unsqueeze(time_idx)) - data_in = torch.cat([root_part, boundary_part], dim=time_idx) - if params is not None: - with params.to_module(value_net): - value_est_full = _call_value_net(data_in) + data_in = torch.cat([root_part, boundary_part], dim=time_idx) + if params is not None: + with params.to_module(value_net): + values_full = _call_value_net(data_in) + else: + values_full = _call_value_net(data_in) + root_idx = (slice(None),) * time_idx + (slice(0, T),) + next_idx = (slice(None),) * time_idx + (slice(1, T + 1),) + value = values_full[root_idx] + value_ = values_full[next_idx] else: - value_est_full = _call_value_net(data_in) - root_idx = (slice(None),) * time_idx + (slice(0, T),) - next_idx = (slice(None),) * time_idx + (slice(1, T + 1),) - value = value_est_full[root_idx] - value_ = value_est_full[next_idx] + if params is not None: + with params.to_module(value_net): + values = _call_value_net(root_part) + else: + values = _call_value_net(root_part) + value = values + if T >= 2: + shifted_idx = (slice(None),) * time_idx + (slice(1, T),) + last_idx = (slice(None),) * time_idx + (slice(T - 1, T),) + value_ = torch.cat( + [values[shifted_idx], values[last_idx]], dim=time_idx + ) + else: + value_ = values done = data.get(("next", "done"), default=None) if done is not None: try: @@ -680,8 +697,7 @@ def _call_value_net_rollout_shape( value_ = value_.view_as(done) except RuntimeError: # Value feat dim doesn't match done's trailing dim; leave - # shapes as produced by the network. Downstream GAE handles - # both layouts. + # shapes as produced by the network. pass return value, value_ @@ -690,13 +706,20 @@ def _call_value_nets( data: TensorDictBase, params: TensorDictBase, next_params: TensorDictBase, - single_call: bool, + single_call: bool | Literal["compact", "legacy"], value_key: NestedKey, detach_next: bool, vmap_randomness: str = "error", *, value_net: TensorDictModuleBase | None = None, ): + # ``single_call`` is passed by callers as ``self.shifted`` and is + # one of ``False``, ``"compact"``, or ``"legacy"`` after + # normalization. ``True`` still arrives untransformed from a few + # direct callers — fold it into the legacy path for backwards + # compat (the constructor already warned). + if single_call is True: + single_call = "legacy" if value_net is None: value_net = self.value_network in_keys = value_net.in_keys @@ -723,16 +746,8 @@ def _call_value_net(data_in: TensorDictBase) -> torch.Tensor: values.append(value_net(chunk).get(value_key)) return torch.cat(values, dim=0) - if single_call and self._can_use_rollout_shape_call(data, ndim): - # Rollout-shape single call: feed [..., T+1] to the value net in - # one shot, then slice into value/next_value. Strictly equivalent - # to the flatten/interleave path when no internal truncations - # exist (verified by _can_use_rollout_shape_call). Preserves - # batch parallelism along the natural batch dim, which matters - # for RNN value nets with reset-aware backends (scan/triton/ - # cuDNN-with-resets) where the flatten path would serialise - # along a [1, B*T] time axis. - value, value_ = self._call_value_net_rollout_shape( + if single_call == "compact": + value, value_ = self._call_value_net_compact( data=data, params=params, next_params=next_params, @@ -869,18 +884,36 @@ class TD0Estimator(ValueEstimatorBase): gamma (scalar): exponential mean discount. value_network (TensorDictModule): value operator used to retrieve the value estimates. - shifted (bool, optional): if ``True``, the value and next value are - estimated with a single call to the value network. This is faster - but is only valid whenever (1) the ``"next"`` value is shifted by - only one time step (which is not the case with multi-step value - estimation, for instance) and (2) when the parameters used at time - ``t`` and ``t+1`` are identical (which is not the case when target - parameters are to be used). For recurrent policies or compact - rollouts, the input should contain long, contiguous trajectory - windows with valid boundary next states; short partial rollouts - that drop the final next observation can bias bootstrapping. In - that case, keep or reconstruct boundary next states, or use - ``shifted=False``. Defaults to ``False``. + shifted (bool or str, optional): controls how value and next-value + are obtained from the value network. ``False`` (default) calls + the value network twice (once on the root tensordict, once on + ``"next"``), which is correct whenever ``"next"`` may differ + non-trivially from ``obs[t+1]``. Truthy values request a single + call: + + - ``"compact"``: constant-shape single call along the time dim + (``[..., T]`` or ``[..., T+1]`` depending on whether + ``("final", k)`` is available for the value-net in-keys). + ``V(next_obs[t])`` is taken as ``V(obs[t+1])`` for ``t Date: Fri, 15 May 2026 21:47:43 +0100 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- torchrl/objectives/value/advantages.py | 113 +++++++++++++------------ 1 file changed, 58 insertions(+), 55 deletions(-) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 4261a638623..ae662279e88 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -614,20 +614,35 @@ def _call_value_net_compact( ) -> tuple[torch.Tensor, torch.Tensor]: """Compact single-call path: constant-shape value-net call. - Runs ``value_net`` once on the rollout-shaped inputs along the time - dim. When ``("final", k)`` is present for every value-net in-key - the call has shape ``[..., T+1]`` (the last slot is the - rollout-boundary observation from ``("final", ...)``); otherwise it - has shape ``[..., T]`` and ``V(next_obs[T-1])`` is approximated by - copying ``V(obs[T-1])``. In both cases ``V(next_obs[t])`` for - ``t < T-1`` is taken as ``V(obs[t+1])``, which is exact at - non-trajectory-boundary steps and a small bias at internal + Always runs ``value_net`` once on a ``[..., T+1]`` input along the + time dim. The boundary slot at ``T`` is filled per value-net + in-key with the first available of: + + 1. ``("final", k)``: explicit collector contract for the + rollout-boundary observation (most authoritative; set by + ``Collector(final_obs=True)``). + 2. ``("next", k)[..., T-1:T]``: env-returned "next" value at the + last step. Present whenever the rollout is collected without + ``compact_obs=True``. + 3. A duplicate of ``root[T-1]``: smoothness proxy used when + neither of the above is available (e.g. Isaac with + ``compact_obs=True`` and no ``final_obs=True``). + + The duplication in case 3 is intentional for recurrent value + nets: feeding the RNN one extra step with the duplicated obs + advances the hidden state, so ``V(next_obs[T-1])`` is + approximated by ``f(obs_{T-1}, h_T)`` — not by a scalar copy of + ``V_T = f(obs_{T-1}, h_{T-1})``. The two coincide for + non-recurrent value nets. + + For ``t < T-1``, ``V(next_obs[t])`` is taken as ``V(obs[t+1])`` — + exact at non-trajectory-boundary steps, small bias at internal truncations. - This path has no Python branches on tensor values, no ``.item()`` - syncs, and a shape that depends only on the collector's - ``("final", ...)`` policy — i.e. constant within a training run, - so ``torch.compile`` specializes once and stays specialized. + Shape and code path are constant within a training run (the + collector config determines availability of ``("final", ...)`` + and ``("next", ...)`` deterministically across calls), so + ``torch.compile`` specializes once and stays specialized. """ if next_params is not None and next_params is not params: raise ValueError( @@ -639,57 +654,45 @@ def _call_value_net_compact( # Include value_key in the select so skip_existing on the value # network can detect pre-filled values (matches legacy path). root_part = data.select(*in_keys, value_key, strict=False) - # Determine whether ("final", k) is available for every in-key - # whose corresponding root value exists. Decision is deterministic - # given the collector config, so the path taken stays constant - # within a training run. + boundary_index = (slice(None),) * time_idx + (slice(T - 1, T),) + # Per-key boundary overrides. See docstring for priority. final_root = data.get("final", default=None) - final_values: dict = {} - use_final_boundary = final_root is not None - if use_final_boundary: - for k in in_keys: - root_v = root_part.get(k, default=None) - if root_v is None: - continue + next_root = data.get("next", default=None) + boundary_overrides: dict = {} + for k in in_keys: + if root_part.get(k, default=None) is None: + continue + # 1. ("final", k): unbatched leaf, [..., *F] with no time dim. + if final_root is not None: key_tuple = (k,) if isinstance(k, str) else tuple(k) fv = final_root.get( key_tuple[0] if len(key_tuple) == 1 else key_tuple, default=None, ) - if fv is None: - use_final_boundary = False - break - final_values[k] = fv - if use_final_boundary and T >= 1: - boundary_index = (slice(None),) * time_idx + (slice(T - 1, T),) - boundary_part = root_part[boundary_index].copy() - for k, fv in final_values.items(): - boundary_part.set(k, fv.unsqueeze(time_idx)) - data_in = torch.cat([root_part, boundary_part], dim=time_idx) - if params is not None: - with params.to_module(value_net): - values_full = _call_value_net(data_in) - else: + if fv is not None: + boundary_overrides[k] = fv.unsqueeze(time_idx) + continue + # 2. ("next", k)[..., T-1:T]: env-returned next at last step. + if next_root is not None: + nv = next_root.get(k, default=None) + if nv is not None: + boundary_overrides[k] = nv[boundary_index] + continue + # 3. (no override) — boundary_part keeps the duplicated + # root[T-1] value below. + boundary_part = root_part[boundary_index].copy() + for k, v in boundary_overrides.items(): + boundary_part.set(k, v) + data_in = torch.cat([root_part, boundary_part], dim=time_idx) + if params is not None: + with params.to_module(value_net): values_full = _call_value_net(data_in) - root_idx = (slice(None),) * time_idx + (slice(0, T),) - next_idx = (slice(None),) * time_idx + (slice(1, T + 1),) - value = values_full[root_idx] - value_ = values_full[next_idx] else: - if params is not None: - with params.to_module(value_net): - values = _call_value_net(root_part) - else: - values = _call_value_net(root_part) - value = values - if T >= 2: - shifted_idx = (slice(None),) * time_idx + (slice(1, T),) - last_idx = (slice(None),) * time_idx + (slice(T - 1, T),) - value_ = torch.cat( - [values[shifted_idx], values[last_idx]], dim=time_idx - ) - else: - value_ = values + values_full = _call_value_net(data_in) + root_idx = (slice(None),) * time_idx + (slice(0, T),) + next_idx = (slice(None),) * time_idx + (slice(1, T + 1),) + value = values_full[root_idx] + value_ = values_full[next_idx] done = data.get(("next", "done"), default=None) if done is not None: try: