From 17e0d2439f04bd10217651a7ab7f4c1b7e43467e Mon Sep 17 00:00:00 2001 From: Abhishek Sriraman Date: Wed, 17 Jun 2026 16:02:13 +0100 Subject: [PATCH 1/2] [Feature] Add chance node support to OpenSpiel wrapper Implement automatic chance node resolution in OpenSpielWrapper and OpenSpielEnv, allowing environments with stochastic outcomes (like backgammon, dice-based games) to be used without modification. Add guard in OpenSpielWrapper for batch size, as only single-batch environments are supported. --- test/envs/libs/test_openspiel_chance.py | 276 ++++++++++++++++++++++++ test/libs/test_misc.py | 12 +- torchrl/envs/libs/openspiel.py | 67 +++++- 3 files changed, 339 insertions(+), 16 deletions(-) create mode 100644 test/envs/libs/test_openspiel_chance.py diff --git a/test/envs/libs/test_openspiel_chance.py b/test/envs/libs/test_openspiel_chance.py new file mode 100644 index 00000000000..a28474a9dc1 --- /dev/null +++ b/test/envs/libs/test_openspiel_chance.py @@ -0,0 +1,276 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import pytest + +torch = pytest.importorskip("torch") +pyspiel = pytest.importorskip("pyspiel") + +import numpy as np +from torchrl.envs import OpenSpielEnv, OpenSpielWrapper +from torchrl.envs.utils import check_env_specs + + +class TestChanceNodeSampler: + """Test the default chance outcome sampler.""" + + def test_default_sampler_respects_probabilities(self): + """Verify default sampler samples according to probabilities.""" + env = OpenSpielEnv("backgammon") + + actions = [0, 1, 2] + probs = [0.5, 0.3, 0.2] + + # Sample many times and verify distribution roughly matches + samples = [env._default_chance_sampler(actions, probs) for _ in range(10000)] + sample_counts = {a: samples.count(a) for a in actions} + + # Check that empirical probabilities are roughly correct (within 5%) + for action, prob in zip(actions, probs): + empirical_prob = sample_counts[action] / len(samples) + assert ( + abs(empirical_prob - prob) < 0.05 + ), f"Action {action}: expected ~{prob}, got {empirical_prob}" + + def test_custom_sampler_injection(self): + """Verify custom sampler can be injected.""" + # Deterministic sampler that always picks the first action + def custom_sampler(actions, probs): + return actions[0] + + game = pyspiel.load_game("backgammon").new_initial_state() + env = OpenSpielWrapper(game, chance_sampler=custom_sampler) + + assert env._chance_sampler is custom_sampler + + def test_sampler_with_single_outcome(self): + """Verify sampler handles single outcome correctly.""" + env = OpenSpielEnv("backgammon") + + actions = [5] + probs = [1.0] + + sampled = env._default_chance_sampler(actions, probs) + assert sampled == 5 + + +class TestChanceNodeResolution: + """Test chance node resolution during reset and step.""" + + def test_backgammon_reset_resolves_chance(self): + """Verify reset resolves initial chance nodes in backgammon.""" + env = OpenSpielEnv("backgammon") + + td = env.reset() + + # After reset, we should be at a decision node (not chance node) + assert ( + not env._env.is_chance_node() + ), "After reset, environment should not be at a chance node" + + # Should have valid observation and current player + assert "agents" in td + assert td.shape == torch.Size([]) + + def test_backgammon_step_resolves_chance(self): + """Verify step resolves chance nodes that may occur after action.""" + env = OpenSpielEnv("backgammon") + + env.reset() + + # Take an action + action = env.full_action_spec.rand() + env.step(action) + + # After step, should not be at a chance node + assert ( + not env._env.is_chance_node() + ), "After step, environment should not be at a chance node" + + def test_full_rollout_with_chance_game(self): + """Verify complete rollout works with stochastic game.""" + env = OpenSpielEnv("backgammon") + + td = env.reset() + episode_length = 0 + max_steps = 100 + done = td["done"] + + while not done.item() and episode_length < max_steps: + action = env.full_action_spec.rand() + td = env.step(action) + done = td["next", "done"] + episode_length += 1 + + # Verify episode completed without errors + assert episode_length > 0 + assert not env._env.is_chance_node() + + def test_state_serialization_with_chance(self): + """Verify state serialization captures post-chance state.""" + env = OpenSpielEnv("backgammon", return_state=True) + + td1 = env.reset() + td1["state"] + + # Take a step + action = env.full_action_spec.rand() + td2 = env.step(action) + td2["next"]["state"] + + # Reset to state2 + env.reset(td2["next"]) + + # The new state should match what we captured + assert not env._env.is_chance_node() + + +class TestSpecsUnchanged: + """Test that specs remain unchanged with chance support.""" + + def test_specs_valid_for_chance_game(self): + """Verify env specs satisfy check_env_specs for chance game.""" + env = OpenSpielEnv("backgammon") + + # This should not raise + check_env_specs(env) + + def test_observation_spec_structure(self): + """Verify observation spec structure unchanged.""" + env = OpenSpielEnv("backgammon") + + spec = env.observation_spec + + # Should have agents and current_player + assert "agents" in spec + assert "current_player" in spec + + # Specs should be deterministic (same for repeated calls) + spec2 = env.observation_spec + assert str(spec) == str(spec2) + + +class TestDeterministicSampling: + """Test deterministic sampling for reproducible testing.""" + + def test_deterministic_sampler(self): + """Verify deterministic sampler produces same sequence.""" + + def seeded_sampler(seed): + rng = np.random.RandomState(seed) + + def sampler(actions, probs): + return int(rng.choice(actions, p=probs)) + + return sampler + + # Create two envs with same seed + game1 = pyspiel.load_game("backgammon").new_initial_state() + game2 = pyspiel.load_game("backgammon").new_initial_state() + + sampler1 = seeded_sampler(42) + sampler2 = seeded_sampler(42) + + env1 = OpenSpielWrapper(game1, chance_sampler=sampler1) + env2 = OpenSpielWrapper(game2, chance_sampler=sampler2) + + td1 = env1.reset() + td2 = env2.reset() + + # Observations should match (up to floating point) + if "agents" in td1: + obs1 = td1["agents"].get("observation", None) + obs2 = td2["agents"].get("observation", None) + if obs1 is not None and obs2 is not None: + assert torch.allclose(obs1, obs2, atol=1e-6) + + +class TestParallelVsSequential: + """Test chance resolution works for both game types.""" + + @pytest.mark.skipif( + not hasattr(pyspiel, "load_game"), reason="pyspiel not available" + ) + def test_sequential_game_with_chance(self): + """Verify sequential game handling.""" + env = OpenSpielEnv("backgammon") + + assert not env.parallel + + env.reset() + assert not env._env.is_chance_node() + + action = env.full_action_spec.rand() + env.step(action) + assert not env._env.is_chance_node() + + @pytest.mark.skipif( + not hasattr(pyspiel, "load_game"), reason="pyspiel not available" + ) + def test_parallel_game_basic(self): + """Verify parallel game still works (may or may not have chance).""" + # Load a parallel game (rock-paper-scissors is parallel) + try: + env = OpenSpielEnv("rock_paper_scissors") + + if env.parallel: + env.reset() + # Should handle parallel games correctly + assert not env._env.is_chance_node() + except Exception: + # Not all games available, skip if rock_paper_scissors not found + pytest.skip("rock_paper_scissors not available") + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_terminal_state_with_chance_history(self): + """Verify terminal states are handled correctly.""" + env = OpenSpielEnv("backgammon") + + td = env.reset() + + # Play until terminal + steps = 0 + done = td["done"] + while not done.item() and steps < 200: + action = env.full_action_spec.rand() + td = env.step(action) + done = td["next", "done"] + steps += 1 + + # Terminal state should be valid + assert env._env.is_terminal() + assert not env._env.is_chance_node() + + def test_repeated_resets(self): + """Verify repeated resets work correctly.""" + env = OpenSpielEnv("backgammon", return_state=True) + + for _ in range(5): + td = env.reset() + assert not env._env.is_chance_node() + assert "state" in td + + def test_batch_size_not_supported(self): + """Verify that non-empty batch_size raises an error.""" + with pytest.raises( + ValueError, + match="OpenSpielWrapper only supports single-environment mode", + ): + OpenSpielEnv("backgammon", batch_size=torch.Size([4])) + + def test_batch_size_empty_allowed(self): + """Verify that empty batch_size is accepted.""" + # This should not raise + env = OpenSpielEnv("backgammon", batch_size=torch.Size([])) + assert env.batch_size == torch.Size([]) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/libs/test_misc.py b/test/libs/test_misc.py index ad45fd7a671..1c95fe4929b 100644 --- a/test/libs/test_misc.py +++ b/test/libs/test_misc.py @@ -201,12 +201,12 @@ def test_reset_state(self, game_string, return_state, categorical_actions): td = env.reset() assert (td == td_init).all() - def test_chance_not_implemented(self): - with pytest.raises( - NotImplementedError, - match="not yet supported", - ): - OpenSpielEnv("bridge") + def test_chance_nodes_supported(self): + # Verify that games with chance nodes now load successfully + env = OpenSpielEnv("bridge") + td = env.reset() + assert not env._env.is_chance_node() + assert td.shape == torch.Size([]) # NOTE: Each of the registered envs are around 180 MB, so only test a few. diff --git a/torchrl/envs/libs/openspiel.py b/torchrl/envs/libs/openspiel.py index 38684310d1e..8f7afbfd8d4 100644 --- a/torchrl/envs/libs/openspiel.py +++ b/torchrl/envs/libs/openspiel.py @@ -7,6 +7,7 @@ import importlib.util +import numpy as np import torch from tensordict import TensorDict, TensorDictBase @@ -42,14 +43,20 @@ class OpenSpielWrapper(_EnvWrapper): Documentation: https://openspiel.readthedocs.io/en/latest/index.html + Supports games with chance nodes. Chance outcomes are automatically sampled + and resolved between player decision nodes, so agents only observe states + where they must act. + Args: env (pyspiel.State): the game to wrap. Keyword Args: device (torch.device, optional): if provided, the device on which the data is to be cast. Defaults to ``None``. - batch_size (torch.Size, optional): the batch size of the environment. - Defaults to ``torch.Size([])``. + batch_size (torch.Size, optional): must be ``torch.Size([])`` (single-env only). + This wrapper does not support batching multiple game instances. For parallel + environments, wrap multiple OpenSpielWrapper instances with ParallelEnv or + SerialEnv. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. @@ -67,6 +74,9 @@ class OpenSpielWrapper(_EnvWrapper): to :meth:`reset` to reset to that state, rather than resetting to the initial state. Defaults to ``False``. + chance_sampler (callable, optional): a callable taking (actions, probabilities) + and returning a sampled action index. If ``None``, uses numpy's + ``random.choice``. Defaults to ``None``. Attributes: available_envs: environments available to build @@ -161,14 +171,28 @@ def __init__( | dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, categorical_actions: bool = False, return_state: bool = False, + chance_sampler=None, **kwargs, ): if env is not None: kwargs["env"] = env + batch_size = kwargs.get("batch_size", torch.Size([])) + if not isinstance(batch_size, torch.Size): + batch_size = torch.Size(batch_size) if batch_size else torch.Size([]) + if batch_size != torch.Size([]): + raise ValueError( + f"OpenSpielWrapper only supports single-environment mode (batch_size=torch.Size([])). " + f"Got batch_size={batch_size}. " + f"For multiple parallel environments, use ParallelEnv or SerialEnv to wrap " + f"multiple OpenSpielWrapper instances." + ) + self.group_map = group_map self.categorical_actions = categorical_actions self.return_state = return_state + self._rng = np.random.default_rng() + self._chance_sampler = chance_sampler or self._default_chance_sampler self._cached_game = None super().__init__(**kwargs) @@ -183,14 +207,27 @@ def _check_kwargs(self, kwargs: dict): if not isinstance(env, pyspiel.State): raise TypeError("env is not of type 'pyspiel.State'.") + def _default_chance_sampler(self, actions: list, probabilities: list[float]) -> int: + return int(self._rng.choice(actions, p=probabilities)) + + def _resolve_chance_nodes(self): + """Resolve all consecutive chance nodes until reaching a decision node or terminal state. + + This method automatically samples and applies chance outcomes so that agents + only interact with the environment at decision nodes. + """ + while self._env.is_chance_node(): + outcomes = self._env.chance_outcomes() + if not outcomes: + break + actions, probabilities = zip(*outcomes) + sampled_action = self._chance_sampler(list(actions), list(probabilities)) + self._env.apply_action(sampled_action) + def _build_env(self, env, requires_grad: bool = False, **kwargs): game = env.get_game() game_type = game.get_type() - if game.max_chance_outcomes() != 0: - raise NotImplementedError( - f"The game '{game_type.short_name}' has chance nodes, which are not yet supported." - ) if game_type.dynamics == self.lib.GameType.Dynamics.MEAN_FIELD: # NOTE: It is unclear from the OpenSpiel documentation what exactly # "mean field" means exactly, and there is no documentation on the @@ -335,8 +372,7 @@ def _make_specs(self, env: pyspiel.State) -> None: # noqa: F821 self.reward_spec = Composite(reward_spec) def _set_seed(self, seed: int | None) -> None: - if seed is not None: - raise NotImplementedError("This environment has no seed.") + self._rng = np.random.default_rng(seed) def current_player(self): return self._env.current_player() @@ -480,6 +516,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: else: self._step_sequential(tensordict) + self._resolve_chance_nodes() self._update_action_mask() return self._make_td_out() @@ -497,6 +534,7 @@ def _reset( new_env = game.new_initial_state() self._env = new_env + self._resolve_chance_nodes() self._update_action_mask() return self._make_td_out(exclude_reward=True) @@ -508,6 +546,10 @@ class OpenSpielEnv(OpenSpielWrapper): Documentation: https://openspiel.readthedocs.io/en/latest/index.html + Supports games with chance nodes. Chance outcomes are automatically sampled + and resolved between player decision nodes, so agents only observe states + where they must act. + Args: game_string (str): the name of the game to wrap. Must be part of :attr:`~.available_envs`. @@ -515,8 +557,10 @@ class OpenSpielEnv(OpenSpielWrapper): Keyword Args: device (torch.device, optional): if provided, the device on which the data is to be cast. Defaults to ``None``. - batch_size (torch.Size, optional): the batch size of the environment. - Defaults to ``torch.Size([])``. + batch_size (torch.Size, optional): must be ``torch.Size([])`` (single-env only). + This wrapper does not support batching multiple game instances. For parallel + environments, wrap multiple OpenSpielEnv instances with ParallelEnv or + SerialEnv. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. @@ -534,6 +578,9 @@ class OpenSpielEnv(OpenSpielWrapper): to :meth:`reset` to reset to that state, rather than resetting to the initial state. Defaults to ``False``. + chance_sampler (callable, optional): a callable taking (actions, probabilities) + and returning a sampled action index. If ``None``, uses numpy's + ``random.choice``. Defaults to ``None``. Attributes: available_envs: environments available to build From fc14a30f2d612dd86b2513f2ebb9d7944f258d50 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 22 Jun 2026 07:44:07 +0100 Subject: [PATCH 2/2] [BugFix] Fix OpenSpiel chance-node setup tests --- test/envs/libs/test_openspiel_chance.py | 276 ------------------------ test/libs/test_misc.py | 57 ++++- torchrl/envs/libs/openspiel.py | 1 + 3 files changed, 53 insertions(+), 281 deletions(-) delete mode 100644 test/envs/libs/test_openspiel_chance.py diff --git a/test/envs/libs/test_openspiel_chance.py b/test/envs/libs/test_openspiel_chance.py deleted file mode 100644 index a28474a9dc1..00000000000 --- a/test/envs/libs/test_openspiel_chance.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -import pytest - -torch = pytest.importorskip("torch") -pyspiel = pytest.importorskip("pyspiel") - -import numpy as np -from torchrl.envs import OpenSpielEnv, OpenSpielWrapper -from torchrl.envs.utils import check_env_specs - - -class TestChanceNodeSampler: - """Test the default chance outcome sampler.""" - - def test_default_sampler_respects_probabilities(self): - """Verify default sampler samples according to probabilities.""" - env = OpenSpielEnv("backgammon") - - actions = [0, 1, 2] - probs = [0.5, 0.3, 0.2] - - # Sample many times and verify distribution roughly matches - samples = [env._default_chance_sampler(actions, probs) for _ in range(10000)] - sample_counts = {a: samples.count(a) for a in actions} - - # Check that empirical probabilities are roughly correct (within 5%) - for action, prob in zip(actions, probs): - empirical_prob = sample_counts[action] / len(samples) - assert ( - abs(empirical_prob - prob) < 0.05 - ), f"Action {action}: expected ~{prob}, got {empirical_prob}" - - def test_custom_sampler_injection(self): - """Verify custom sampler can be injected.""" - # Deterministic sampler that always picks the first action - def custom_sampler(actions, probs): - return actions[0] - - game = pyspiel.load_game("backgammon").new_initial_state() - env = OpenSpielWrapper(game, chance_sampler=custom_sampler) - - assert env._chance_sampler is custom_sampler - - def test_sampler_with_single_outcome(self): - """Verify sampler handles single outcome correctly.""" - env = OpenSpielEnv("backgammon") - - actions = [5] - probs = [1.0] - - sampled = env._default_chance_sampler(actions, probs) - assert sampled == 5 - - -class TestChanceNodeResolution: - """Test chance node resolution during reset and step.""" - - def test_backgammon_reset_resolves_chance(self): - """Verify reset resolves initial chance nodes in backgammon.""" - env = OpenSpielEnv("backgammon") - - td = env.reset() - - # After reset, we should be at a decision node (not chance node) - assert ( - not env._env.is_chance_node() - ), "After reset, environment should not be at a chance node" - - # Should have valid observation and current player - assert "agents" in td - assert td.shape == torch.Size([]) - - def test_backgammon_step_resolves_chance(self): - """Verify step resolves chance nodes that may occur after action.""" - env = OpenSpielEnv("backgammon") - - env.reset() - - # Take an action - action = env.full_action_spec.rand() - env.step(action) - - # After step, should not be at a chance node - assert ( - not env._env.is_chance_node() - ), "After step, environment should not be at a chance node" - - def test_full_rollout_with_chance_game(self): - """Verify complete rollout works with stochastic game.""" - env = OpenSpielEnv("backgammon") - - td = env.reset() - episode_length = 0 - max_steps = 100 - done = td["done"] - - while not done.item() and episode_length < max_steps: - action = env.full_action_spec.rand() - td = env.step(action) - done = td["next", "done"] - episode_length += 1 - - # Verify episode completed without errors - assert episode_length > 0 - assert not env._env.is_chance_node() - - def test_state_serialization_with_chance(self): - """Verify state serialization captures post-chance state.""" - env = OpenSpielEnv("backgammon", return_state=True) - - td1 = env.reset() - td1["state"] - - # Take a step - action = env.full_action_spec.rand() - td2 = env.step(action) - td2["next"]["state"] - - # Reset to state2 - env.reset(td2["next"]) - - # The new state should match what we captured - assert not env._env.is_chance_node() - - -class TestSpecsUnchanged: - """Test that specs remain unchanged with chance support.""" - - def test_specs_valid_for_chance_game(self): - """Verify env specs satisfy check_env_specs for chance game.""" - env = OpenSpielEnv("backgammon") - - # This should not raise - check_env_specs(env) - - def test_observation_spec_structure(self): - """Verify observation spec structure unchanged.""" - env = OpenSpielEnv("backgammon") - - spec = env.observation_spec - - # Should have agents and current_player - assert "agents" in spec - assert "current_player" in spec - - # Specs should be deterministic (same for repeated calls) - spec2 = env.observation_spec - assert str(spec) == str(spec2) - - -class TestDeterministicSampling: - """Test deterministic sampling for reproducible testing.""" - - def test_deterministic_sampler(self): - """Verify deterministic sampler produces same sequence.""" - - def seeded_sampler(seed): - rng = np.random.RandomState(seed) - - def sampler(actions, probs): - return int(rng.choice(actions, p=probs)) - - return sampler - - # Create two envs with same seed - game1 = pyspiel.load_game("backgammon").new_initial_state() - game2 = pyspiel.load_game("backgammon").new_initial_state() - - sampler1 = seeded_sampler(42) - sampler2 = seeded_sampler(42) - - env1 = OpenSpielWrapper(game1, chance_sampler=sampler1) - env2 = OpenSpielWrapper(game2, chance_sampler=sampler2) - - td1 = env1.reset() - td2 = env2.reset() - - # Observations should match (up to floating point) - if "agents" in td1: - obs1 = td1["agents"].get("observation", None) - obs2 = td2["agents"].get("observation", None) - if obs1 is not None and obs2 is not None: - assert torch.allclose(obs1, obs2, atol=1e-6) - - -class TestParallelVsSequential: - """Test chance resolution works for both game types.""" - - @pytest.mark.skipif( - not hasattr(pyspiel, "load_game"), reason="pyspiel not available" - ) - def test_sequential_game_with_chance(self): - """Verify sequential game handling.""" - env = OpenSpielEnv("backgammon") - - assert not env.parallel - - env.reset() - assert not env._env.is_chance_node() - - action = env.full_action_spec.rand() - env.step(action) - assert not env._env.is_chance_node() - - @pytest.mark.skipif( - not hasattr(pyspiel, "load_game"), reason="pyspiel not available" - ) - def test_parallel_game_basic(self): - """Verify parallel game still works (may or may not have chance).""" - # Load a parallel game (rock-paper-scissors is parallel) - try: - env = OpenSpielEnv("rock_paper_scissors") - - if env.parallel: - env.reset() - # Should handle parallel games correctly - assert not env._env.is_chance_node() - except Exception: - # Not all games available, skip if rock_paper_scissors not found - pytest.skip("rock_paper_scissors not available") - - -class TestEdgeCases: - """Test edge cases and error conditions.""" - - def test_terminal_state_with_chance_history(self): - """Verify terminal states are handled correctly.""" - env = OpenSpielEnv("backgammon") - - td = env.reset() - - # Play until terminal - steps = 0 - done = td["done"] - while not done.item() and steps < 200: - action = env.full_action_spec.rand() - td = env.step(action) - done = td["next", "done"] - steps += 1 - - # Terminal state should be valid - assert env._env.is_terminal() - assert not env._env.is_chance_node() - - def test_repeated_resets(self): - """Verify repeated resets work correctly.""" - env = OpenSpielEnv("backgammon", return_state=True) - - for _ in range(5): - td = env.reset() - assert not env._env.is_chance_node() - assert "state" in td - - def test_batch_size_not_supported(self): - """Verify that non-empty batch_size raises an error.""" - with pytest.raises( - ValueError, - match="OpenSpielWrapper only supports single-environment mode", - ): - OpenSpielEnv("backgammon", batch_size=torch.Size([4])) - - def test_batch_size_empty_allowed(self): - """Verify that empty batch_size is accepted.""" - # This should not raise - env = OpenSpielEnv("backgammon", batch_size=torch.Size([])) - assert env.batch_size == torch.Size([]) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/test/libs/test_misc.py b/test/libs/test_misc.py index 1c95fe4929b..f945642b556 100644 --- a/test/libs/test_misc.py +++ b/test/libs/test_misc.py @@ -85,8 +85,7 @@ def test_robohive(self, envname, from_pixels, from_depths): # List of OpenSpiel games to test # TODO: Some of the games in `OpenSpielWrapper.available_envs` raise errors for -# a few different reasons, mostly because we do not support chance nodes yet. So -# we cannot run tests on all of them yet. +# a few different reasons, so we cannot run tests on all of them yet. _openspiel_games = [ # ---------------- # Sequential games @@ -161,9 +160,7 @@ def test_all_envs(self, game_string, return_state, categorical_actions): @pytest.mark.parametrize("return_state", [False, True]) @pytest.mark.parametrize("categorical_actions", [False, True]) def test_wrapper(self, game_string, return_state, categorical_actions): - import pyspiel - - base_env = pyspiel.load_game(game_string).new_initial_state() + base_env = OpenSpielWrapper.lib.load_game(game_string).new_initial_state() env_torchrl = OpenSpielWrapper( base_env, categorical_actions=categorical_actions, return_state=return_state ) @@ -204,10 +201,60 @@ def test_reset_state(self, game_string, return_state, categorical_actions): def test_chance_nodes_supported(self): # Verify that games with chance nodes now load successfully env = OpenSpielEnv("bridge") + assert not env._env.is_chance_node() td = env.reset() assert not env._env.is_chance_node() assert td.shape == torch.Size([]) + def test_chance_nodes_resolved_before_initial_step(self): + env = OpenSpielEnv("backgammon") + assert not env._env.is_chance_node() + + td = env.rand_step() + + assert not env._env.is_chance_node() + assert td["next", "done"].shape == torch.Size([1]) + + def test_chance_nodes_resolved_after_step(self): + env = OpenSpielEnv("backgammon") + + env.reset() + td = env.step(env.full_action_spec.rand()) + + assert not env._env.is_chance_node() + assert td["next"].shape == torch.Size([]) + + def test_custom_chance_sampler(self): + samples = [] + + def chance_sampler(actions, probabilities): + samples.append((actions, probabilities)) + return actions[0] + + base_env = OpenSpielWrapper.lib.load_game("backgammon").new_initial_state() + env = OpenSpielWrapper(base_env, chance_sampler=chance_sampler) + + assert samples + assert not env._env.is_chance_node() + + def test_seeded_chance_sampler(self): + env0 = OpenSpielEnv("backgammon", return_state=True) + env0.set_seed(0) + td0 = env0.reset() + + env1 = OpenSpielEnv("backgammon", return_state=True) + env1.set_seed(0) + td1 = env1.reset() + + assert td0["state"] == td1["state"] + + def test_chance_batch_size_not_supported(self): + with pytest.raises( + ValueError, + match="OpenSpielWrapper only supports single-environment mode", + ): + OpenSpielEnv("backgammon", batch_size=torch.Size([4])) + # NOTE: Each of the registered envs are around 180 MB, so only test a few. _mlagents_registered_envs = [ diff --git a/torchrl/envs/libs/openspiel.py b/torchrl/envs/libs/openspiel.py index 8f7afbfd8d4..8ed1fb8e810 100644 --- a/torchrl/envs/libs/openspiel.py +++ b/torchrl/envs/libs/openspiel.py @@ -240,6 +240,7 @@ def _build_env(self, env, requires_grad: bool = False, **kwargs): return env def _init_env(self): + self._resolve_chance_nodes() self._update_action_mask() def _get_game(self):