Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions test/envs/test_special.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_td_to_device_mps_safe,
_to_device_mps_safe,
)
from torchrl.envs.env_creator import get_env_metadata
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.envs.transforms.transforms import Tokenizer
from torchrl.envs.utils import check_env_specs
Expand Down Expand Up @@ -979,3 +980,160 @@ def test_parallel_env_no_buffers_mps_rollout(self):
assert td["observation"].dtype == torch.float32
finally:
env.close(raise_if_closed=False)


_MPS_USE_BUFFERS_WARNING = (
"The environment specs have leaves on an MPS device, which cannot be placed "
"in shared memory"
)
_MPS_USE_BUFFERS_ERROR = (
"use_buffers=True is incompatible with environments whose specs have leaves "
"on an MPS device"
)


class TestParallelEnvMPSBuffers:
"""ParallelEnv use_buffers checks for MPS sub-envs (issue #3066).

These tests fake the device map reported by the env metadata, so they run
on CPU-only machines too.
"""

@staticmethod
def _patch_device_map_to_mps(monkeypatch):
def get_env_metadata_mps(*args, **kwargs):
meta_data = get_env_metadata(*args, **kwargs)
meta_data.device_map = {
key: torch.device("mps") for key in meta_data.device_map
}
return meta_data

monkeypatch.setattr(
"torchrl.envs.batched_envs.get_env_metadata", get_env_metadata_mps
)

def test_parallel_env_mps_leaves_default_use_buffers_false(self, monkeypatch):
self._patch_device_map_to_mps(monkeypatch)
with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING):
env = ParallelEnv(2, ContinuousActionVecMockEnv)
assert env._use_buffers is False

def test_parallel_env_mps_leaves_use_buffers_true_raises(self, monkeypatch):
self._patch_device_map_to_mps(monkeypatch)
with pytest.raises(RuntimeError, match=_MPS_USE_BUFFERS_ERROR):
ParallelEnv(2, ContinuousActionVecMockEnv, use_buffers=True)

def test_parallel_env_mps_leaves_configure_parallel_raises(self, monkeypatch):
self._patch_device_map_to_mps(monkeypatch)
with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING):
env = ParallelEnv(2, ContinuousActionVecMockEnv)
with pytest.raises(RuntimeError, match=_MPS_USE_BUFFERS_ERROR):
env.configure_parallel(use_buffers=True)

def test_parallel_env_mps_leaves_explicit_use_buffers_false(self, monkeypatch):
self._patch_device_map_to_mps(monkeypatch)
env = ParallelEnv(2, ContinuousActionVecMockEnv, use_buffers=False)
assert env._use_buffers is False

def test_serial_env_mps_leaves_keeps_buffers(self, monkeypatch):
# SerialEnv runs in-process, so MPS buffers are fine there
self._patch_device_map_to_mps(monkeypatch)
env = SerialEnv(2, ContinuousActionVecMockEnv)
assert env._use_buffers is True


@pytest.mark.skipif(not _has_mps(), reason="MPS device not available")
class TestMPSSubEnvs:
"""ParallelEnv and collectors over sub-envs living on MPS (issue #3066)."""

class _MPSObsEnv(EnvBase):
"""Minimal env with all spec leaves on MPS.

The observation mirrors the last action so that the parent-worker
round-trip can be checked end-to-end.
"""

def __init__(self, device="mps"):
super().__init__(device=device)
self.observation_spec = Composite(
observation=Unbounded(shape=(3,), device=device), device=device
)
self.action_spec = Unbounded(shape=(1,), device=device)
self.reward_spec = Unbounded(shape=(1,), device=device)

def _reset(self, tensordict):
return TensorDict(
{"observation": torch.zeros(3, device=self.device)},
batch_size=[],
device=self.device,
)

def _step(self, tensordict):
return TensorDict(
{
"observation": tensordict["action"].expand(3).clone(),
"reward": torch.zeros(1, device=self.device),
"done": torch.zeros(1, dtype=torch.bool, device=self.device),
},
batch_size=[],
device=self.device,
)

def _set_seed(self, seed):
return seed

def test_parallel_env_mps_sub_envs_default_warns_and_runs(self):
with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING):
env = ParallelEnv(2, self._MPSObsEnv)
try:
assert env._use_buffers is False
td = env.reset()
assert td.device.type == "mps"
assert td["observation"].device.type == "mps"
policy = RandomPolicy(env.action_spec)
rollout = env.rollout(max_steps=3, policy=policy)
assert rollout.device.type == "mps"
# the worker must have seen the actions sampled in the parent
assert (rollout["next", "observation"] == rollout["action"]).all()
finally:
env.close(raise_if_closed=False)

def test_parallel_env_mps_sub_envs_use_buffers_true_raises(self):
with pytest.raises(RuntimeError, match=_MPS_USE_BUFFERS_ERROR):
ParallelEnv(2, self._MPSObsEnv, use_buffers=True)

@pytest.mark.parametrize("consolidate", [True, False])
def test_parallel_env_mps_sub_envs_no_buffers_rollout(self, consolidate):
env = ParallelEnv(
2, self._MPSObsEnv, use_buffers=False, consolidate=consolidate
)
try:
policy = RandomPolicy(env.action_spec)
rollout = env.rollout(max_steps=3, policy=policy)
assert rollout.device.type == "mps"
assert (rollout["next", "observation"] == rollout["action"]).all()
finally:
env.close(raise_if_closed=False)

def test_collector_parallel_env_mps_sub_envs(self):
# the setup reported in issue #3066
with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING):
collector = Collector(
lambda: ParallelEnv(2, self._MPSObsEnv),
frames_per_batch=4,
total_frames=8,
)
try:
for data in collector:
assert data.numel() == 4
finally:
collector.shutdown()

def test_serial_env_mps_sub_envs_buffers(self):
env = SerialEnv(2, self._MPSObsEnv)
try:
assert env._use_buffers is True
rollout = env.rollout(max_steps=3)
assert rollout.device.type == "mps"
finally:
env.close(raise_if_closed=False)
106 changes: 99 additions & 7 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,12 @@ class BatchedEnvBase(EnvBase):
one of the environment has dynamic specs.

.. note:: Learn more about dynamic specs and environments :ref:`here <dynamic_envs>`.

.. note:: MPS tensors cannot be placed in shared memory. If the environment
specs have leaves on an MPS device, :class:`~torchrl.envs.ParallelEnv`
defaults to ``use_buffers=False`` (with a warning) and the data exchanged
between processes is staged on CPU. Passing ``use_buffers=True`` in that
case raises a ``RuntimeError``.
daemon (bool, optional): whether the processes should be daemonized.
This is only applicable to parallel environments such as :class:`~torchrl.envs.ParallelEnv`.
Defaults to ``False``.
Expand Down Expand Up @@ -580,6 +586,7 @@ def configure_parallel(
)
if use_buffers is not None:
self._use_buffers = use_buffers
self._check_mps_use_buffers()
if shared_memory is not None:
self._share_memory = shared_memory
if memmap is not None:
Expand Down Expand Up @@ -732,6 +739,30 @@ def _validate_worker_env(env) -> None:
if transform is not None:
transform._check_batched_worker_compat()

def _check_mps_use_buffers(self) -> None:
"""Prevents the use of shared buffers when the sub-env data lives on MPS.

MPS storages cannot be placed in shared memory nor pickled through
multiprocessing pipes, so :class:`~torchrl.envs.ParallelEnv` must run
with ``use_buffers=False`` and stage the inter-process data on CPU.
"""
if not self._has_mps_leaves or not isinstance(self, ParallelEnv):
return
if self._use_buffers:
raise RuntimeError(
"use_buffers=True is incompatible with environments whose specs have "
"leaves on an MPS device, because MPS tensors cannot be placed in shared "
"memory. Pass use_buffers=False or move the environments to CPU."
)
if self._use_buffers is None:
warn(
"The environment specs have leaves on an MPS device, which cannot be placed "
"in shared memory. use_buffers will default to False and the data will be "
"passed between processes through CPU memory. To silence this warning, pass "
"use_buffers=False to the ParallelEnv constructor."
)
self._use_buffers = False

def _get_metadata(
self, create_env_fn: list[Callable], create_env_kwargs: list[dict]
):
Expand All @@ -745,6 +776,10 @@ def _get_metadata(
self.meta_data = meta_data.expand(
*(self.num_workers, *meta_data.batch_size)
)
self._has_mps_leaves = any(
device.type == "mps" for device in meta_data.device_map.values()
)
self._check_mps_use_buffers()
if self._use_buffers is not False:
_use_buffers = not self.meta_data.has_dynamic_specs
if self._use_buffers and not _use_buffers:
Expand Down Expand Up @@ -776,15 +811,22 @@ def _get_metadata(
"be True to accommodate non-stackable tensors."
)
self.share_individual_td = share_individual_td
_use_buffers = all(
not metadata.has_dynamic_specs for metadata in self.meta_data
self._has_mps_leaves = any(
device.type == "mps"
for metadata in self.meta_data
for device in metadata.device_map.values()
)
if self._use_buffers and not _use_buffers:
warn(
"A value of use_buffers=True was passed but this is incompatible "
"with the list of environments provided. Turning use_buffers to False."
self._check_mps_use_buffers()
if self._use_buffers is not False:
_use_buffers = all(
not metadata.has_dynamic_specs for metadata in self.meta_data
)
self._use_buffers = _use_buffers
if self._use_buffers and not _use_buffers:
warn(
"A value of use_buffers=True was passed but this is incompatible "
"with the list of environments provided. Turning use_buffers to False."
)
self._use_buffers = _use_buffers

self._set_properties()

Expand Down Expand Up @@ -1878,6 +1920,10 @@ def _step_and_maybe_reset_no_buffers(
else:
workers_range = range(self.num_workers)

if self._has_mps_leaves:
# MPS storages cannot be pickled through mp pipes: stage the data
# on CPU before sending it (the workers cast it back to their device)
tensordict = tensordict.to("cpu")
if self.consolidate:
try:
td = tensordict.consolidate(
Expand Down Expand Up @@ -2219,6 +2265,10 @@ def _step_no_buffers(
else:
workers_range = range(self.num_workers)

if self._has_mps_leaves:
# MPS storages cannot be pickled through mp pipes: stage the data
# on CPU before sending it (the workers cast it back to their device)
tensordict = tensordict.to("cpu")
if self.consolidate:
try:
data = tensordict.consolidate(
Expand All @@ -2242,6 +2292,10 @@ def _step_no_buffers(
if data.device != env_device:
if env_device is None:
local_data.clear_device_()
elif env_device.type == "mps":
# MPS data cannot be sent through the pipe: the worker casts
# the data back to the env device upon reception.
pass
else:
local_data = local_data.to(env_device)
self.parent_channels[i].send(("step", local_data))
Expand Down Expand Up @@ -2456,6 +2510,10 @@ def _reset_no_buffers(
needs_resetting,
) -> tuple[TensorDictBase, TensorDictBase]:
if is_tensor_collection(tensordict):
if self._has_mps_leaves:
# MPS storages cannot be pickled through mp pipes: stage the data
# on CPU before sending it (the workers cast it back to their device)
tensordict = tensordict.to("cpu")
if self.consolidate:
try:
tensordict = tensordict.consolidate(
Expand Down Expand Up @@ -3108,6 +3166,20 @@ def _run_worker_pipe_direct(
event = torch.cuda.Event()
else:
event = None
# MPS storages cannot be pickled through pipes, so when the env specs have
# leaves on MPS the data sent to the parent is staged on CPU and the data
# received from the parent is cast back to the env device.
for spec in env.output_spec.values(True, True):
if spec.device is not None and spec.device.type == "mps":
has_mps = True
break
else:
for spec in env.input_spec.values(True, True):
if spec.device is not None and spec.device.type == "mps":
has_mps = True
break
else:
has_mps = False

i = -1
import torchrl
Expand Down Expand Up @@ -3158,13 +3230,20 @@ def _run_worker_pipe_direct(
data._fast_apply(
lambda x: x.clone() if x.device.type == "cuda" else x, out=data
)
if has_mps and env.device is not None and data.device != env.device:
# the parent sent CPU-staged data: cast it back to the env device
data = _td_to_device_mps_safe(data, env.device)
cur_td = env.reset(
tensordict=data,
**reset_kwargs,
)
if event is not None:
event.record()
event.synchronize()
if has_mps:
# MPS storages cannot be pickled through mp pipes: stage the data
# on CPU before sending it (the parent casts it back to its device)
cur_td = cur_td.to("cpu")
if consolidate:
try:
cur_td = cur_td.consolidate(
Expand All @@ -3187,10 +3266,17 @@ def _run_worker_pipe_direct(
if not initialized:
raise RuntimeError("called 'init' before step")
i += 1
if has_mps and env.device is not None and data.device != env.device:
# the parent sent CPU-staged data: cast it back to the env device
data = _td_to_device_mps_safe(data, env.device)
next_td = env._step(data)
if event is not None:
event.record()
event.synchronize()
if has_mps:
# MPS storages cannot be pickled through mp pipes: stage the data
# on CPU before sending it (the parent casts it back to its device)
next_td = next_td.to("cpu")
if consolidate:
try:
next_td = next_td.consolidate(
Expand Down Expand Up @@ -3218,11 +3304,17 @@ def _run_worker_pipe_direct(
data._fast_apply(
lambda x: x.clone() if x.device.type == "cuda" else x, out=data
)
if has_mps and env.device is not None and data.device != env.device:
# the parent sent CPU-staged data: cast it back to the env device
data = _td_to_device_mps_safe(data, env.device)
td, root_next_td = env.step_and_maybe_reset(data)

if event is not None:
event.record()
event.synchronize()
if has_mps:
td = td.to("cpu")
root_next_td = root_next_td.to("cpu")
child_pipe.send((td, root_next_td))
mp_event.set()
del td, root_next_td
Expand Down
Loading