diff --git a/test/objectives/test_values.py b/test/objectives/test_values.py index 45ae59bba87..a1fa01e651b 100644 --- a/test/objectives/test_values.py +++ b/test/objectives/test_values.py @@ -254,6 +254,94 @@ def test_shifted_gae_accepts_noncanonical_strides(self): assert torch.isfinite(out["advantage"]).all() assert torch.isfinite(out["value_target"]).all() + @pytest.mark.parametrize( + "estimator_cls,kwargs", + [ + (TD0Estimator, {"gamma": 0.9}), + (TD1Estimator, {"gamma": 0.9}), + (TDLambdaEstimator, {"gamma": 0.9, "lmbda": 0.95}), + (GAE, {"gamma": 0.9, "lmbda": 0.95}), + ], + ) + def test_final_obs_bootstrap_shifted(self, estimator_cls, kwargs): + """``("final", obs)`` carries the true bootstrap obs at the window edge. + + Without it, shifted-GAE under ``compact_obs=True`` falls back to + ``V(s_{T-1})`` via :meth:`_fill_missing_next_inputs`, corrupting the + last step's advantage when the window boundary is not a real done. + The fix overrides those positions with the values carried under + ``("final", obs)`` and matches the non-compact reference exactly. + + Also verifies the consumer drops ``("final", ...)`` from the returned + tensordict so it survives a contiguous-storage replay buffer. + """ + from tensordict import UnbatchedTensor + + torch.manual_seed(0) + value_net = TensorDictModule( + nn.Linear(3, 1, bias=False), + in_keys=["obs"], + out_keys=["state_value"], + ) + B, T, F = 2, 5, 3 + obs = torch.randn(B, T, F) + # No real done inside the window: the last step is a "soft" boundary. + done = torch.zeros(B, T, 1, dtype=torch.bool) + reward = torch.ones(B, T, 1) + # The "true" next obs after step T-1 (one per env, no time dim). + final_obs = torch.randn(B, F) + + # Reference: full ('next', obs) at every step. + next_obs_full = torch.empty(B, T, F) + next_obs_full[:, :-1] = obs[:, 1:] + next_obs_full[:, -1] = final_obs + td_ref = TensorDict( + { + "obs": obs, + "next": { + "obs": next_obs_full, + "reward": reward, + "done": done.clone(), + "terminated": done.clone(), + }, + }, + [B, T], + ) + + # Compact + final_obs: no ('next', obs) but a ('final', obs) UnbatchedTensor. + td_compact = TensorDict( + { + "obs": obs, + "next": { + "reward": reward, + "done": done.clone(), + "terminated": done.clone(), + }, + "final": TensorDict( + {"obs": UnbatchedTensor(final_obs)}, + batch_size=(B, T), + ), + }, + [B, T], + ) + + est = estimator_cls(**kwargs, value_network=value_net, shifted=True) + out_ref = est(td_ref.clone()) + out_compact = est(td_compact.clone()) + + # Must match the non-compact reference exactly at the boundary + # (and everywhere else). + torch.testing.assert_close(out_compact["advantage"], out_ref["advantage"]) + torch.testing.assert_close(out_compact["value_target"], out_ref["value_target"]) + + # Drop must have happened — the rollout is now safe to extend into a + # contiguous-storage RB. + out_inplace = td_compact.clone() + est(out_inplace) + assert ( + "final" not in out_inplace.keys() + ), "('final', ...) should have been consumed and dropped" + @pytest.mark.skipif(not _has_gym, reason="requires gym") def test_gae_multi_done(self): diff --git a/test/test_collectors.py b/test/test_collectors.py index e6f054e41c0..11c1dbfffa3 100644 --- a/test/test_collectors.py +++ b/test/test_collectors.py @@ -1425,6 +1425,117 @@ def make_env(): root_shifted = ref_data.get(k)[..., 1:, :] torch.testing.assert_close(ref_next[mask], root_shifted[mask]) + def test_final_obs_requires_compact_obs(self): + """``final_obs=True`` without ``compact_obs=True`` must raise.""" + + def make_env(): + return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker()) + + with pytest.raises(ValueError, match="requires compact_obs=True"): + Collector( + create_env_fn=make_env, + policy=RandomPolicy(make_env().action_spec), + frames_per_batch=10, + total_frames=10, + compact_obs=False, + final_obs=True, + ) + + @pytest.mark.parametrize("use_buffers", [True, False]) + def test_final_obs_matches_compact_off(self, use_buffers): + """``final_obs=True`` carries the same boundary obs as a non-compact run. + + Runs two rollouts with identical seeds: one with + ``compact_obs=False`` (full ``('next', obs)`` retained at every step) + and one with ``compact_obs=True, final_obs=True`` (boundary obs + stored under ``('final', obs)`` as + :class:`~tensordict.UnbatchedTensor`). The two must agree on the + boundary obs (non-done envs only — at done envs the bootstrap is + masked downstream so the value there is unconstrained). + """ + from tensordict import UnbatchedTensor + + def make_env(): + return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker()) + + dummy_env = make_env() + obs_keys = list(dummy_env._observation_keys_step_mdp) + dummy_env.close() + + def run(compact, final): + torch.manual_seed(0) + return Collector( + create_env_fn=make_env, + policy=RandomPolicy(make_env().action_spec), + frames_per_batch=20, + total_frames=20, + use_buffers=use_buffers, + compact_obs=compact, + final_obs=final, + ) + + ref = run(False, False) + ref_data = next(iter(ref)) + ref.shutdown() + del ref + + comp = run(True, True) + comp_data = next(iter(comp)) + comp.shutdown() + del comp + + # Batch shape preserved (no time dim leakage from the UnbatchedTensor). + assert comp_data.batch_size == ref_data.batch_size + + # ('final', k) is present and is an UnbatchedTensor. + for k in obs_keys: + full_final = ("final", *k) if isinstance(k, tuple) else ("final", k) + assert full_final in comp_data.keys(True, True), f"missing {full_final}" + val = comp_data.get(full_final) + assert isinstance(val, UnbatchedTensor), type(val) + + # Compare against the reference's ('next', k) at the last step, + # masked to non-done envs. + full_next = ("next", *k) if isinstance(k, tuple) else ("next", k) + ref_last_next = ref_data.get(full_next)[..., -1, :] + done_last = ref_data.get(("next", "done"))[..., -1, :].squeeze(-1) + mask = ~done_last + torch.testing.assert_close(val[mask], ref_last_next[mask]) + + def test_final_obs_unbatched_survives_indexing(self): + """The ``("final", obs)`` UnbatchedTensor must not collapse on reshape. + + Closes (a): if the leaf were a regular tensor, indexing or reshaping + the rollout along the time axis would either drop the time dim or + propagate a shape mismatch into a contiguous-storage replay buffer. + """ + from tensordict import UnbatchedTensor + + def make_env(): + return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker()) + + c = Collector( + create_env_fn=make_env, + policy=RandomPolicy(make_env().action_spec), + frames_per_batch=20, + total_frames=20, + compact_obs=True, + final_obs=True, + ) + data = next(iter(c)) + c.shutdown() + + original = data.get(("final", "observation")) + # Slicing along time must preserve the same underlying tensor. + sliced = data[..., :5].get(("final", "observation")) + assert isinstance(sliced, UnbatchedTensor) + assert torch.equal(original, sliced) + # exclude("final") must yield a td whose batch shape reshapes cleanly. + without = data.exclude("final") + assert "final" not in without.keys() + flat = without.reshape(-1) + assert flat.batch_size.numel() == 20 + @pytest.mark.parametrize("env_class", [CountingEnv, CountingBatchedEnv]) def test_initial_obs_consistency(self, env_class, seed=1): # non regression test on #938 diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index 01928b751a3..695e8416a49 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -329,6 +329,13 @@ class MultiCollector(BaseCollector, metaclass=_MultiCollectorMeta): :class:`~torchrl.envs.transforms.rb_transforms.NextStateReconstructor` at sampling time. Defaults to ``False``. + final_obs (bool, optional): if ``True`` (requires ``compact_obs=True``), + each worker additionally stores the true next-observation reached + after the last step of its rollout under ``("final", k)`` as an + :class:`tensordict.UnbatchedTensor`. Closes the shifted-GAE + bootstrap-correctness gap at window boundaries. See + :class:`~torchrl.collectors.SyncDataCollector` for details. + Defaults to ``False``. worker_idx (int, optional): the index of the worker. Examples: @@ -416,6 +423,7 @@ def __init__( pre_collect_hook: Callable[[], None] | None = None, post_collect_hook: Callable[[TensorDictBase], None] | None = None, compact_obs: bool = False, + final_obs: bool = False, ): self.closed = True self.worker_idx = worker_idx @@ -527,6 +535,13 @@ def __init__( self.reset_at_each_iter = reset_at_each_iter self.postproc = postproc self.compact_obs = bool(compact_obs) + self.final_obs = bool(final_obs) + if self.final_obs and not self.compact_obs: + raise ValueError( + "final_obs=True requires compact_obs=True; otherwise the true " + "next observation is already stored at every step under " + "('next', ...)." + ) self.max_frames_per_traj = ( int(max_frames_per_traj) if max_frames_per_traj is not None else 0 ) @@ -1323,6 +1338,7 @@ def _run_processes(self) -> None: "pre_collect_hook": self._worker_pre_collect_hook, "post_collect_hook": self._worker_post_collect_hook, "compact_obs": self.compact_obs, + "final_obs": self.final_obs, } proc = _ProcessNoWarnCtx( target=_main_async_collector, diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index ed30337176e..5682fc2c45c 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -71,6 +71,7 @@ def _main_async_collector( pre_collect_hook: Callable[[], None] | None = None, post_collect_hook: Callable[[TensorDictBase], None] | None = None, compact_obs: bool = False, + final_obs: bool = False, ) -> None: # Process-level initialisation hook (e.g. Isaac Lab ``AppLauncher``). # Runs before any CUDA/torchrl work in the child process. @@ -142,6 +143,7 @@ def _main_async_collector( pre_collect_hook=pre_collect_hook, post_collect_hook=post_collect_hook, compact_obs=compact_obs, + final_obs=final_obs, ) # Set up weight receivers for worker process using the standard register_scheme_receiver API. # This properly initializes the schemes on the receiver side and stores them in _receiver_schemes. diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py index e3fe344d1a7..4aa0149eaa0 100644 --- a/torchrl/collectors/_single.py +++ b/torchrl/collectors/_single.py @@ -11,7 +11,12 @@ import torch -from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase +from tensordict import ( + LazyStackedTensorDict, + TensorDict, + TensorDictBase, + UnbatchedTensor, +) from tensordict.nn import CudaGraphModule, TensorDictModule, TensorDictModuleBase from torch import nn from torchrl import compile_with_warmup @@ -275,7 +280,49 @@ class Collector(BaseCollector): keys can be re-hydrated at sampling time with :class:`~torchrl.envs.transforms.rb_transforms.NextStateReconstructor` when consuming a :class:`~torchrl.data.SliceSampler`-backed replay - buffer. Defaults to ``False``. + buffer. + + ``compact_obs=True`` composes cleanly with + :class:`~torchrl.objectives.value.advantages.GAE` configured with + ``shifted=True``: shifted-GAE only needs the value at the boundary + between steps, which it reads via the root key of the next step + rather than the ``("next", "observation")`` mirror, so no rehydration + is required for the on-policy advantage pass. For vectorized + environments with large observations this is typically a sizeable + GPU-memory win at near-zero CPU cost. Defaults to ``False``. + final_obs (bool, optional): if ``True`` (and ``compact_obs=True``), the + collector additionally stores the true next-observation reached + after the last step of the rollout under a top-level ``("final", k)`` + sub-tensordict for each observation/state key ``k`` that was + compacted away. The value is wrapped in + :class:`tensordict.UnbatchedTensor` (one obs per env, no time + dimension) so the rollout's batch shape ``[*envs, T]`` is preserved. + + .. warning:: ``final_obs`` is experimental and may change or be + removed without notice. In practice, bootstrapping the last + step with the root observation of the same step (the default + fallback under ``compact_obs=True``) works well; ``final_obs`` + exists for correctness when the bias matters. + + This closes the bootstrap-correctness gap when running with short + rollout windows: under ``compact_obs=True``, the ``("next", obs)`` + of the very last step of each window is dropped, and a shifted + value estimator (e.g. :class:`~torchrl.objectives.value.GAE`, + :class:`~torchrl.objectives.value.TD0Estimator`, + :class:`~torchrl.objectives.value.TD1Estimator`, + :class:`~torchrl.objectives.value.TDLambdaEstimator`, + :class:`~torchrl.objectives.value.VTrace`) falls back to + bootstrapping ``V(s_T) ≈ V(s_{T-1})`` for that step (a 1/T + fraction of corruption). With ``final_obs=True``, the value + estimator reads the true ``s_T`` from ``("final", obs)`` instead. + + The pipeline assumption is: + ``collector -> value_estimator(shifted=True) -> ReplayBuffer.extend()``. + The value estimator consumes and drops ``("final", ...)`` from + the returned tensordict, so the downstream replay buffer never + sees an :class:`~tensordict.UnbatchedTensor` (which would + otherwise be incompatible with a contiguous storage). + Defaults to ``False``. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -383,6 +430,7 @@ def __init__( pre_collect_hook: Callable[[], None] | None = None, post_collect_hook: Callable[[TensorDictBase], None] | None = None, compact_obs: bool = False, + final_obs: bool = False, **kwargs, ): self.closed = True @@ -475,7 +523,11 @@ def __init__( # Set up compact_obs: keys to drop from ("next", ...) to avoid # storing two copies of each observation. See `_setup_compact_obs`. - self._setup_compact_obs(compact_obs) + # `final_obs` (requires `compact_obs`) additionally captures the true + # next-observation after the last step as an UnbatchedTensor under + # ("final", ...) so GAE(shifted=True) can bootstrap correctly without + # re-storing every step's next obs. + self._setup_compact_obs(compact_obs, final_obs) # Calculate frames per batch self._setup_frames_per_batch(frames_per_batch) @@ -512,6 +564,7 @@ def __init__( # Create shuttle and rollout buffers self._make_shuttle() self._maybe_make_final_rollout(make_rollout=self._use_buffers) + self._setup_final_obs_buffer() self._set_truncated_keys() # Set up interruptor and frame tracking @@ -945,7 +998,7 @@ def _setup_postproc(self, postproc: Callable | None) -> None: if postproc is not self.postproc and postproc is not None: self.postproc = postproc - def _setup_compact_obs(self, compact_obs: bool) -> None: + def _setup_compact_obs(self, compact_obs: bool, final_obs: bool) -> None: """Resolve the ``("next", ...)`` keys to drop when ``compact_obs=True``. When enabled, the collector drops the observation and state keys from @@ -959,21 +1012,68 @@ def _setup_compact_obs(self, compact_obs: bool) -> None: The user can re-hydrate the dropped keys at sampling time with :class:`~torchrl.envs.transforms.rb_transforms.NextStateReconstructor` when consuming a ``SliceSampler``-backed replay buffer. + + If ``final_obs=True`` (requires ``compact_obs=True``), the corresponding + ``("final", ...)`` keys are resolved so the rollout can carry the true + next-observation after the last step as an + :class:`tensordict.UnbatchedTensor` (no time dim). """ self.compact_obs = bool(compact_obs) + self.final_obs = bool(final_obs) + if self.final_obs and not self.compact_obs: + raise ValueError( + "final_obs=True requires compact_obs=True; otherwise the true " + "next observation is already stored at every step under " + "('next', ...)." + ) if not self.compact_obs: self._compact_next_keys: tuple = () + self._final_obs_keys: tuple = () return leaf_keys = list(self.env._observation_keys_step_mdp) + list( self.env._state_keys_step_mdp ) compact: list[tuple] = [] + final: list[tuple] = [] for k in leaf_keys: if isinstance(k, tuple): compact.append(("next", *k)) + final.append(("final", *k)) else: compact.append(("next", k)) + final.append(("final", k)) self._compact_next_keys = tuple(compact) + self._final_obs_keys = tuple(final) if self.final_obs else () + + def _setup_final_obs_buffer(self) -> None: + """Allocate the per-env buffer that holds the final next obs/state. + + When ``final_obs=True``, the collector keeps a persistent buffer of + shape ``[*env.batch_size, *leaf_shape]`` (one entry per env, no time + dim) for each obs/state leaf that was compacted away. The buffer is + updated in-place from ``self._carrier["next"]`` each step; the final + value (after the last step) is wrapped in + :class:`tensordict.UnbatchedTensor` and attached to the rollout under + ``("final", *leaf)``. + """ + if not self.final_obs: + self._final_obs_buffer = None + return + with torch.no_grad(): + fake = self.env.fake_tensordict() + nxt = fake.get("next") + # Read inside ("next", ...). Single-element paths must be unwrapped + # to a plain str for TensorDict.get/select to behave consistently. + leaf_paths = [] + for compact_k in self._compact_next_keys: + leaf = compact_k[1:] + leaf_paths.append(leaf[0] if len(leaf) == 1 else leaf) + buf = nxt.select(*leaf_paths, strict=False).clone() + if self.storing_device is not None: + buf = buf.to(self.storing_device, non_blocking=True) + else: + buf.clear_device_() + self._final_obs_buffer = buf def _setup_frames_per_batch(self, frames_per_batch: int) -> None: """Calculate and validate frames per batch.""" @@ -1746,7 +1846,20 @@ def rollout(self) -> TensorDictBase: # When compact_obs is enabled, drop the obs/state keys from # ("next", ...) before persisting the per-step td. The dropped # keys are recoverable from the root keys of the next step. + # If final_obs is also on, snapshot the true next obs/state + # into the side buffer *before* the drop, so the last + # iteration's snapshot becomes the rollout's final obs. if self._compact_next_keys: + if self.final_obs: + nxt = self._carrier.get("next") + for compact_k in self._compact_next_keys: + leaf = compact_k[1:] + leaf = leaf[0] if len(leaf) == 1 else leaf + dst = self._final_obs_buffer.get(leaf) + src = nxt.get(leaf) + if dst.device != src.device: + src = src.to(dst.device, non_blocking=True) + dst.copy_(src, non_blocking=True) carrier_for_out = self._carrier.exclude(*self._compact_next_keys) else: carrier_for_out = self._carrier @@ -1855,8 +1968,33 @@ def rollout(self) -> TensorDictBase: result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) result.refine_names(..., "time") + result = self._maybe_attach_final_obs(result) return self._maybe_set_truncated(result) + def _maybe_attach_final_obs(self, result): + """Attach ``("final", *leaf)`` entries on ``result`` as UnbatchedTensor. + + No-op unless ``final_obs=True``. Wraps the tensors of the side buffer + (which holds the most-recent next-obs/state, in-place updated by the + rollout loop) into :class:`tensordict.UnbatchedTensor` so the + rollout's batch shape ``[*envs, T]`` is preserved. Downstream, + :class:`~torchrl.objectives.value.advantages.GAE` (shifted) consumes + and drops these before the rollout reaches the replay buffer. + """ + if not self.final_obs: + return result + for final_k in self._final_obs_keys: + leaf = final_k[1:] + buf_path = leaf[0] if len(leaf) == 1 else leaf + val = self._final_obs_buffer.get(buf_path) + wrapped = UnbatchedTensor(val) + if result.is_locked: + with result.unlock_(): + result.set(final_k, wrapped) + else: + result.set(final_k, wrapped) + return result + def _maybe_set_truncated(self, final_rollout): last_step = (slice(None),) * (final_rollout.ndim - 1) + (-1,) for truncated_key in self._truncated_keys: diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 557b24453f3..93c0169f14b 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -1562,6 +1562,15 @@ def get(self, indices: slice) -> TensorDictBase | torch.Tensor | Any: class LazyMemmapStorage(LazyTensorStorage): """A memory-mapped storage for tensors and tensordicts. + .. note:: ``LazyMemmapStorage`` trades sampling throughput for GPU memory: + keeping the buffer off-device avoids the GPU-memory pressure of a + large ``LazyTensorStorage(..., device='cuda')``, but every sample + incurs a host-to-device copy (and a disk read on cold pages). For + vectorized RL with large rollouts this is often the right call when + the buffer would otherwise dominate GPU memory; for tight inner loops + where the buffer fits comfortably on device, ``LazyTensorStorage`` + on the training device is usually faster. + Args: max_size (int): size of the storage, i.e. maximum number of elements stored in the buffer. diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index a918676ae8d..f76ed4e3b89 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -491,6 +491,82 @@ def _fill_missing_next_inputs( next_data.set(key, value) return next_data + @staticmethod + def _apply_final_obs_to_next_done( + next_done_data: TensorDictBase, + data: TensorDictBase, + done_view: torch.Tensor, + in_keys: list[NestedKey], + ndim: int, + ) -> TensorDictBase: + """Override last-step entries of ``next_done_data`` with ``("final", k)``. + + When the collector was configured with ``final_obs=True`` and + ``compact_obs=True``, the rollout carries a top-level ``("final", k)`` + sub-tensordict (with :class:`~tensordict.UnbatchedTensor` leaves) that + holds the true next observation reached after the last step. Without + this fix, shifted-GAE bootstraps the last step's + ``V(s_{T+1})`` from the *root* obs of step ``T-1`` (via + :meth:`_fill_missing_next_inputs`) — a 1/T fraction of corruption. + + This helper substitutes those entries with the true ``("final", k)`` + values for each in-key that has one. No-op when ``("final", k)`` is + absent. + """ + final_root = data.get("final", default=None) + if final_root is None: + return next_done_data + # Identify which of the K done positions are the synthetic last-step done. + # In flat order, last-step positions of a (..., T) tensor are + # ``i*T + (T-1)`` for i in range(B_total). + batch_dims = data.batch_size[: ndim - 1] + T = data.batch_size[ndim - 1] + B_total = 1 + for d in batch_dims: + B_total *= int(d) + if B_total == 0 or T == 0: + return next_done_data + device = done_view.device + last_step_positions = torch.arange(B_total, device=device) * T + (T - 1) + last_step_mask_flat = torch.zeros(B_total * T, dtype=torch.bool, device=device) + last_step_mask_flat[last_step_positions] = True + last_step_in_done = last_step_mask_flat[done_view] + if not bool(last_step_in_done.any()): + return next_done_data + copied = False + for key in in_keys: + key_tuple = (key,) if isinstance(key, str) else tuple(key) + final_val = final_root.get( + key_tuple[0] if len(key_tuple) == 1 else key_tuple, default=None + ) + if final_val is None: + continue + nxt_k = next_done_data.get(key, default=None) + if nxt_k is None: + continue + obs_shape = final_val.shape[ndim - 1 :] + final_flat = final_val.reshape(-1, *obs_shape) + if not copied: + next_done_data = next_done_data.copy() + copied = True + nxt_k = nxt_k.clone() + nxt_k[last_step_in_done] = final_flat + next_done_data.set(key, nxt_k) + return next_done_data + + @staticmethod + def _maybe_drop_final_obs(tensordict: TensorDictBase) -> None: + """Drop the ``("final", ...)`` sub-tensordict from ``tensordict``. + + The collector writes this sub-tensordict with + :class:`~tensordict.UnbatchedTensor` leaves when ``final_obs=True`` so + :meth:`_apply_final_obs_to_next_done` can bootstrap shifted-GAE + correctly. After consumption the entries are stale and would block + contiguous-storage replay buffers, so we strip them here in-place. + """ + if "final" in tensordict.keys(): + del tensordict["final"] + def _call_value_nets( self, data: TensorDictBase, @@ -581,6 +657,13 @@ def _call_value_net(data_in: TensorDictBase) -> torch.Tensor: next_done_data = self._fill_missing_next_inputs( next_done_data, root_done_data, in_keys ) + # When the collector ran with `final_obs=True`, override the + # synthetic-done last-step entries with the true next obs + # carried under ("final", k), otherwise shifted-GAE would + # bootstrap with V(s_{T-1}) at every window boundary. + next_done_data = self._apply_final_obs_to_next_done( + next_done_data, data, done_view, in_keys, ndim + ) data_in[indices_interleaved] = next_done_data if next_params is not None and next_params is not params: raise ValueError( @@ -828,6 +911,7 @@ def forward( value_target = self.value_estimate(tensordict, next_value=next_value) tensordict.set(self.tensor_keys.advantage, value_target - value) tensordict.set(self.tensor_keys.value_target, value_target) + self._maybe_drop_final_obs(tensordict) return tensordict def value_estimate( @@ -1061,6 +1145,7 @@ def forward( tensordict.set(self.tensor_keys.advantage, value_target - value) tensordict.set(self.tensor_keys.value_target, value_target) + self._maybe_drop_final_obs(tensordict) return tensordict def value_estimate( @@ -1313,6 +1398,7 @@ def forward( tensordict.set(self.tensor_keys.advantage, value_target - value) tensordict.set(self.tensor_keys.value_target, value_target) + self._maybe_drop_final_obs(tensordict) return tensordict def value_estimate( @@ -1692,6 +1778,7 @@ def forward( tensordict.set(self.tensor_keys.advantage, adv) tensordict.set(self.tensor_keys.value_target, value_target) + self._maybe_drop_final_obs(tensordict) return tensordict @@ -2071,6 +2158,7 @@ def forward( tensordict.set(self.tensor_keys.advantage, adv) tensordict.set(self.tensor_keys.value_target, value_target) + self._maybe_drop_final_obs(tensordict) return tensordict