diff --git a/test/test_collectors.py b/test/test_collectors.py index 642a10b0ed1..87d67e16a27 100644 --- a/test/test_collectors.py +++ b/test/test_collectors.py @@ -1538,6 +1538,88 @@ def make_env(): flat = without.reshape(-1) assert flat.batch_size.numel() == 20 + @pytest.mark.parametrize("use_buffers", [True, False]) + def test_fake_tensordict_single_matches_iter(self, use_buffers): + """``Collector.fake_tensordict()`` mirrors the shape and keys of a real batch.""" + + def make_env(): + return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker()) + + c = Collector( + create_env_fn=make_env, + policy=RandomPolicy(make_env().action_spec), + frames_per_batch=20, + total_frames=20, + use_buffers=use_buffers, + ) + try: + fake = c.fake_tensordict() + torch.manual_seed(0) + real = next(iter(c)) + assert fake.batch_size == real.batch_size, ( + fake.batch_size, + real.batch_size, + ) + assert fake.names == real.names + fake_keys = sorted(map(str, fake.keys(True, True))) + real_keys = sorted(map(str, real.keys(True, True))) + assert fake_keys == real_keys, set(real_keys) ^ set(fake_keys) + # Must be all-zero. + for key, val in fake.items(True, True): + if val.dtype in (torch.bool, torch.uint8) or not val.is_floating_point(): + continue + assert not val.any(), f"{key} is not zeroed" + finally: + c.shutdown() + + def test_fake_tensordict_single_with_compact_and_final_obs(self): + """``compact_obs`` + ``final_obs`` effects are visible on the fake batch.""" + from tensordict import UnbatchedTensor + + def make_env(): + return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker()) + + c = Collector( + create_env_fn=make_env, + policy=RandomPolicy(make_env().action_spec), + frames_per_batch=10, + total_frames=10, + compact_obs=True, + final_obs=True, + ) + try: + fake = c.fake_tensordict() + # compact_obs dropped ('next', obs) + assert ("next", "observation") not in fake.keys(True, True) + # final_obs attached ('final', obs) as UnbatchedTensor + val = fake.get(("final", "observation")) + assert isinstance(val, UnbatchedTensor) + finally: + c.shutdown() + + def test_fake_tensordict_multi_sync_stacks_workers(self): + """``MultiSyncCollector.fake_tensordict()`` stacks per-worker batches along dim 0.""" + + def make_env(): + return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker()) + + c = MultiCollector( + create_env_fn=[make_env, make_env], + policy=RandomPolicy(make_env().action_spec), + frames_per_batch=20, + total_frames=20, + sync=True, + ) + try: + fake = c.fake_tensordict() + # 2 workers, 10 frames each, no env batch dim on the mock env. + assert fake.batch_size == torch.Size([2, 10]) + assert fake.names[-1] == "time" + assert ("collector", "traj_ids") in fake.keys(True, True) + assert ("next", "reward") in fake.keys(True, True) + finally: + c.shutdown() + @pytest.mark.parametrize("env_class", [CountingEnv, CountingBatchedEnv]) def test_initial_obs_consistency(self, env_class, seed=1): # non regression test on #938 diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index 12f14137346..659b59a8792 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -1054,6 +1054,156 @@ def _add_policy_outputs_to_fake_td(self, fake_td): fake_td.set(key, policy_output.get(key)) return fake_td + @torch.no_grad() + def fake_tensordict(self) -> TensorDictBase: + """Return a zero-filled tensordict shaped like one batch from this multi-collector. + + Mirrors what one iteration of this collector yields, including + policy out-keys, ``("collector", "traj_ids")``, ``compact_obs`` / + ``final_obs`` effects, ``split_trajs``, and ``postproc``. + + Shape: + + - :class:`~torchrl.collectors.MultiSyncCollector` with + ``cat_results="stack"`` (the default): ``(num_workers, + *env.batch_size, frames_per_worker)``, last dim named ``"time"``. + - :class:`~torchrl.collectors.MultiSyncCollector` with an integer + ``cat_results``: per-worker tensordicts concatenated along that dim. + - :class:`~torchrl.collectors.MultiAsyncCollector`: + ``(*env.batch_size, frames_per_worker)`` — async yields one + worker batch at a time. + + Intended for storage initialization and ``torch.compile`` / + cudagraph warmup without spinning up the worker processes. + """ + # Build / borrow one env to read fake_tensordict, compact-obs leaf + # keys, and final-obs leaf shapes from. + env_fn = self.create_env_fn[0] + owns_env = False + if isinstance(env_fn, EnvBase): + env = env_fn + elif isinstance(env_fn, EnvCreator): + env = env_fn() + owns_env = True + else: + env = env_fn(**self.create_env_kwargs[0]) + owns_env = True + + try: + per_worker = env.fake_tensordict() + + compact_keys: tuple = () + final_keys: tuple = () + if self.compact_obs: + leaf_keys = list(env._observation_keys_step_mdp) + list( + env._state_keys_step_mdp + ) + _compact: list[tuple] = [] + _final: list[tuple] = [] + for k in leaf_keys: + if isinstance(k, tuple): + _compact.append(("next", *k)) + _final.append(("final", *k)) + else: + _compact.append(("next", k)) + _final.append(("final", k)) + compact_keys = tuple(_compact) + final_keys = tuple(_final) if self.final_obs else () + + # Capture leaf shapes for final_obs *before* the exclude + # collapses them. + final_obs_template = ( + per_worker.get("next").select( + *[k[1:][0] if len(k[1:]) == 1 else k[1:] for k in final_keys], + strict=False, + ) + if final_keys + else None + ) + + if compact_keys: + per_worker = per_worker.exclude(*compact_keys) + + per_worker = self._add_policy_outputs_to_fake_td(per_worker) + + frames_per_worker = self.frames_per_batch_worker(worker_idx=0) + per_worker = ( + per_worker.unsqueeze(-1) + .expand(*env.batch_size, frames_per_worker) + .clone() + .zero_() + ) + + per_worker.set( + ("collector", "traj_ids"), + torch.zeros(per_worker.shape, dtype=torch.int64), + ) + + if final_obs_template is not None: + from tensordict import UnbatchedTensor + + for final_k in final_keys: + leaf = final_k[1:] + leaf_path = leaf[0] if len(leaf) == 1 else leaf + src = final_obs_template.get(leaf_path, default=None) + if src is None: + continue + val = torch.zeros_like(src) + per_worker.set(final_k, UnbatchedTensor(val)) + + per_worker.refine_names(..., "time") + + cat_results = getattr(self, "cat_results", None) + if cat_results is None: + # MultiAsyncCollector yields one worker batch at a time; + # MultiSyncCollector defaults to stacking along dim 0. + from torchrl.collectors._multi_async import MultiAsyncCollector + + if isinstance(self, MultiAsyncCollector): + result = per_worker + else: + result = torch.stack( + [per_worker] + [per_worker.clone() for _ in range(self.num_workers - 1)], + 0, + ) + result.refine_names(*[None] * (result.ndim - 1) + ["time"]) + elif cat_results == "stack": + result = torch.stack( + [per_worker] + [per_worker.clone() for _ in range(self.num_workers - 1)], + 0, + ) + result.refine_names(*[None] * (result.ndim - 1) + ["time"]) + else: + result = torch.cat( + [per_worker] + [per_worker.clone() for _ in range(self.num_workers - 1)], + cat_results, + ) + if cat_results == -1: + result.refine_names(*[None] * (result.ndim - 1) + ["time"]) + + if self.split_trajs: + result = split_trajectories(result, prefix="collector") + if self.postproc is not None: + postproc = ( + self.postproc.to(result.device) + if hasattr(self.postproc, "to") + else self.postproc + ) + result = postproc(result) + if self._exclude_private_keys: + excluded_keys = [ + key + for key in result.keys() + if isinstance(key, str) and key.startswith("_") + ] + if excluded_keys: + result = result.exclude(*excluded_keys) + + return result + finally: + if owns_env: + env.close() + @classmethod def _total_workers_from_env(cls, env_creators): if isinstance(env_creators, (tuple, list)): diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py index ec3755a222e..2022549068f 100644 --- a/torchrl/collectors/_single.py +++ b/torchrl/collectors/_single.py @@ -1996,6 +1996,37 @@ def _maybe_set_truncated(self, final_rollout): ) return final_rollout + @torch.no_grad() + def fake_tensordict(self) -> TensorDictBase: + """Return a zero-filled tensordict shaped like one batch from this collector. + + The result mirrors what ``next(iter(collector))`` would yield: + + - batch shape ``(*env.batch_size, frames_per_batch)`` with the last + dim named ``"time"``; + - env keys (observation / reward / done / terminated / truncated / + ``is_init`` when an :class:`~torchrl.envs.InitTracker` is on the + env), policy out-keys, and ``("collector", "traj_ids")`` when + trajectory tracking is enabled; + - ``compact_obs=True`` exclusions and ``final_obs=True`` + ``("final", k)`` entries applied; + - ``set_truncated=True`` last-step ``truncated``/``done`` masking + applied; + - ``postproc`` / ``split_trajs`` / private-key exclusion applied, + mirroring :meth:`_postproc`. + + Intended for storage initialization and ``torch.compile`` / + cudagraph warmup without having to step the environment first. + """ + if getattr(self, "_final_rollout", None) is None: + # Build the rollout buffer on demand even when use_buffers=False + # so we have a structural template to clone from. + self._maybe_make_final_rollout(make_rollout=True) + result = self._final_rollout.clone().zero_() + result = self._maybe_attach_final_obs(result) + result = self._maybe_set_truncated(result) + return self._postproc(result) + @torch.no_grad() def reset(self, index=None, **kwargs) -> None: """Resets the environments to a new initial state."""