From 7b0acb6f822165aa1174661ebd2c04c54cc21fb0 Mon Sep 17 00:00:00 2001 From: RuanJohn Date: Tue, 16 Jan 2024 10:31:48 +0200 Subject: [PATCH 1/2] test: add test to ensure timestep shapes and dtypes remain consistent over reset and step --- matrax/env_test.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/matrax/env_test.py b/matrax/env_test.py index 6a3e17a..03676c4 100644 --- a/matrax/env_test.py +++ b/matrax/env_test.py @@ -58,7 +58,7 @@ def test_matrix_game__reset(matrix_game_env: MatrixGame) -> None: key1, key2 = random.PRNGKey(0), random.PRNGKey(1) state1, timestep1 = reset_fn(key1) - state2, timestep2 = reset_fn(key2) + state2, _ = reset_fn(key2) assert isinstance(timestep1, TimeStep) assert isinstance(state1, State) @@ -111,10 +111,10 @@ def test_matrix_game__step(matrix_game_env_with_state: MatrixGame) -> None: # Check that rewards have the correct number of dimensions assert jnp.ndim(timestep1.reward) == 1 - assert jnp.ndim(timestep.reward) == 0 + assert jnp.ndim(timestep.reward) == 1 # Check that discounts have the correct number of dimensions - assert jnp.ndim(timestep1.discount) == 0 - assert jnp.ndim(timestep.discount) == 0 + assert jnp.ndim(timestep1.discount) == 1 + assert jnp.ndim(timestep.discount) == 1 # Check that the state is made of DeviceArrays, this is false for the non-jitted # step function since unpacking random.split returns numpy arrays and not device arrays. assert_is_jax_array_tree(new_state1) @@ -157,7 +157,6 @@ def test_matrix_game__reward(matrix_game_env: MatrixGame) -> None: state, timestep = matrix_game_env.reset(state_key) state, timestep = step_fn(state, jnp.array([0, 0])) - jax.debug.print("rewards: {r}", r=timestep.reward) assert jnp.array_equal(timestep.reward, jnp.array([11, 11])) state, timestep = step_fn(state, jnp.array([1, 0])) @@ -174,3 +173,14 @@ def test_matrix_game__reward(matrix_game_env: MatrixGame) -> None: state, timestep = step_fn(state, jnp.array([2, 2])) assert jnp.array_equal(timestep.reward, jnp.array([5, 5])) + + +def test_matrix_game__timesteps_equivalent(matrix_game_env: MatrixGame) -> None: + """Validate that all timestep attributes have the same dtype and shape over reset and step.""" + step_fn = jax.jit(matrix_game_env.step) + state_key = random.PRNGKey(10) + state, init_timestep = matrix_game_env.reset(state_key) + + state, new_timestep = step_fn(state, jnp.array([0, 0])) + + chex.assert_trees_all_equal_shapes_and_dtypes(init_timestep, new_timestep) From a372c9187057a1292ed2fb7fac36a06456601bf3 Mon Sep 17 00:00:00 2001 From: RuanJohn Date: Tue, 16 Jan 2024 10:33:01 +0200 Subject: [PATCH 2/2] feat: pass shape to reset, transition and termination functions --- matrax/env.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/matrax/env.py b/matrax/env.py index 336a196..1f4f4a5 100644 --- a/matrax/env.py +++ b/matrax/env.py @@ -99,7 +99,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: agent_obs=agent_obs, step_count=state.step_count, ) - timestep = restart(observation=observation) + timestep = restart(observation=observation, shape=self.num_agents) return state, timestep def step( @@ -122,7 +122,7 @@ def compute_reward( actions: chex.Array, payoff_matrix_per_agent: chex.Array ) -> chex.Array: reward_idx = tuple(actions) - return payoff_matrix_per_agent[reward_idx] + return payoff_matrix_per_agent[reward_idx].astype(float) rewards = jax.vmap(functools.partial(compute_reward, actions))( self.payoff_matrix @@ -143,10 +143,12 @@ def compute_reward( timestep = jax.lax.cond( done, - termination, - transition, - rewards, - next_observation, + lambda: termination( + reward=rewards, observation=next_observation, shape=self.num_agents + ), + lambda: transition( + reward=rewards, observation=next_observation, shape=self.num_agents + ), ) # create environment state