Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions test/test_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,6 +1538,88 @@ def make_env():
flat = without.reshape(-1)
assert flat.batch_size.numel() == 20

@pytest.mark.parametrize("use_buffers", [True, False])
def test_fake_tensordict_single_matches_iter(self, use_buffers):
"""``Collector.fake_tensordict()`` mirrors the shape and keys of a real batch."""

def make_env():
return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker())

c = Collector(
create_env_fn=make_env,
policy=RandomPolicy(make_env().action_spec),
frames_per_batch=20,
total_frames=20,
use_buffers=use_buffers,
)
try:
fake = c.fake_tensordict()
torch.manual_seed(0)
real = next(iter(c))
assert fake.batch_size == real.batch_size, (
fake.batch_size,
real.batch_size,
)
assert fake.names == real.names
fake_keys = sorted(map(str, fake.keys(True, True)))
real_keys = sorted(map(str, real.keys(True, True)))
assert fake_keys == real_keys, set(real_keys) ^ set(fake_keys)
# Must be all-zero.
for key, val in fake.items(True, True):
if val.dtype in (torch.bool, torch.uint8) or not val.is_floating_point():
continue
assert not val.any(), f"{key} is not zeroed"
finally:
c.shutdown()

def test_fake_tensordict_single_with_compact_and_final_obs(self):
"""``compact_obs`` + ``final_obs`` effects are visible on the fake batch."""
from tensordict import UnbatchedTensor

def make_env():
return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker())

c = Collector(
create_env_fn=make_env,
policy=RandomPolicy(make_env().action_spec),
frames_per_batch=10,
total_frames=10,
compact_obs=True,
final_obs=True,
)
try:
fake = c.fake_tensordict()
# compact_obs dropped ('next', obs)
assert ("next", "observation") not in fake.keys(True, True)
# final_obs attached ('final', obs) as UnbatchedTensor
val = fake.get(("final", "observation"))
assert isinstance(val, UnbatchedTensor)
finally:
c.shutdown()

def test_fake_tensordict_multi_sync_stacks_workers(self):
"""``MultiSyncCollector.fake_tensordict()`` stacks per-worker batches along dim 0."""

def make_env():
return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker())

c = MultiCollector(
create_env_fn=[make_env, make_env],
policy=RandomPolicy(make_env().action_spec),
frames_per_batch=20,
total_frames=20,
sync=True,
)
try:
fake = c.fake_tensordict()
# 2 workers, 10 frames each, no env batch dim on the mock env.
assert fake.batch_size == torch.Size([2, 10])
assert fake.names[-1] == "time"
assert ("collector", "traj_ids") in fake.keys(True, True)
assert ("next", "reward") in fake.keys(True, True)
finally:
c.shutdown()

@pytest.mark.parametrize("env_class", [CountingEnv, CountingBatchedEnv])
def test_initial_obs_consistency(self, env_class, seed=1):
# non regression test on #938
Expand Down
150 changes: 150 additions & 0 deletions torchrl/collectors/_multi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,156 @@ def _add_policy_outputs_to_fake_td(self, fake_td):
fake_td.set(key, policy_output.get(key))
return fake_td

@torch.no_grad()
def fake_tensordict(self) -> TensorDictBase:
"""Return a zero-filled tensordict shaped like one batch from this multi-collector.

Mirrors what one iteration of this collector yields, including
policy out-keys, ``("collector", "traj_ids")``, ``compact_obs`` /
``final_obs`` effects, ``split_trajs``, and ``postproc``.

Shape:

- :class:`~torchrl.collectors.MultiSyncCollector` with
``cat_results="stack"`` (the default): ``(num_workers,
*env.batch_size, frames_per_worker)``, last dim named ``"time"``.
- :class:`~torchrl.collectors.MultiSyncCollector` with an integer
``cat_results``: per-worker tensordicts concatenated along that dim.
- :class:`~torchrl.collectors.MultiAsyncCollector`:
``(*env.batch_size, frames_per_worker)`` — async yields one
worker batch at a time.

Intended for storage initialization and ``torch.compile`` /
cudagraph warmup without spinning up the worker processes.
"""
# Build / borrow one env to read fake_tensordict, compact-obs leaf
# keys, and final-obs leaf shapes from.
env_fn = self.create_env_fn[0]

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a big fan of this implementation.
We are creating an env in the main process which defies the purpose of the MultiCollector.
If we cannot get the fake data from the inner collector easily, we should just raise a NotImplementedError. But we should not pretend we're doing that (which is the right thing) and do it via a ton of custom code on the main process.

owns_env = False
if isinstance(env_fn, EnvBase):
env = env_fn
elif isinstance(env_fn, EnvCreator):
env = env_fn()
owns_env = True
else:
env = env_fn(**self.create_env_kwargs[0])
owns_env = True

try:
per_worker = env.fake_tensordict()

compact_keys: tuple = ()
final_keys: tuple = ()
if self.compact_obs:
leaf_keys = list(env._observation_keys_step_mdp) + list(
env._state_keys_step_mdp
)
_compact: list[tuple] = []
_final: list[tuple] = []
for k in leaf_keys:
if isinstance(k, tuple):
_compact.append(("next", *k))
_final.append(("final", *k))
else:
_compact.append(("next", k))
_final.append(("final", k))
compact_keys = tuple(_compact)
final_keys = tuple(_final) if self.final_obs else ()

# Capture leaf shapes for final_obs *before* the exclude
# collapses them.
final_obs_template = (
per_worker.get("next").select(
*[k[1:][0] if len(k[1:]) == 1 else k[1:] for k in final_keys],
strict=False,
)
if final_keys
else None
)

if compact_keys:
per_worker = per_worker.exclude(*compact_keys)

per_worker = self._add_policy_outputs_to_fake_td(per_worker)

frames_per_worker = self.frames_per_batch_worker(worker_idx=0)
per_worker = (
per_worker.unsqueeze(-1)
.expand(*env.batch_size, frames_per_worker)
.clone()
.zero_()
)

per_worker.set(
("collector", "traj_ids"),
torch.zeros(per_worker.shape, dtype=torch.int64),
)

if final_obs_template is not None:
from tensordict import UnbatchedTensor

for final_k in final_keys:
leaf = final_k[1:]
leaf_path = leaf[0] if len(leaf) == 1 else leaf
src = final_obs_template.get(leaf_path, default=None)
if src is None:
continue
val = torch.zeros_like(src)
per_worker.set(final_k, UnbatchedTensor(val))

per_worker.refine_names(..., "time")

cat_results = getattr(self, "cat_results", None)
if cat_results is None:
# MultiAsyncCollector yields one worker batch at a time;
# MultiSyncCollector defaults to stacking along dim 0.
from torchrl.collectors._multi_async import MultiAsyncCollector

if isinstance(self, MultiAsyncCollector):
result = per_worker
else:
result = torch.stack(
[per_worker] + [per_worker.clone() for _ in range(self.num_workers - 1)],
0,
)
result.refine_names(*[None] * (result.ndim - 1) + ["time"])
elif cat_results == "stack":
result = torch.stack(
[per_worker] + [per_worker.clone() for _ in range(self.num_workers - 1)],
0,
)
result.refine_names(*[None] * (result.ndim - 1) + ["time"])
else:
result = torch.cat(
[per_worker] + [per_worker.clone() for _ in range(self.num_workers - 1)],
cat_results,
)
if cat_results == -1:
result.refine_names(*[None] * (result.ndim - 1) + ["time"])

if self.split_trajs:
result = split_trajectories(result, prefix="collector")
if self.postproc is not None:
postproc = (
self.postproc.to(result.device)
if hasattr(self.postproc, "to")
else self.postproc
)
result = postproc(result)
if self._exclude_private_keys:
excluded_keys = [
key
for key in result.keys()
if isinstance(key, str) and key.startswith("_")
]
if excluded_keys:
result = result.exclude(*excluded_keys)

return result
finally:
if owns_env:
env.close()

@classmethod
def _total_workers_from_env(cls, env_creators):
if isinstance(env_creators, (tuple, list)):
Expand Down
31 changes: 31 additions & 0 deletions torchrl/collectors/_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -1996,6 +1996,37 @@ def _maybe_set_truncated(self, final_rollout):
)
return final_rollout

@torch.no_grad()
def fake_tensordict(self) -> TensorDictBase:
"""Return a zero-filled tensordict shaped like one batch from this collector.

The result mirrors what ``next(iter(collector))`` would yield:

- batch shape ``(*env.batch_size, frames_per_batch)`` with the last
dim named ``"time"``;
- env keys (observation / reward / done / terminated / truncated /
``is_init`` when an :class:`~torchrl.envs.InitTracker` is on the
env), policy out-keys, and ``("collector", "traj_ids")`` when
trajectory tracking is enabled;
- ``compact_obs=True`` exclusions and ``final_obs=True``
``("final", k)`` entries applied;
- ``set_truncated=True`` last-step ``truncated``/``done`` masking
applied;
- ``postproc`` / ``split_trajs`` / private-key exclusion applied,
mirroring :meth:`_postproc`.

Intended for storage initialization and ``torch.compile`` /
cudagraph warmup without having to step the environment first.
"""
if getattr(self, "_final_rollout", None) is None:
# Build the rollout buffer on demand even when use_buffers=False
# so we have a structural template to clone from.
self._maybe_make_final_rollout(make_rollout=True)
result = self._final_rollout.clone().zero_()
result = self._maybe_attach_final_obs(result)
result = self._maybe_set_truncated(result)
return self._postproc(result)

@torch.no_grad()
def reset(self, index=None, **kwargs) -> None:
"""Resets the environments to a new initial state."""
Expand Down
Loading