diff --git a/matrax/__init__.py b/matrax/__init__.py index a22e141..d493a6d 100644 --- a/matrax/__init__.py +++ b/matrax/__init__.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jumanji.registration import make, register -from jumanji.version import __version__ +from jumanji.registration import register -from matrax.env import MatrixGame from matrax.games import climbing_game, conflict_games, no_conflict_games, penalty_games -from matrax.types import Observation, State """Environment Registration""" @@ -92,7 +89,7 @@ ) register( f"Conflict-{_id}-stateful-v0", - entry_point="matrix:MatrixGame", + entry_point="matrax:MatrixGame", kwargs={ "payoff_matrix": payoff_matrix, "keep_state": True, diff --git a/matrax/env_test.py b/matrax/env_test.py index 6a3e17a..934380b 100644 --- a/matrax/env_test.py +++ b/matrax/env_test.py @@ -21,8 +21,9 @@ from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep -from matrax import MatrixGame, State +from matrax.env import MatrixGame from matrax.games import climbing_game +from matrax.types import State @pytest.fixture