Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Update Jumanji version and the env specs #8

Merged
merged 3 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)

# Interact with the (jit-able) environment
action = env.action_spec().generate_value() # Action selection (dummy value here)
action = env.action_spec.generate_value() # Action selection (dummy value here)
state, timestep = jax.jit(env.step)(state, action) # Take a step and observe the next state and time step
```

Expand Down
2 changes: 1 addition & 1 deletion matrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.

from jumanji.registration import make, register
from jumanji.version import __version__

from matrax.env import MatrixGame
from matrax.games import climbing_game, conflict_games, no_conflict_games, penalty_games
from matrax.types import Observation, State
from matrax.version import __version__

"""Environment Registration"""

Expand Down
20 changes: 10 additions & 10 deletions matrax/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from functools import cached_property, partial
from typing import Tuple

import chex
Expand All @@ -25,7 +25,7 @@
from matrax.types import Observation, State


class MatrixGame(Environment[State]):
class MatrixGame(Environment[State, specs.MultiDiscreteArray, Observation]):
"""JAX implementation of the 2-player matrix game environment:
https://github.com/uoe-agents/matrix-games

Expand All @@ -42,7 +42,7 @@ class MatrixGame(Environment[State]):
env = MatrixGame(payoff_matrix)
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
action = env.action_spec().generate_value()
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
```
"""
Expand Down Expand Up @@ -92,9 +92,9 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
dummy_actions = jnp.ones((self.num_agents,), int) * -1

# collect first observations and create timestep
agent_obs = jax.vmap(
functools.partial(self._make_agent_observation, dummy_actions)
)(jnp.arange(self.num_agents))
agent_obs = jax.vmap(partial(self._make_agent_observation, dummy_actions))(
jnp.arange(self.num_agents)
)
observation = Observation(
agent_obs=agent_obs,
step_count=state.step_count,
Expand Down Expand Up @@ -124,16 +124,14 @@ def compute_reward(
reward_idx = tuple(actions)
return payoff_matrix_per_agent[reward_idx].astype(float)

rewards = jax.vmap(functools.partial(compute_reward, actions))(
self.payoff_matrix
)
rewards = jax.vmap(partial(compute_reward, actions))(self.payoff_matrix)

# construct timestep and check environment termination
steps = state.step_count + 1
done = steps >= self.time_limit

# compute next observation
agent_obs = jax.vmap(functools.partial(self._make_agent_observation, actions))(
agent_obs = jax.vmap(partial(self._make_agent_observation, actions))(
jnp.arange(self.num_agents)
)
next_observation = Observation(
Expand Down Expand Up @@ -169,6 +167,7 @@ def _make_agent_observation(
lambda: jnp.zeros(self.num_agents, int),
)

@cached_property
def observation_spec(self) -> specs.Spec[Observation]:
"""Specification of the observation of the MatrixGame environment.
Returns:
Expand All @@ -190,6 +189,7 @@ def observation_spec(self) -> specs.Spec[Observation]:
step_count=step_count,
)

@cached_property
def action_spec(self) -> specs.MultiDiscreteArray:
"""Returns the action spec.
Since this is a multi-agent environment, the environment expects an array of actions.
Expand Down
4 changes: 2 additions & 2 deletions matrax/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def matrix_game_env_with_state() -> MatrixGame:

def test_matrix_game__specs(matrix_game_env: MatrixGame) -> None:
"""Validate environment specs conform to the expected shapes and values"""
action_spec = matrix_game_env.action_spec()
observation_spec = matrix_game_env.observation_spec()
action_spec = matrix_game_env.action_spec
observation_spec = matrix_game_env.observation_spec

assert observation_spec.agent_obs.shape == (2, 2) # type: ignore
assert action_spec.num_values.shape[0] == matrix_game_env.num_agents
Expand Down
2 changes: 1 addition & 1 deletion matrax/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.0.4"
__version__ = "0.0.5"
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
jumanji==0.3.1
jumanji
Loading