diff --git a/test/objectives/test_values.py b/test/objectives/test_values.py index 45ae59bba87..9e19e2ccfb8 100644 --- a/test/objectives/test_values.py +++ b/test/objectives/test_values.py @@ -378,6 +378,83 @@ def test_gae_recurrent(self, module, vectorized): a1 = r1["advantage"] torch.testing.assert_close(a0, a1) + def _build_shifted_test_td(self, *, with_internal_done: bool): + """Build a rollout-shaped tensordict for shifted-mode tests.""" + B, T, obs_dim = 4, 8, 6 + all_obs = torch.randn(B, T + 1, obs_dim) + obs = all_obs[:, :T].clone() + next_obs = all_obs[:, 1:].clone() + reward = torch.randn(B, T, 1) + done = torch.zeros(B, T, 1, dtype=torch.bool) + terminated = torch.zeros(B, T, 1, dtype=torch.bool) + if with_internal_done: + done[0, 3, 0] = True + next_obs[0, 3] = torch.randn(obs_dim) + done[:, -1, 0] = True + td = TensorDict( + { + "observation": obs, + "next": TensorDict( + { + "observation": next_obs, + "reward": reward, + "done": done, + "terminated": terminated, + "truncated": done & ~terminated, + }, + batch_size=[B, T], + ), + }, + batch_size=[B, T], + ) + td.refine_names(..., "time") + return td, obs_dim + + @pytest.mark.parametrize("with_internal_done", [False, True]) + def test_gae_shifted_compact_and_legacy(self, with_internal_done): + # Both shifted='compact' and shifted='legacy' must produce a valid + # advantage. 'legacy' must match shifted=False exactly. 'compact' + # is allowed a small boundary bias from copying V(obs[T-1]) at + # the rollout boundary; not asserted. + torch.manual_seed(0) + td, obs_dim = self._build_shifted_test_td(with_internal_done=with_internal_done) + value_net = TensorDictModule( + nn.Linear(obs_dim, 1), + in_keys=["observation"], + out_keys=["state_value"], + ) + gae_compact = GAE( + gamma=0.9, lmbda=0.95, value_network=value_net, shifted="compact" + ) + gae_legacy = GAE( + gamma=0.9, lmbda=0.95, value_network=value_net, shifted="legacy" + ) + gae_unshifted = GAE( + gamma=0.9, lmbda=0.95, value_network=value_net, shifted=False + ) + gae_compact(td.copy()) + adv_legacy = gae_legacy(td.copy())["advantage"] + adv_unshifted = gae_unshifted(td.copy())["advantage"] + torch.testing.assert_close(adv_legacy, adv_unshifted) + + def test_gae_shifted_true_deprecation_aliases_legacy(self): + torch.manual_seed(0) + td, obs_dim = self._build_shifted_test_td(with_internal_done=True) + value_net = TensorDictModule( + nn.Linear(obs_dim, 1), + in_keys=["observation"], + out_keys=["state_value"], + ) + with pytest.warns(DeprecationWarning, match="shifted=True is deprecated"): + gae_true = GAE(gamma=0.9, lmbda=0.95, value_network=value_net, shifted=True) + gae_legacy = GAE( + gamma=0.9, lmbda=0.95, value_network=value_net, shifted="legacy" + ) + assert gae_true.shifted == "legacy" + adv_true = gae_true(td.copy())["advantage"] + adv_legacy = gae_legacy(td.copy())["advantage"] + torch.testing.assert_close(adv_true, adv_legacy) + @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("gamma", [0.1, 0.5, 0.99]) @pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99]) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index a918676ae8d..cff513c77cb 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -11,6 +11,7 @@ from contextlib import nullcontext from dataclasses import asdict, dataclass from functools import wraps +from typing import Literal import torch from tensordict import is_tensor_collection, TensorDictBase @@ -229,7 +230,7 @@ def __init__( self, *, value_network: TensorDictModule, - shifted: bool = False, + shifted: bool | Literal["compact", "legacy"] = False, differentiable: bool = False, skip_existing: bool | None = None, advantage_key: NestedKey = None, @@ -252,7 +253,7 @@ def __init__( self.skip_existing = skip_existing self.__dict__["value_network"] = value_network self.dep_keys = {} - self.shifted = shifted + self.shifted = self._normalize_shifted(shifted) if advantage_key is not None: raise RuntimeError( @@ -491,18 +492,139 @@ def _fill_missing_next_inputs( next_data.set(key, value) return next_data + @staticmethod + def _normalize_shifted( + shifted: bool | Literal["compact", "legacy"], + ) -> Literal[False, "compact", "legacy"]: + """Normalize the ``shifted`` argument. + + ``shifted=True`` is deprecated; users must opt in explicitly to + either ``"compact"`` (compile-friendly constant-shape single call, + small bias at trajectory boundaries) or ``"legacy"`` (current + flatten/interleave path, exact ``V(next_obs)`` but variable shape). + """ + if shifted is False: + return False + if shifted is True: + warnings.warn( + "shifted=True is deprecated and will be removed in v0.15. " + "Pass shifted='legacy' to preserve the current " + "flatten/interleave behavior (exact V(next_obs), variable " + "shape, not compile-friendly), or shifted='compact' to opt " + "into the new constant-shape single-call path (small bias " + "at trajectory boundaries, compile-friendly). The default " + "for shifted=True is currently 'legacy'; this default will " + "be removed in v0.15.", + DeprecationWarning, + stacklevel=3, + ) + return "legacy" + if shifted in ("compact", "legacy"): + return shifted + raise ValueError( + f"shifted must be one of False, 'compact', 'legacy' (or the " + f"deprecated True), got {shifted!r}." + ) + + def _call_value_net_compact( + self, + data: TensorDictBase, + params: TensorDictBase | None, + next_params: TensorDictBase | None, + value_key: NestedKey, + ndim: int, + value_net: TensorDictModuleBase, + _call_value_net, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compact single-call path: constant-shape value-net call. + + Always runs ``value_net`` once on a ``[..., T+1]`` input along the + time dim. The boundary slot at ``T`` is filled per value-net + in-key with the first available of: + + 1. ``("next", k)[..., T-1:T]``: env-returned "next" value at the + last step. Present whenever the rollout is collected without + ``compact_obs=True``. + 2. A duplicate of ``root[T-1]``: smoothness proxy used when + ``("next", k)`` is unavailable (e.g. Isaac with + ``compact_obs=True``). + + The duplication in case 2 is intentional for recurrent value + nets: feeding the RNN one extra step with the duplicated obs + advances the hidden state, so ``V(next_obs[T-1])`` is + approximated by ``f(obs_{T-1}, h_T)`` — not by a scalar copy of + ``V_T = f(obs_{T-1}, h_{T-1})``. The two coincide for + non-recurrent value nets. + + For ``t < T-1``, ``V(next_obs[t])`` is taken as ``V(obs[t+1])`` — + exact at non-trajectory-boundary steps, small bias at internal + truncations. + + Shape and code path are constant within a training run (the + collector config determines availability of ``("next", ...)`` + deterministically across calls), so ``torch.compile`` specializes + once and stays specialized. + """ + if next_params is not None and next_params is not params: + raise ValueError( + "the value at t and t+1 cannot be retrieved in a single call when both params and next params are passed." + ) + in_keys = value_net.in_keys + time_idx = ndim - 1 + T = data.shape[time_idx] + root_part = data.select(*in_keys, value_key, strict=False) + boundary_index = (slice(None),) * time_idx + (slice(T - 1, T),) + next_root = data.get("next", default=None) + boundary_overrides: dict = {} + for k in in_keys: + if root_part.get(k, default=None) is None: + continue + if next_root is not None: + nv = next_root.get(k, default=None) + if nv is not None: + boundary_overrides[k] = nv[boundary_index] + continue + boundary_part = root_part[boundary_index].copy() + for k, v in boundary_overrides.items(): + boundary_part.set(k, v) + data_in = torch.cat([root_part, boundary_part], dim=time_idx) + if params is not None: + with params.to_module(value_net): + values_full = _call_value_net(data_in) + else: + values_full = _call_value_net(data_in) + root_idx = (slice(None),) * time_idx + (slice(0, T),) + next_idx = (slice(None),) * time_idx + (slice(1, T + 1),) + value = values_full[root_idx] + value_ = values_full[next_idx] + done = data.get(("next", "done"), default=None) + if done is not None: + try: + value = value.view_as(done) + value_ = value_.view_as(done) + except RuntimeError: + pass + return value, value_ + def _call_value_nets( self, data: TensorDictBase, params: TensorDictBase, next_params: TensorDictBase, - single_call: bool, + single_call: bool | Literal["compact", "legacy"], value_key: NestedKey, detach_next: bool, vmap_randomness: str = "error", *, value_net: TensorDictModuleBase | None = None, ): + # ``single_call`` is passed by callers as ``self.shifted`` and is + # one of ``False``, ``"compact"``, or ``"legacy"`` after + # normalization. ``True`` still arrives untransformed from a few + # direct callers — fold it into the legacy path for backwards + # compat (the constructor already warned). + if single_call is True: + single_call = "legacy" if value_net is None: value_net = self.value_network in_keys = value_net.in_keys @@ -529,7 +651,17 @@ def _call_value_net(data_in: TensorDictBase) -> torch.Tensor: values.append(value_net(chunk).get(value_key)) return torch.cat(values, dim=0) - if single_call: + if single_call == "compact": + value, value_ = self._call_value_net_compact( + data=data, + params=params, + next_params=next_params, + value_key=value_key, + ndim=ndim, + value_net=value_net, + _call_value_net=_call_value_net, + ) + elif single_call: # We are going to flatten the data, then interleave the last observation of each trajectory in between its # previous obs (from the root TD) and the first of the next trajectory. Eventually, each trajectory will # have T+1 elements (or, for a batch of N trajectories, we will have \Sum_{t=0}^{T-1} length_t + T @@ -650,18 +782,35 @@ class TD0Estimator(ValueEstimatorBase): gamma (scalar): exponential mean discount. value_network (TensorDictModule): value operator used to retrieve the value estimates. - shifted (bool, optional): if ``True``, the value and next value are - estimated with a single call to the value network. This is faster - but is only valid whenever (1) the ``"next"`` value is shifted by - only one time step (which is not the case with multi-step value - estimation, for instance) and (2) when the parameters used at time - ``t`` and ``t+1`` are identical (which is not the case when target - parameters are to be used). For recurrent policies or compact - rollouts, the input should contain long, contiguous trajectory - windows with valid boundary next states; short partial rollouts - that drop the final next observation can bias bootstrapping. In - that case, keep or reconstruct boundary next states, or use - ``shifted=False``. Defaults to ``False``. + shifted (bool or str, optional): controls how value and next-value + are obtained from the value network. ``False`` (default) calls + the value network twice (once on the root tensordict, once on + ``"next"``), which is correct whenever ``"next"`` may differ + non-trivially from ``obs[t+1]``. Truthy values request a single + call: + + - ``"compact"``: constant-shape single call along the time + dim of length ``T+1``. ``V(next_obs[t])`` is taken as + ``V(obs[t+1])`` for ``t