Skip to content

Commit

Permalink
Added pre-commit hooks and fixed linting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
elliottower committed Mar 9, 2023
1 parent d932f1b commit 94caa24
Show file tree
Hide file tree
Showing 23 changed files with 1,376 additions and 526 deletions.
62 changes: 62 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
---
repos:
- repo: https://github.com/python/black
rev: 23.1.0
hooks:
- id: black
- repo: https://github.com/codespell-project/codespell
rev: v2.2.2
hooks:
- id: codespell
args:
- --skip=*.css,*.js,*.map,*.scss,*svg
- --ignore-words-list=magent
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
args:
- '--per-file-ignores=*/__init__.py:F401 test/all_parameter_combs_test.py:F405 pettingzoo/classic/go/go.py:W605'
- --extend-ignore=E203
- --max-complexity=205
- --max-line-length=300
- --show-source
- --statistics
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
hooks:
- id: pyupgrade
# TODO: remove `--keep-runtime-typing` option
args: ["--py37-plus", "--keep-runtime-typing"]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: mixed-line-ending
args: ["--fix=lf"]
- repo: https://github.com/pycqa/pydocstyle
rev: 6.3.0
hooks:
- id: pydocstyle
args:
- --source
- --explain
- --convention=google
- --count
# TODO: Remove ignoring rules D101, D102, D103, D105
- --add-ignore=D100,D107,D101,D102,D103,D105
exclude: "__init__.py$|^pettingzoo.test|^docs"
additional_dependencies: ["toml"]
# - repo: local
# hooks:
# - id: pyright
# name: pyright
# entry: pyright
# language: node
# pass_filenames: false
# types: [python]
# additional_dependencies: ["pyright"]
2 changes: 1 addition & 1 deletion gobblet/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . __version__ import __version__
from .__version__ import __version__
2 changes: 1 addition & 1 deletion gobblet/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.3.4'
__version__ = "1.3.4"
31 changes: 14 additions & 17 deletions gobblet/examples/example_RLlib.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import glob
import os
from typing import Tuple
from gymnasium import spaces

import ray.tune
from ray import init
Expand All @@ -13,8 +12,8 @@
from ray.tune.registry import register_env

from gobblet import gobblet_v1
from gobblet.models.action_mask_model import TorchActionMaskModel
from gobblet.game.utils import get_project_root
from gobblet.models.action_mask_model import TorchActionMaskModel

torch, nn = try_import_torch()

Expand All @@ -32,11 +31,6 @@ def env_creator():
# wrap the pettingzoo env in MultiAgent RLLib
env = PettingZooEnv(env_creator())

# Convert obs space and action space to gym
# observation_space = env.observation_space["observation"]
# observation_space = spaces.Box(observation_space.low, observation_space.high, observation_space.shape, observation_space.dtype)
# action_space = spaces.Discrete(env.action_space.n)

agents = ["player_1", "player_2"]
custom_config = {
"env": env_name,
Expand All @@ -47,7 +41,7 @@ def env_creator():
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
# "num_gpus": int(torch.cuda.device_count()),
"num_gpus": 0,
"num_workers": 2, #os.cpu_count() - 1,
"num_workers": 2, # os.cpu_count() - 1,
"multiagent": {
"policies": {
name: (None, env.observation_space, env.action_space, {})
Expand Down Expand Up @@ -97,13 +91,17 @@ def train_ray(ppo_config, timesteps_total: int = 10):


def load_ray(path, ppo_config):
"""
"""Load ray.
Load a trained RLlib agent from the specified path.
Call this before testing a trained agent.
:param path:
Path pointing to the agent's saved checkpoint (only used for RLlib agents)
:param ppo_config:
dict config
Args:
path: Path pointing to the agent's saved checkpoint (only used for RLlib agents)
ppo_config: dict config
Returns:
trainer: RLlib trainer object
"""
trainer = ppo.PPOTrainer(config=ppo_config)
trainer.restore(path)
Expand Down Expand Up @@ -145,7 +143,7 @@ def sample_trainer(trainer, env):


def tune_training_loop(timesteps_total=10000):
"""train trainer and sample"""
"""Train trainer and sample."""
trainer, env, ppo_config = prepare_train()

# train trainer
Expand All @@ -162,8 +160,7 @@ def tune_training_loop(timesteps_total=10000):


def manual_training_loop(timesteps_total=10000):
"""train trainer and sample"""

"""Train trainer and sample."""
trainer, env, ppo_config = prepare_train()
trainer_trained = train(trainer, max_steps=timesteps_total)

Expand All @@ -172,4 +169,4 @@ def manual_training_loop(timesteps_total=10000):

if __name__ == "__main__":
init(local_mode=True)
tune_training_loop()
tune_training_loop()
28 changes: 19 additions & 9 deletions gobblet/examples/example_basic.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from gobblet import gobblet_v1
import argparse
import numpy as np
import time

import numpy as np

from gobblet import gobblet_v1


def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"--render_mode", type=str, default="human", choices=["human", "rgb_array", "text", "text_full"],
help="Choose the rendering mode for the game."
"--render_mode",
type=str,
default="human",
choices=["human", "rgb_array", "text", "text_full"],
help="Choose the rendering mode for the game.",
)

parser.add_argument(
Expand All @@ -23,10 +28,12 @@ def get_parser() -> argparse.ArgumentParser:

return parser


def get_args() -> argparse.Namespace:
parser = get_parser()
return parser.parse_known_args()[0]


if __name__ == "__main__":
# train the agent and watch its performance in a match!
args = get_args()
Expand All @@ -41,18 +48,21 @@ def get_args() -> argparse.Namespace:
env.render() # need to render the environment before pygame can take user input

for agent in env.agent_iter():

observation, reward, termination, truncation, info = env.last()

if termination or truncation:
print(f"Agent: ({agent}), Reward: {reward}, info: {info}")
env.step(None)

else:
action_mask = observation['action_mask']
action = np.random.choice(np.arange(len(action_mask)), p=action_mask / np.sum(action_mask))
action_mask = observation["action_mask"]
action = np.random.choice(
np.arange(len(action_mask)), p=action_mask / np.sum(action_mask)
)

if args.render_mode == "human":
time.sleep(.5) # Wait .5 seconds between moves so the user can follow the sequence of moves
time.sleep(
0.5
) # Wait .5 seconds between moves so the user can follow the sequence of moves

env.step(action)
env.step(action)
38 changes: 28 additions & 10 deletions gobblet/examples/example_record_game.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,45 @@
import argparse

import numpy as np
import pygame

from gobblet.game.utils import GIFRecorder
import numpy as np
import argparse


def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"--seed", type=int, default=None, help="Set random seed manually (will only affect CPU agents)"
"--seed",
type=int,
default=None,
help="Set random seed manually (will only affect CPU agents)",
)
parser.add_argument(
"--cpu-players", type=int, default=1, choices=[0, 1, 2], help="Number of CPU players (options: 0, 1, 2)"
"--cpu-players",
type=int,
default=1,
choices=[0, 1, 2],
help="Number of CPU players (options: 0, 1, 2)",
)
parser.add_argument(
"--player", type=int, default=0, choices=[0,1], help="Choose which player to play as: red = 0, yellow = 1"
"--player",
type=int,
default=0,
choices=[0, 1],
help="Choose which player to play as: red = 0, yellow = 1",
)
parser.add_argument(
"--screen-width", type=int, default=640, help="Width of pygame screen in pixels"
)

return parser


def get_args() -> argparse.Namespace:
parser = get_parser()
return parser.parse_known_args()[0]


if __name__ == "__main__":
from gobblet import gobblet_v1

Expand All @@ -44,7 +60,7 @@ def get_args() -> argparse.Namespace:
# Record the first frame (empty board)
recorder.capture_frame(env.unwrapped.screen)

manual_policy = gobblet_v1.ManualPolicy(env, recorder=recorder)
manual_policy = gobblet_v1.ManualGobbletPolicy(env, recorder=recorder)

for agent in env.agent_iter():
clock.tick(env.metadata["render_fps"])
Expand All @@ -60,11 +76,13 @@ def get_args() -> argparse.Namespace:
continue

if agent == manual_policy.agent and args.cpu_players < 2:
action = manual_policy(observation, agent)
action = manual_policy(observation, agent)
else:
action_mask = observation['action_mask']
action = np.random.choice(np.arange(len(action_mask)), p=action_mask / np.sum(action_mask))
action_mask = observation["action_mask"]
action = np.random.choice(
np.arange(len(action_mask)), p=action_mask / np.sum(action_mask)
)

env.step(action)

env.render()
env.render()
Loading

0 comments on commit 94caa24

Please sign in to comment.