From 2569c1d6de20b8540300ae96c22548c07e7a7fbf Mon Sep 17 00:00:00 2001 From: Philipp Sinitsin Date: Sat, 13 Jun 2026 14:14:03 +0100 Subject: [PATCH 1/2] [BugFix] ParallelEnv over MPS envs: default to use_buffers=False, stage pipe data on CPU MPS storages cannot be placed in shared memory nor pickled through multiprocessing pipes, so ParallelEnv crashed with "_share_filename_: only available on CPU" whenever the sub-envs lived on MPS (#3066). Following the design discussed in the issue, _get_metadata now inspects the spec-leaf devices recorded in EnvMetaData.device_map: with MPS leaves, use_buffers defaults to False with a warning and an explicit use_buffers=True raises. The no-buffers path stages the pipe data on CPU before consolidation on both ends and casts it back to the env device on reception. SerialEnv behavior is unchanged. --- test/envs/test_special.py | 158 +++++++++++++++++++++++++++++++++++ torchrl/envs/batched_envs.py | 106 +++++++++++++++++++++-- 2 files changed, 257 insertions(+), 7 deletions(-) diff --git a/test/envs/test_special.py b/test/envs/test_special.py index f1c134a0b1b..7cfbab821f2 100644 --- a/test/envs/test_special.py +++ b/test/envs/test_special.py @@ -33,6 +33,7 @@ _td_to_device_mps_safe, _to_device_mps_safe, ) +from torchrl.envs.env_creator import get_env_metadata from torchrl.envs.transforms import StepCounter, TransformedEnv from torchrl.envs.transforms.transforms import Tokenizer from torchrl.envs.utils import check_env_specs @@ -979,3 +980,160 @@ def test_parallel_env_no_buffers_mps_rollout(self): assert td["observation"].dtype == torch.float32 finally: env.close(raise_if_closed=False) + + +_MPS_USE_BUFFERS_WARNING = ( + "The environment specs have leaves on an MPS device, which cannot be placed " + "in shared memory" +) +_MPS_USE_BUFFERS_ERROR = ( + "use_buffers=True is incompatible with environments whose specs have leaves " + "on an MPS device" +) + + +class TestParallelEnvMPSBuffers: + """ParallelEnv use_buffers checks for MPS sub-envs (issue #3066). + + These tests fake the device map reported by the env metadata, so they run + on CPU-only machines too. + """ + + @staticmethod + def _patch_device_map_to_mps(monkeypatch): + def get_env_metadata_mps(*args, **kwargs): + meta_data = get_env_metadata(*args, **kwargs) + meta_data.device_map = { + key: torch.device("mps") for key in meta_data.device_map + } + return meta_data + + monkeypatch.setattr( + "torchrl.envs.batched_envs.get_env_metadata", get_env_metadata_mps + ) + + def test_parallel_env_mps_leaves_default_use_buffers_false(self, monkeypatch): + self._patch_device_map_to_mps(monkeypatch) + with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING): + env = ParallelEnv(2, ContinuousActionVecMockEnv) + assert env._use_buffers is False + + def test_parallel_env_mps_leaves_use_buffers_true_raises(self, monkeypatch): + self._patch_device_map_to_mps(monkeypatch) + with pytest.raises(RuntimeError, match=_MPS_USE_BUFFERS_ERROR): + ParallelEnv(2, ContinuousActionVecMockEnv, use_buffers=True) + + def test_parallel_env_mps_leaves_configure_parallel_raises(self, monkeypatch): + self._patch_device_map_to_mps(monkeypatch) + with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING): + env = ParallelEnv(2, ContinuousActionVecMockEnv) + with pytest.raises(RuntimeError, match=_MPS_USE_BUFFERS_ERROR): + env.configure_parallel(use_buffers=True) + + def test_parallel_env_mps_leaves_explicit_use_buffers_false(self, monkeypatch): + self._patch_device_map_to_mps(monkeypatch) + env = ParallelEnv(2, ContinuousActionVecMockEnv, use_buffers=False) + assert env._use_buffers is False + + def test_serial_env_mps_leaves_keeps_buffers(self, monkeypatch): + # SerialEnv runs in-process, so MPS buffers are fine there + self._patch_device_map_to_mps(monkeypatch) + env = SerialEnv(2, ContinuousActionVecMockEnv) + assert env._use_buffers is True + + +@pytest.mark.skipif(not _has_mps(), reason="MPS device not available") +class TestMPSSubEnvs: + """ParallelEnv and collectors over sub-envs living on MPS (issue #3066).""" + + class _MPSObsEnv(EnvBase): + """Minimal env with all spec leaves on MPS. + + The observation mirrors the last action so that the parent-worker + round-trip can be checked end-to-end. + """ + + def __init__(self, device="mps"): + super().__init__(device=device) + self.observation_spec = Composite( + observation=Unbounded(shape=(3,), device=device), device=device + ) + self.action_spec = Unbounded(shape=(1,), device=device) + self.reward_spec = Unbounded(shape=(1,), device=device) + + def _reset(self, tensordict): + return TensorDict( + {"observation": torch.zeros(3, device=self.device)}, + batch_size=[], + device=self.device, + ) + + def _step(self, tensordict): + return TensorDict( + { + "observation": tensordict["action"].expand(3).clone(), + "reward": torch.zeros(1, device=self.device), + "done": torch.zeros(1, dtype=torch.bool, device=self.device), + }, + batch_size=[], + device=self.device, + ) + + def _set_seed(self, seed): + return seed + + def test_parallel_env_mps_sub_envs_default_warns_and_runs(self): + with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING): + env = ParallelEnv(2, self._MPSObsEnv) + try: + assert env._use_buffers is False + td = env.reset() + assert td.device.type == "mps" + assert td["observation"].device.type == "mps" + policy = RandomPolicy(env.action_spec) + rollout = env.rollout(max_steps=3, policy=policy) + assert rollout.device.type == "mps" + # the worker must have seen the actions sampled in the parent + assert (rollout["next", "observation"] == rollout["action"]).all() + finally: + env.close(raise_if_closed=False) + + def test_parallel_env_mps_sub_envs_use_buffers_true_raises(self): + with pytest.raises(RuntimeError, match=_MPS_USE_BUFFERS_ERROR): + ParallelEnv(2, self._MPSObsEnv, use_buffers=True) + + @pytest.mark.parametrize("consolidate", [True, False]) + def test_parallel_env_mps_sub_envs_no_buffers_rollout(self, consolidate): + env = ParallelEnv( + 2, self._MPSObsEnv, use_buffers=False, consolidate=consolidate + ) + try: + policy = RandomPolicy(env.action_spec) + rollout = env.rollout(max_steps=3, policy=policy) + assert rollout.device.type == "mps" + assert (rollout["next", "observation"] == rollout["action"]).all() + finally: + env.close(raise_if_closed=False) + + def test_collector_parallel_env_mps_sub_envs(self): + # the setup reported in issue #3066 + with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING): + collector = Collector( + lambda: ParallelEnv(2, self._MPSObsEnv), + frames_per_batch=4, + total_frames=8, + ) + try: + for data in collector: + assert data.numel() == 4 + finally: + collector.shutdown() + + def test_serial_env_mps_sub_envs_buffers(self): + env = SerialEnv(2, self._MPSObsEnv) + try: + assert env._use_buffers is True + rollout = env.rollout(max_steps=3) + assert rollout.device.type == "mps" + finally: + env.close(raise_if_closed=False) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index b1ba4d29bf8..3cf9346487b 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -325,6 +325,12 @@ class BatchedEnvBase(EnvBase): one of the environment has dynamic specs. .. note:: Learn more about dynamic specs and environments :ref:`here `. + + .. note:: MPS tensors cannot be placed in shared memory. If the environment + specs have leaves on an MPS device, :class:`~torchrl.envs.ParallelEnv` + defaults to ``use_buffers=False`` (with a warning) and the data exchanged + between processes is staged on CPU. Passing ``use_buffers=True`` in that + case raises a ``RuntimeError``. daemon (bool, optional): whether the processes should be daemonized. This is only applicable to parallel environments such as :class:`~torchrl.envs.ParallelEnv`. Defaults to ``False``. @@ -580,6 +586,7 @@ def configure_parallel( ) if use_buffers is not None: self._use_buffers = use_buffers + self._check_mps_use_buffers() if shared_memory is not None: self._share_memory = shared_memory if memmap is not None: @@ -732,6 +739,30 @@ def _validate_worker_env(env) -> None: if transform is not None: transform._check_batched_worker_compat() + def _check_mps_use_buffers(self) -> None: + """Prevents the use of shared buffers when the sub-env data lives on MPS. + + MPS storages cannot be placed in shared memory nor pickled through + multiprocessing pipes, so :class:`~torchrl.envs.ParallelEnv` must run + with ``use_buffers=False`` and stage the inter-process data on CPU. + """ + if not self._has_mps_leaves or not isinstance(self, ParallelEnv): + return + if self._use_buffers: + raise RuntimeError( + "use_buffers=True is incompatible with environments whose specs have " + "leaves on an MPS device, because MPS tensors cannot be placed in shared " + "memory. Pass use_buffers=False or move the environments to CPU." + ) + if self._use_buffers is None: + warn( + "The environment specs have leaves on an MPS device, which cannot be placed " + "in shared memory. use_buffers will default to False and the data will be " + "passed between processes through CPU memory. To silence this warning, pass " + "use_buffers=False to the ParallelEnv constructor." + ) + self._use_buffers = False + def _get_metadata( self, create_env_fn: list[Callable], create_env_kwargs: list[dict] ): @@ -745,6 +776,10 @@ def _get_metadata( self.meta_data = meta_data.expand( *(self.num_workers, *meta_data.batch_size) ) + self._has_mps_leaves = any( + device.type == "mps" for device in meta_data.device_map.values() + ) + self._check_mps_use_buffers() if self._use_buffers is not False: _use_buffers = not self.meta_data.has_dynamic_specs if self._use_buffers and not _use_buffers: @@ -776,15 +811,22 @@ def _get_metadata( "be True to accommodate non-stackable tensors." ) self.share_individual_td = share_individual_td - _use_buffers = all( - not metadata.has_dynamic_specs for metadata in self.meta_data + self._has_mps_leaves = any( + device.type == "mps" + for metadata in self.meta_data + for device in metadata.device_map.values() ) - if self._use_buffers and not _use_buffers: - warn( - "A value of use_buffers=True was passed but this is incompatible " - "with the list of environments provided. Turning use_buffers to False." + self._check_mps_use_buffers() + if self._use_buffers is not False: + _use_buffers = all( + not metadata.has_dynamic_specs for metadata in self.meta_data ) - self._use_buffers = _use_buffers + if self._use_buffers and not _use_buffers: + warn( + "A value of use_buffers=True was passed but this is incompatible " + "with the list of environments provided. Turning use_buffers to False." + ) + self._use_buffers = _use_buffers self._set_properties() @@ -1878,6 +1920,10 @@ def _step_and_maybe_reset_no_buffers( else: workers_range = range(self.num_workers) + if self._has_mps_leaves: + # MPS storages cannot be pickled through mp pipes: stage the data + # on CPU before sending it (the workers cast it back to their device) + tensordict = tensordict.to("cpu") if self.consolidate: try: td = tensordict.consolidate( @@ -2219,6 +2265,10 @@ def _step_no_buffers( else: workers_range = range(self.num_workers) + if self._has_mps_leaves: + # MPS storages cannot be pickled through mp pipes: stage the data + # on CPU before sending it (the workers cast it back to their device) + tensordict = tensordict.to("cpu") if self.consolidate: try: data = tensordict.consolidate( @@ -2242,6 +2292,10 @@ def _step_no_buffers( if data.device != env_device: if env_device is None: local_data.clear_device_() + elif env_device.type == "mps": + # MPS data cannot be sent through the pipe: the worker casts + # the data back to the env device upon reception. + pass else: local_data = local_data.to(env_device) self.parent_channels[i].send(("step", local_data)) @@ -2456,6 +2510,10 @@ def _reset_no_buffers( needs_resetting, ) -> tuple[TensorDictBase, TensorDictBase]: if is_tensor_collection(tensordict): + if self._has_mps_leaves: + # MPS storages cannot be pickled through mp pipes: stage the data + # on CPU before sending it (the workers cast it back to their device) + tensordict = tensordict.to("cpu") if self.consolidate: try: tensordict = tensordict.consolidate( @@ -3108,6 +3166,20 @@ def _run_worker_pipe_direct( event = torch.cuda.Event() else: event = None + # MPS storages cannot be pickled through pipes, so when the env specs have + # leaves on MPS the data sent to the parent is staged on CPU and the data + # received from the parent is cast back to the env device. + for spec in env.output_spec.values(True, True): + if spec.device is not None and spec.device.type == "mps": + has_mps = True + break + else: + for spec in env.input_spec.values(True, True): + if spec.device is not None and spec.device.type == "mps": + has_mps = True + break + else: + has_mps = False i = -1 import torchrl @@ -3158,6 +3230,9 @@ def _run_worker_pipe_direct( data._fast_apply( lambda x: x.clone() if x.device.type == "cuda" else x, out=data ) + if has_mps and env.device is not None and data.device != env.device: + # the parent sent CPU-staged data: cast it back to the env device + data = _td_to_device_mps_safe(data, env.device) cur_td = env.reset( tensordict=data, **reset_kwargs, @@ -3165,6 +3240,10 @@ def _run_worker_pipe_direct( if event is not None: event.record() event.synchronize() + if has_mps: + # MPS storages cannot be pickled through mp pipes: stage the data + # on CPU before sending it (the parent casts it back to its device) + cur_td = cur_td.to("cpu") if consolidate: try: cur_td = cur_td.consolidate( @@ -3187,10 +3266,17 @@ def _run_worker_pipe_direct( if not initialized: raise RuntimeError("called 'init' before step") i += 1 + if has_mps and env.device is not None and data.device != env.device: + # the parent sent CPU-staged data: cast it back to the env device + data = _td_to_device_mps_safe(data, env.device) next_td = env._step(data) if event is not None: event.record() event.synchronize() + if has_mps: + # MPS storages cannot be pickled through mp pipes: stage the data + # on CPU before sending it (the parent casts it back to its device) + next_td = next_td.to("cpu") if consolidate: try: next_td = next_td.consolidate( @@ -3218,11 +3304,17 @@ def _run_worker_pipe_direct( data._fast_apply( lambda x: x.clone() if x.device.type == "cuda" else x, out=data ) + if has_mps and env.device is not None and data.device != env.device: + # the parent sent CPU-staged data: cast it back to the env device + data = _td_to_device_mps_safe(data, env.device) td, root_next_td = env.step_and_maybe_reset(data) if event is not None: event.record() event.synchronize() + if has_mps: + td = td.to("cpu") + root_next_td = root_next_td.to("cpu") child_pipe.send((td, root_next_td)) mp_event.set() del td, root_next_td From 3c0a18e6a0796b474ea47bc23f4f80ef4c7736f1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 17 Jun 2026 17:02:58 +0100 Subject: [PATCH 2/2] [BugFix] Fix ParallelEnv no-buffer consolidation --- test/transforms/test_action_transforms.py | 1 + torchrl/envs/batched_envs.py | 16 ++++++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/test/transforms/test_action_transforms.py b/test/transforms/test_action_transforms.py index 06fb92153e5..12d58c3cbee 100644 --- a/test/transforms/test_action_transforms.py +++ b/test/transforms/test_action_transforms.py @@ -2515,6 +2515,7 @@ def test_trailing_dim_enforced(self): with pytest.raises(ValueError, match="immediately follow"): t(TensorDict({"action": torch.randn(2, 4, 3)}, batch_size=[2, 4])) + @pytest.mark.skipif(IS_WIN, reason="windows tests do not support compile") def test_compile_build_chunk(self): t = ActionChunkTransform(chunk_size=3) action = torch.randn(2, 5, 2) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 3cf9346487b..93ecb31447e 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -2278,18 +2278,30 @@ def _step_no_buffers( inplace=False, num_threads=1, ) + except RuntimeError as err: + if "self.stride(-1) must be 1" not in str(err): + raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err + data = [ + local_data.contiguous().consolidate( + share_memory=False, + inplace=False, + num_threads=1, + ) + for local_data in tensordict.unbind(0) + ] except Exception as err: raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err else: data = tensordict - for i, local_data in zip(workers_range, data.unbind(0)): + data_iter = data if isinstance(data, list) else data.unbind(0) + for i, local_data in zip(workers_range, data_iter): env_device = ( self.meta_data[i].device if isinstance(self.meta_data, list) else self.meta_data.device ) - if data.device != env_device: + if local_data.device != env_device: if env_device is None: local_data.clear_device_() elif env_device.type == "mps":