diff --git a/test/envs/test_auto_reset.py b/test/envs/test_auto_reset.py index edc56efbb81..50fb56d8a07 100644 --- a/test/envs/test_auto_reset.py +++ b/test/envs/test_auto_reset.py @@ -99,7 +99,7 @@ def test_native_auto_reset_step_and_maybe_reset(self): ) assert tensordict["next", "done"].all() - assert (tensordict["next", "observation"] == 0).all() + assert (tensordict["next", "observation"] == tensordict["observation"]).all() assert not tensordict_["done"].any() assert tensordict_["is_init"].all() diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 117953ac1de..7a0cc8362d1 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -3693,8 +3693,9 @@ def _native_autoreset_set_invalid_next_observation( if obs is None or not isinstance(obs, torch.Tensor): continue reset_observations[obs_key] = obs.clone() + current_obs = tensordict.get(obs_key, None) if obs.is_floating_point(): - obs[done] = torch.nan + obs[done] = current_obs[done] else: obs[done] = 0 return reset_observations diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index b4ab20f40d4..bcf29404211 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -955,8 +955,8 @@ class GymWrapper(GymLikeEnv, metaclass=_GymAsyncMeta): observation returned by the environment as the next root observation instead of calling reset from :meth:`~torchrl.envs.EnvBase.step_and_maybe_reset`. The terminal - ``"next"`` observation is still invalid and filled with ``NaN`` for - floating point observations. Defaults to ``False``. + ``"next"`` observation is replaced with the current (root) + observation. Defaults to ``False``. Attributes: available_envs (List[str]): a list of environments to build. diff --git a/torchrl/envs/transforms/_misc.py b/torchrl/envs/transforms/_misc.py index 296a8b35ad1..b0588dd9a50 100644 --- a/torchrl/envs/transforms/_misc.py +++ b/torchrl/envs/transforms/_misc.py @@ -408,8 +408,8 @@ class VecGymEnvTransform(Transform): native_autoreset (bool, optional): if ``True``, leaves the native auto-reset observation available to the environment wrapper so it can be cloned into the next root observation, while the terminal - floating point ``"next"`` observation is marked with ``NaN``. - Defaults to ``False``. + ``"next"`` observation is replaced with the current (root) + observation. Defaults to ``False``. .. note:: In general, this class should not be handled directly. It is created whenever a vectorized environment is placed within a :class:`GymWrapper`.