diff --git a/test/libs/test_misc.py b/test/libs/test_misc.py index ad45fd7a671..f945642b556 100644 --- a/test/libs/test_misc.py +++ b/test/libs/test_misc.py @@ -85,8 +85,7 @@ def test_robohive(self, envname, from_pixels, from_depths): # List of OpenSpiel games to test # TODO: Some of the games in `OpenSpielWrapper.available_envs` raise errors for -# a few different reasons, mostly because we do not support chance nodes yet. So -# we cannot run tests on all of them yet. +# a few different reasons, so we cannot run tests on all of them yet. _openspiel_games = [ # ---------------- # Sequential games @@ -161,9 +160,7 @@ def test_all_envs(self, game_string, return_state, categorical_actions): @pytest.mark.parametrize("return_state", [False, True]) @pytest.mark.parametrize("categorical_actions", [False, True]) def test_wrapper(self, game_string, return_state, categorical_actions): - import pyspiel - - base_env = pyspiel.load_game(game_string).new_initial_state() + base_env = OpenSpielWrapper.lib.load_game(game_string).new_initial_state() env_torchrl = OpenSpielWrapper( base_env, categorical_actions=categorical_actions, return_state=return_state ) @@ -201,12 +198,62 @@ def test_reset_state(self, game_string, return_state, categorical_actions): td = env.reset() assert (td == td_init).all() - def test_chance_not_implemented(self): + def test_chance_nodes_supported(self): + # Verify that games with chance nodes now load successfully + env = OpenSpielEnv("bridge") + assert not env._env.is_chance_node() + td = env.reset() + assert not env._env.is_chance_node() + assert td.shape == torch.Size([]) + + def test_chance_nodes_resolved_before_initial_step(self): + env = OpenSpielEnv("backgammon") + assert not env._env.is_chance_node() + + td = env.rand_step() + + assert not env._env.is_chance_node() + assert td["next", "done"].shape == torch.Size([1]) + + def test_chance_nodes_resolved_after_step(self): + env = OpenSpielEnv("backgammon") + + env.reset() + td = env.step(env.full_action_spec.rand()) + + assert not env._env.is_chance_node() + assert td["next"].shape == torch.Size([]) + + def test_custom_chance_sampler(self): + samples = [] + + def chance_sampler(actions, probabilities): + samples.append((actions, probabilities)) + return actions[0] + + base_env = OpenSpielWrapper.lib.load_game("backgammon").new_initial_state() + env = OpenSpielWrapper(base_env, chance_sampler=chance_sampler) + + assert samples + assert not env._env.is_chance_node() + + def test_seeded_chance_sampler(self): + env0 = OpenSpielEnv("backgammon", return_state=True) + env0.set_seed(0) + td0 = env0.reset() + + env1 = OpenSpielEnv("backgammon", return_state=True) + env1.set_seed(0) + td1 = env1.reset() + + assert td0["state"] == td1["state"] + + def test_chance_batch_size_not_supported(self): with pytest.raises( - NotImplementedError, - match="not yet supported", + ValueError, + match="OpenSpielWrapper only supports single-environment mode", ): - OpenSpielEnv("bridge") + OpenSpielEnv("backgammon", batch_size=torch.Size([4])) # NOTE: Each of the registered envs are around 180 MB, so only test a few. diff --git a/torchrl/envs/libs/openspiel.py b/torchrl/envs/libs/openspiel.py index 38684310d1e..8ed1fb8e810 100644 --- a/torchrl/envs/libs/openspiel.py +++ b/torchrl/envs/libs/openspiel.py @@ -7,6 +7,7 @@ import importlib.util +import numpy as np import torch from tensordict import TensorDict, TensorDictBase @@ -42,14 +43,20 @@ class OpenSpielWrapper(_EnvWrapper): Documentation: https://openspiel.readthedocs.io/en/latest/index.html + Supports games with chance nodes. Chance outcomes are automatically sampled + and resolved between player decision nodes, so agents only observe states + where they must act. + Args: env (pyspiel.State): the game to wrap. Keyword Args: device (torch.device, optional): if provided, the device on which the data is to be cast. Defaults to ``None``. - batch_size (torch.Size, optional): the batch size of the environment. - Defaults to ``torch.Size([])``. + batch_size (torch.Size, optional): must be ``torch.Size([])`` (single-env only). + This wrapper does not support batching multiple game instances. For parallel + environments, wrap multiple OpenSpielWrapper instances with ParallelEnv or + SerialEnv. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. @@ -67,6 +74,9 @@ class OpenSpielWrapper(_EnvWrapper): to :meth:`reset` to reset to that state, rather than resetting to the initial state. Defaults to ``False``. + chance_sampler (callable, optional): a callable taking (actions, probabilities) + and returning a sampled action index. If ``None``, uses numpy's + ``random.choice``. Defaults to ``None``. Attributes: available_envs: environments available to build @@ -161,14 +171,28 @@ def __init__( | dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, categorical_actions: bool = False, return_state: bool = False, + chance_sampler=None, **kwargs, ): if env is not None: kwargs["env"] = env + batch_size = kwargs.get("batch_size", torch.Size([])) + if not isinstance(batch_size, torch.Size): + batch_size = torch.Size(batch_size) if batch_size else torch.Size([]) + if batch_size != torch.Size([]): + raise ValueError( + f"OpenSpielWrapper only supports single-environment mode (batch_size=torch.Size([])). " + f"Got batch_size={batch_size}. " + f"For multiple parallel environments, use ParallelEnv or SerialEnv to wrap " + f"multiple OpenSpielWrapper instances." + ) + self.group_map = group_map self.categorical_actions = categorical_actions self.return_state = return_state + self._rng = np.random.default_rng() + self._chance_sampler = chance_sampler or self._default_chance_sampler self._cached_game = None super().__init__(**kwargs) @@ -183,14 +207,27 @@ def _check_kwargs(self, kwargs: dict): if not isinstance(env, pyspiel.State): raise TypeError("env is not of type 'pyspiel.State'.") + def _default_chance_sampler(self, actions: list, probabilities: list[float]) -> int: + return int(self._rng.choice(actions, p=probabilities)) + + def _resolve_chance_nodes(self): + """Resolve all consecutive chance nodes until reaching a decision node or terminal state. + + This method automatically samples and applies chance outcomes so that agents + only interact with the environment at decision nodes. + """ + while self._env.is_chance_node(): + outcomes = self._env.chance_outcomes() + if not outcomes: + break + actions, probabilities = zip(*outcomes) + sampled_action = self._chance_sampler(list(actions), list(probabilities)) + self._env.apply_action(sampled_action) + def _build_env(self, env, requires_grad: bool = False, **kwargs): game = env.get_game() game_type = game.get_type() - if game.max_chance_outcomes() != 0: - raise NotImplementedError( - f"The game '{game_type.short_name}' has chance nodes, which are not yet supported." - ) if game_type.dynamics == self.lib.GameType.Dynamics.MEAN_FIELD: # NOTE: It is unclear from the OpenSpiel documentation what exactly # "mean field" means exactly, and there is no documentation on the @@ -203,6 +240,7 @@ def _build_env(self, env, requires_grad: bool = False, **kwargs): return env def _init_env(self): + self._resolve_chance_nodes() self._update_action_mask() def _get_game(self): @@ -335,8 +373,7 @@ def _make_specs(self, env: pyspiel.State) -> None: # noqa: F821 self.reward_spec = Composite(reward_spec) def _set_seed(self, seed: int | None) -> None: - if seed is not None: - raise NotImplementedError("This environment has no seed.") + self._rng = np.random.default_rng(seed) def current_player(self): return self._env.current_player() @@ -480,6 +517,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: else: self._step_sequential(tensordict) + self._resolve_chance_nodes() self._update_action_mask() return self._make_td_out() @@ -497,6 +535,7 @@ def _reset( new_env = game.new_initial_state() self._env = new_env + self._resolve_chance_nodes() self._update_action_mask() return self._make_td_out(exclude_reward=True) @@ -508,6 +547,10 @@ class OpenSpielEnv(OpenSpielWrapper): Documentation: https://openspiel.readthedocs.io/en/latest/index.html + Supports games with chance nodes. Chance outcomes are automatically sampled + and resolved between player decision nodes, so agents only observe states + where they must act. + Args: game_string (str): the name of the game to wrap. Must be part of :attr:`~.available_envs`. @@ -515,8 +558,10 @@ class OpenSpielEnv(OpenSpielWrapper): Keyword Args: device (torch.device, optional): if provided, the device on which the data is to be cast. Defaults to ``None``. - batch_size (torch.Size, optional): the batch size of the environment. - Defaults to ``torch.Size([])``. + batch_size (torch.Size, optional): must be ``torch.Size([])`` (single-env only). + This wrapper does not support batching multiple game instances. For parallel + environments, wrap multiple OpenSpielEnv instances with ParallelEnv or + SerialEnv. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. @@ -534,6 +579,9 @@ class OpenSpielEnv(OpenSpielWrapper): to :meth:`reset` to reset to that state, rather than resetting to the initial state. Defaults to ``False``. + chance_sampler (callable, optional): a callable taking (actions, probabilities) + and returning a sampled action index. If ``None``, uses numpy's + ``random.choice``. Defaults to ``None``. Attributes: available_envs: environments available to build