diff --git a/test/test_collectors.py b/test/test_collectors.py index 09c86c17b5d..78781824fad 100644 --- a/test/test_collectors.py +++ b/test/test_collectors.py @@ -1425,6 +1425,61 @@ def make_env(): root_shifted = ref_data.get(k)[..., 1:, :] torch.testing.assert_close(ref_next[mask], root_shifted[mask]) + @pytest.mark.parametrize("use_buffers", [True, False]) + def test_fake_tensordict_single_matches_iter(self, use_buffers): + """``Collector.fake_tensordict()`` mirrors the shape and keys of a real batch.""" + + def make_env(): + return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker()) + + c = Collector( + create_env_fn=make_env, + policy=RandomPolicy(make_env().action_spec), + frames_per_batch=20, + total_frames=20, + use_buffers=use_buffers, + ) + try: + fake = c.fake_tensordict() + torch.manual_seed(0) + real = next(iter(c)) + assert fake.batch_size == real.batch_size, ( + fake.batch_size, + real.batch_size, + ) + assert fake.names == real.names + fake_keys = sorted(map(str, fake.keys(True, True))) + real_keys = sorted(map(str, real.keys(True, True))) + assert fake_keys == real_keys, set(real_keys) ^ set(fake_keys) + for key, val in fake.items(True, True): + if ( + val.dtype in (torch.bool, torch.uint8) + or not val.is_floating_point() + ): + continue + assert not val.any(), f"{key} is not zeroed" + finally: + c.shutdown() + + def test_fake_tensordict_multi_raises(self): + """``MultiCollector.fake_tensordict()`` is intentionally not implemented.""" + + def make_env(): + return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker()) + + c = MultiCollector( + create_env_fn=[make_env, make_env], + policy=RandomPolicy(make_env().action_spec), + frames_per_batch=20, + total_frames=20, + sync=True, + ) + try: + with pytest.raises(NotImplementedError, match="fake_tensordict"): + c.fake_tensordict() + finally: + c.shutdown() + @pytest.mark.parametrize("env_class", [CountingEnv, CountingBatchedEnv]) def test_initial_obs_consistency(self, env_class, seed=1): # non regression test on #938 diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index b6b529b6677..f8f4d9be954 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -1056,6 +1056,25 @@ def _add_policy_outputs_to_fake_td(self, fake_td): fake_td.set(key, policy_output.get(key)) return fake_td + def fake_tensordict(self) -> TensorDictBase: + """Not implemented for multi-process collectors. + + Honoring the multi-collector contract here would require either + creating an env in the main process (which defeats the purpose of + a multi-process collector — Isaac Lab / mujoco-mjx etc. can only + run in workers) or routing a request to a worker over the pipe + (which requires workers to be alive and adds protocol surface). + Neither is implemented; call :meth:`~torchrl.collectors.Collector.fake_tensordict` + on a single :class:`~torchrl.collectors.Collector` instead, or + build the template directly from the env spec. + """ + raise NotImplementedError( + f"{type(self).__name__}.fake_tensordict() is not implemented. " + "Use Collector.fake_tensordict() on a single-process collector " + "for storage / cudagraph warmup, or build the template from the " + "env spec directly." + ) + @classmethod def _total_workers_from_env(cls, env_creators): if isinstance(env_creators, (tuple, list)): diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py index 9170a7e9f92..50882ac5ffc 100644 --- a/torchrl/collectors/_single.py +++ b/torchrl/collectors/_single.py @@ -1891,6 +1891,33 @@ def _maybe_set_truncated(self, final_rollout): ) return final_rollout + @torch.no_grad() + def fake_tensordict(self) -> TensorDictBase: + """Return a zero-filled tensordict shaped like one batch from this collector. + + The result mirrors what ``next(iter(collector))`` would yield: + + - batch shape ``(*env.batch_size, frames_per_batch)`` with the last + dim named ``"time"``; + - env keys (observation / reward / done / terminated / truncated / + ``is_init`` when an :class:`~torchrl.envs.InitTracker` is on the + env), policy out-keys, and ``("collector", "traj_ids")`` when + trajectory tracking is enabled; + - ``compact_obs=True`` exclusions applied; + - ``set_truncated=True`` last-step ``truncated``/``done`` masking + applied; + - ``postproc`` / ``split_trajs`` / private-key exclusion applied, + mirroring :meth:`_postproc`. + + Intended for storage initialization and ``torch.compile`` / + cudagraph warmup without having to step the environment first. + """ + if getattr(self, "_final_rollout", None) is None: + self._maybe_make_final_rollout(make_rollout=True) + result = self._final_rollout.clone().zero_() + result = self._maybe_set_truncated(result) + return self._postproc(result) + @torch.no_grad() def reset(self, index=None, **kwargs) -> None: """Resets the environments to a new initial state."""