diff --git a/test/objectives/test_values.py b/test/objectives/test_values.py index 9e19e2ccfb8..922da3e06c3 100644 --- a/test/objectives/test_values.py +++ b/test/objectives/test_values.py @@ -10,6 +10,7 @@ import pytest import torch +from packaging import version from tensordict import assert_allclose_td, TensorDict from tensordict.nn import ( @@ -53,6 +54,8 @@ PENDULUM_VERSIONED, ) +_TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) + class TestValues: @pytest.mark.parametrize( @@ -411,7 +414,8 @@ def _build_shifted_test_td(self, *, with_internal_done: bool): return td, obs_dim @pytest.mark.parametrize("with_internal_done", [False, True]) - def test_gae_shifted_compact_and_legacy(self, with_internal_done): + @pytest.mark.parametrize("compact_cat_dim", ["batch", "time"]) + def test_gae_shifted_compact_and_legacy(self, with_internal_done, compact_cat_dim): # Both shifted='compact' and shifted='legacy' must produce a valid # advantage. 'legacy' must match shifted=False exactly. 'compact' # is allowed a small boundary bias from copying V(obs[T-1]) at @@ -424,7 +428,11 @@ def test_gae_shifted_compact_and_legacy(self, with_internal_done): out_keys=["state_value"], ) gae_compact = GAE( - gamma=0.9, lmbda=0.95, value_network=value_net, shifted="compact" + gamma=0.9, + lmbda=0.95, + value_network=value_net, + shifted="compact", + compact_cat_dim=compact_cat_dim, ) gae_legacy = GAE( gamma=0.9, lmbda=0.95, value_network=value_net, shifted="legacy" @@ -455,6 +463,152 @@ def test_gae_shifted_true_deprecation_aliases_legacy(self): adv_legacy = gae_legacy(td.copy())["advantage"] torch.testing.assert_close(adv_true, adv_legacy) + @pytest.mark.skipif( + _TORCH_VERSION < version.parse("2.7"), + reason="GAE compact recurrent path uses torch.vmap chunked semantics that fall " + "back to _pseudo_vmap on torch<2.7 (NotImplementedError).", + ) + @pytest.mark.parametrize("module", ["lstm", "gru"]) + @pytest.mark.parametrize("compact_cat_dim", ["batch", "time"]) + def test_gae_recurrent_shifted_compact_matches_unshifted_isaac_shape( + self, module, compact_cat_dim + ): + # Isaac-shaped regression test: recurrent value network, multi-trajectory + # rollout with truncations every `episode_len` steps (never terminations), + # and ``compact_obs=False`` semantics — ``("next", obs)`` is populated + # everywhere, in particular at internal-done positions where it carries + # the true pre-reset terminal observation (not the post-reset first obs + # of the new episode). + # + # Under these conditions shifted="compact" must match shifted=False + # to within a small tolerance. The compact path currently builds + # ``data_in = [root_obs[0:T], boundary_obs]`` and reads + # ``value_[t] = V(data_in[t+1])``; for ``t < T-1`` that is + # ``V(root_obs[t+1])``, which at internal-done positions is the + # **post-reset** obs rather than ``("next", obs)[t]``. The + # boundary-override mechanism in ``_call_value_net_compact`` only fills + # the rollout-edge slot, leaving internal-done positions corrupted. + # GAE then bootstraps with ``(1 - terminated)`` (truncations are + # *not* masked), so the wrong ``next_state_value`` propagates straight + # into the value target / advantage. + # + # See ``examples/collectors/isaaclab_rnn_ppo_memory.py`` and + # ``torchrl/objectives/value/advantages.py:_call_value_net_compact``. + torch.manual_seed(0) + B, T, obs_dim, hidden = 4, 16, 6, 8 + episode_len = 4 # internal truncation every 4 steps + g = torch.Generator(device="cpu").manual_seed(0) + all_obs = torch.randn(B, T + 1, obs_dim, generator=g) + obs = all_obs[:, :T].clone() + next_obs = all_obs[:, 1:].clone() + done = torch.zeros(B, T, 1, dtype=torch.bool) + for t in range(episode_len - 1, T, episode_len): + done[:, t, 0] = True + if t < T - 1: + # Decouple next_obs[t] from obs[t+1]: env returned the true + # truncation obs, then auto-reset gave a fresh obs[t+1]. + next_obs[:, t] = torch.randn(B, obs_dim, generator=g) + # Isaac-Ant only ever truncates (max_episode_steps); never terminates. + terminated = torch.zeros_like(done) + truncated = done.clone() + is_init = torch.zeros(B, T, 1, dtype=torch.bool) + is_init[:, 0, 0] = True + is_init[:, 1:][done[:, :-1]] = True + next_is_init = done.clone() + reward = torch.randn(B, T, 1, generator=g) * 0.1 + td = TensorDict( + { + "observation": obs, + "is_init": is_init, + "next": TensorDict( + { + "observation": next_obs, + "reward": reward, + "done": done, + "terminated": terminated, + "truncated": truncated, + "is_init": next_is_init, + }, + [B, T], + ), + }, + [B, T], + ) + + if module == "lstm": + recurrent_module = LSTMModule( + input_size=obs_dim, + hidden_size=hidden, + num_layers=1, + in_keys=["observation", "rs_h", "rs_c"], + out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")], + python_based=True, + recurrent_backend="pad", + dropout=0, + ) + else: + recurrent_module = GRUModule( + input_size=obs_dim, + hidden_size=hidden, + num_layers=1, + in_keys=["observation", "rs_h"], + out_keys=["intermediate", ("next", "rs_h")], + python_based=True, + recurrent_backend="pad", + dropout=0, + ) + recurrent_module.eval() + value_net = Seq( + recurrent_module, + Mod( + nn.Linear(hidden, 1), in_keys=["intermediate"], out_keys=["state_value"] + ), + ) + + gae_unshifted = GAE( + gamma=0.99, + lmbda=0.95, + value_network=value_net, + shifted=False, + deactivate_vmap=True, + average_gae=False, + ) + gae_compact = GAE( + gamma=0.99, + lmbda=0.95, + value_network=value_net, + shifted="compact", + compact_cat_dim=compact_cat_dim, + deactivate_vmap=False, + average_gae=False, + ) + with set_recurrent_mode(True), torch.no_grad(): + adv_unshifted = gae_unshifted(td.clone())["advantage"] + adv_compact = gae_compact(td.clone())["advantage"] + # Tolerance is generous because the recurrent value net has its own + # set of mild approximations (legacy/False stack-and-vmap; compact + # single-call with boundary overrides). The bound here is the level + # at which we have empirically observed the Isaac PPO run diverge + # from the shifted=False baseline; values above ~5% mean-rel-err + # corresponded to a ~20% relative reward shortfall at iter 1000 on + # Isaac-Ant. See the wandb runs cited above. + mean_abs_diff = (adv_compact - adv_unshifted).abs().mean() + mean_unshifted_mag = adv_unshifted.abs().mean().clamp_min(1e-6) + rel = mean_abs_diff / mean_unshifted_mag + assert rel < 0.05, ( + f"shifted='compact' advantage diverges from shifted=False by " + f"mean rel-err={float(rel):.4f} on the Isaac-shaped fixture. " + "This indicates the compact path's _call_value_net_compact is " + "not overriding internal-done positions of `data_in` with the " + "env-returned `('next', obs)` even when it is populated, so the " + "bootstrap value at every truncation step is computed against " + "the post-reset observation instead of the true truncation " + "observation. Bootstraps for truncations are not masked by " + "GAE's (1 - terminated) factor on Isaac-Ant (where every " + "episode boundary is a truncation), so the bias propagates " + "into the value target." + ) + @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("gamma", [0.1, 0.5, 0.99]) @pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99]) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index cff513c77cb..2c1443b6967 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -239,10 +239,16 @@ def __init__( device: torch.device | None = None, deactivate_vmap: bool = False, value_chunk_size: int | None = None, + compact_cat_dim: Literal["batch", "time"] = "batch", ): super().__init__() if device is None: device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() + if compact_cat_dim not in ("batch", "time"): + raise ValueError( + "compact_cat_dim must be one of 'batch' or 'time', " + f"got {compact_cat_dim!r}." + ) # this is saved for tracking only and should not be used to cast anything else than buffers during # init. self._device = device @@ -250,6 +256,7 @@ def __init__( self.differentiable = differentiable self.deactivate_vmap = deactivate_vmap self.value_chunk_size = value_chunk_size + self.compact_cat_dim = compact_cat_dim self.skip_existing = skip_existing self.__dict__["value_network"] = value_network self.dep_keys = {} @@ -538,27 +545,24 @@ def _call_value_net_compact( ) -> tuple[torch.Tensor, torch.Tensor]: """Compact single-call path: constant-shape value-net call. - Always runs ``value_net`` once on a ``[..., T+1]`` input along the - time dim. The boundary slot at ``T`` is filled per value-net - in-key with the first available of: + Always runs ``value_net`` once. The root and ``("next", ...)`` streams + are concatenated along either a non-time batch dimension or the time + dimension according to ``compact_cat_dim``, so populated next + observations are evaluated directly without a second value-network + call. The boundary slot at ``T-1`` on the next side is filled per + value-net in-key with the first available of: 1. ``("next", k)[..., T-1:T]``: env-returned "next" value at the last step. Present whenever the rollout is collected without ``compact_obs=True``. 2. A duplicate of ``root[T-1]``: smoothness proxy used when ``("next", k)`` is unavailable (e.g. Isaac with - ``compact_obs=True``). + ``compact_obs=True``), supplied by + :meth:`_fill_missing_next_inputs`. - The duplication in case 2 is intentional for recurrent value - nets: feeding the RNN one extra step with the duplicated obs - advances the hidden state, so ``V(next_obs[T-1])`` is - approximated by ``f(obs_{T-1}, h_T)`` — not by a scalar copy of - ``V_T = f(obs_{T-1}, h_{T-1})``. The two coincide for - non-recurrent value nets. - - For ``t < T-1``, ``V(next_obs[t])`` is taken as ``V(obs[t+1])`` — - exact at non-trajectory-boundary steps, small bias at internal - truncations. + For recurrent value nets, ``("next", "is_init")`` is OR-ed with + the root ``is_init`` so the RNN resets at every trajectory + boundary, matching the ``shifted=False`` reference. Shape and code path are constant within a training run (the collector config determines availability of ``("next", ...)`` @@ -573,30 +577,43 @@ def _call_value_net_compact( time_idx = ndim - 1 T = data.shape[time_idx] root_part = data.select(*in_keys, value_key, strict=False) - boundary_index = (slice(None),) * time_idx + (slice(T - 1, T),) - next_root = data.get("next", default=None) - boundary_overrides: dict = {} - for k in in_keys: - if root_part.get(k, default=None) is None: - continue - if next_root is not None: - nv = next_root.get(k, default=None) - if nv is not None: - boundary_overrides[k] = nv[boundary_index] - continue - boundary_part = root_part[boundary_index].copy() - for k, v in boundary_overrides.items(): - boundary_part.set(k, v) - data_in = torch.cat([root_part, boundary_part], dim=time_idx) + next_part = data.get("next").select(*in_keys, value_key, strict=False) + next_part = self._fill_missing_next_inputs(next_part, root_part, in_keys) + next_part = next_part.copy() + if "is_init" in root_part.keys() and "is_init" in next_part.keys(): + next_part["is_init"] = next_part["is_init"] | root_part["is_init"] + if self.compact_cat_dim == "batch": + added_batch_dim = time_idx == 0 + cat_dim = 0 + if added_batch_dim: + root_part = root_part.unsqueeze(0) + next_part = next_part.unsqueeze(0) + time_idx = 1 + data_in = torch.cat([root_part, next_part], dim=cat_dim) + else: + if "is_init" in next_part.keys(): + first_index = (slice(None),) * time_idx + (slice(0, 1),) + next_is_init = next_part["is_init"].clone() + next_is_init[first_index] = True + next_part["is_init"] = next_is_init + data_in = torch.cat([root_part, next_part], dim=time_idx) if params is not None: with params.to_module(value_net): values_full = _call_value_net(data_in) else: values_full = _call_value_net(data_in) - root_idx = (slice(None),) * time_idx + (slice(0, T),) - next_idx = (slice(None),) * time_idx + (slice(1, T + 1),) - value = values_full[root_idx] - value_ = values_full[next_idx] + if self.compact_cat_dim == "batch": + batch_root = root_part.shape[cat_dim] + value = values_full[:batch_root] + value_ = values_full[batch_root : 2 * batch_root] + if added_batch_dim: + value = value.squeeze(0) + value_ = value_.squeeze(0) + else: + root_idx = (slice(None),) * time_idx + (slice(0, T),) + next_idx = (slice(None),) * time_idx + (slice(T, 2 * T),) + value = values_full[root_idx] + value_ = values_full[next_idx] done = data.get(("next", "done"), default=None) if done is not None: try: @@ -838,6 +855,10 @@ class TD0Estimator(ValueEstimatorBase): value_chunk_size (int, optional): if set, splits value-network calls into chunks of this many elements along the leading dimension. Defaults to ``None``. + compact_cat_dim ("batch" or "time", optional): layout used by + ``shifted="compact"``. ``"batch"`` concatenates root and next + streams along a non-time batch dimension. ``"time"`` concatenates + them along the time dimension. Defaults to ``"batch"``. """ @@ -856,6 +877,7 @@ def __init__( device: torch.device | None = None, deactivate_vmap: bool = False, value_chunk_size: int | None = None, + compact_cat_dim: Literal["batch", "time"] = "batch", ): super().__init__( value_network=value_network, @@ -868,6 +890,7 @@ def __init__( device=device, deactivate_vmap=deactivate_vmap, value_chunk_size=value_chunk_size, + compact_cat_dim=compact_cat_dim, ) self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) self.average_rewards = average_rewards @@ -1086,6 +1109,10 @@ class TD1Estimator(ValueEstimatorBase): value_chunk_size (int, optional): if set, splits value-network calls into chunks of this many elements along the leading dimension. Defaults to ``None``. + compact_cat_dim ("batch" or "time", optional): layout used by + ``shifted="compact"``. ``"batch"`` concatenates root and next + streams along a non-time batch dimension. ``"time"`` concatenates + them along the time dimension. Defaults to ``"batch"``. """ @@ -1105,6 +1132,7 @@ def __init__( time_dim: int | None = None, deactivate_vmap: bool = False, value_chunk_size: int | None = None, + compact_cat_dim: Literal["batch", "time"] = "batch", ): super().__init__( value_network=value_network, @@ -1117,6 +1145,7 @@ def __init__( device=device, deactivate_vmap=deactivate_vmap, value_chunk_size=value_chunk_size, + compact_cat_dim=compact_cat_dim, ) self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) self.average_rewards = average_rewards @@ -1341,6 +1370,10 @@ class TDLambdaEstimator(ValueEstimatorBase): value_chunk_size (int, optional): if set, splits value-network calls into chunks of this many elements along the leading dimension. Defaults to ``None``. + compact_cat_dim ("batch" or "time", optional): layout used by + ``shifted="compact"``. ``"batch"`` concatenates root and next + streams along a non-time batch dimension. ``"time"`` concatenates + them along the time dimension. Defaults to ``"batch"``. """ @@ -1362,6 +1395,7 @@ def __init__( time_dim: int | None = None, deactivate_vmap: bool = False, value_chunk_size: int | None = None, + compact_cat_dim: Literal["batch", "time"] = "batch", ): super().__init__( value_network=value_network, @@ -1374,6 +1408,7 @@ def __init__( device=device, deactivate_vmap=deactivate_vmap, value_chunk_size=value_chunk_size, + compact_cat_dim=compact_cat_dim, ) self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) self.register_buffer("lmbda", torch.tensor(lmbda, device=self._device)) @@ -1636,6 +1671,10 @@ class GAE(ValueEstimatorBase): value_chunk_size (int, optional): if set, splits value-network calls into chunks of this many elements along the leading dimension. Defaults to ``None``. + compact_cat_dim ("batch" or "time", optional): layout used by + ``shifted="compact"``. ``"batch"`` concatenates root and next + streams along a non-time batch dimension. ``"time"`` concatenates + them along the time dimension. Defaults to ``"batch"``. GAE will return an :obj:`"advantage"` entry containing the advantage value. It will also return a :obj:`"value_target"` entry with the return value that is to be used @@ -1682,6 +1721,7 @@ def __init__( auto_reset_env: bool = False, deactivate_vmap: bool = False, value_chunk_size: int | None = None, + compact_cat_dim: Literal["batch", "time"] = "batch", ): super().__init__( shifted=shifted, @@ -1694,6 +1734,7 @@ def __init__( device=device, deactivate_vmap=deactivate_vmap, value_chunk_size=value_chunk_size, + compact_cat_dim=compact_cat_dim, ) self.register_buffer( "gamma", @@ -2038,6 +2079,10 @@ class VTrace(ValueEstimatorBase): value_chunk_size (int, optional): if set, splits value-network calls into chunks of this many elements along the leading dimension. Defaults to ``None``. + compact_cat_dim ("batch" or "time", optional): layout used by + ``shifted="compact"``. ``"batch"`` concatenates root and next + streams along a non-time batch dimension. ``"time"`` concatenates + them along the time dimension. Defaults to ``"batch"``. VTrace will return an :obj:`"advantage"` entry containing the advantage value. It will also return a :obj:`"value_target"` entry with the V-Trace target value. @@ -2067,6 +2112,7 @@ def __init__( device: torch.device | None = None, time_dim: int | None = None, value_chunk_size: int | None = None, + compact_cat_dim: Literal["batch", "time"] = "batch", ): super().__init__( shifted=shifted, @@ -2078,6 +2124,7 @@ def __init__( skip_existing=skip_existing, device=device, value_chunk_size=value_chunk_size, + compact_cat_dim=compact_cat_dim, ) if not isinstance(gamma, torch.Tensor): gamma = torch.tensor(gamma, device=self._device)