diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..97ad3a7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,62 @@ +--- +repos: + - repo: https://github.com/python/black + rev: 23.1.0 + hooks: + - id: black + - repo: https://github.com/codespell-project/codespell + rev: v2.2.2 + hooks: + - id: codespell + args: + - --skip=*.css,*.js,*.map,*.scss,*svg + - --ignore-words-list=magent + - repo: https://github.com/PyCQA/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + args: + - '--per-file-ignores=*/__init__.py:F401 test/all_parameter_combs_test.py:F405 pettingzoo/classic/go/go.py:W605' + - --extend-ignore=E203 + - --max-complexity=205 + - --max-line-length=300 + - --show-source + - --statistics + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--profile", "black"] + - repo: https://github.com/asottile/pyupgrade + rev: v3.3.1 + hooks: + - id: pyupgrade + # TODO: remove `--keep-runtime-typing` option + args: ["--py37-plus", "--keep-runtime-typing"] + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: mixed-line-ending + args: ["--fix=lf"] + - repo: https://github.com/pycqa/pydocstyle + rev: 6.3.0 + hooks: + - id: pydocstyle + args: + - --source + - --explain + - --convention=google + - --count + # TODO: Remove ignoring rules D101, D102, D103, D105 + - --add-ignore=D100,D107,D101,D102,D103,D105 + exclude: "__init__.py$|^pettingzoo.test|^docs" + additional_dependencies: ["toml"] +# - repo: local +# hooks: +# - id: pyright +# name: pyright +# entry: pyright +# language: node +# pass_filenames: false +# types: [python] +# additional_dependencies: ["pyright"] diff --git a/gobblet/__init__.py b/gobblet/__init__.py index a80c3f6..9226fe7 100644 --- a/gobblet/__init__.py +++ b/gobblet/__init__.py @@ -1 +1 @@ -from . __version__ import __version__ \ No newline at end of file +from .__version__ import __version__ diff --git a/gobblet/__version__.py b/gobblet/__version__.py index ac422f1..f9e47b6 100644 --- a/gobblet/__version__.py +++ b/gobblet/__version__.py @@ -1 +1 @@ -__version__ = '1.3.4' +__version__ = "1.3.4" diff --git a/gobblet/examples/example_RLlib.py b/gobblet/examples/example_RLlib.py index 68e4cf3..e6e023d 100644 --- a/gobblet/examples/example_RLlib.py +++ b/gobblet/examples/example_RLlib.py @@ -1,7 +1,6 @@ import glob import os from typing import Tuple -from gymnasium import spaces import ray.tune from ray import init @@ -13,8 +12,8 @@ from ray.tune.registry import register_env from gobblet import gobblet_v1 -from gobblet.models.action_mask_model import TorchActionMaskModel from gobblet.game.utils import get_project_root +from gobblet.models.action_mask_model import TorchActionMaskModel torch, nn = try_import_torch() @@ -32,11 +31,6 @@ def env_creator(): # wrap the pettingzoo env in MultiAgent RLLib env = PettingZooEnv(env_creator()) - # Convert obs space and action space to gym - # observation_space = env.observation_space["observation"] - # observation_space = spaces.Box(observation_space.low, observation_space.high, observation_space.shape, observation_space.dtype) - # action_space = spaces.Discrete(env.action_space.n) - agents = ["player_1", "player_2"] custom_config = { "env": env_name, @@ -47,7 +41,7 @@ def env_creator(): # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. # "num_gpus": int(torch.cuda.device_count()), "num_gpus": 0, - "num_workers": 2, #os.cpu_count() - 1, + "num_workers": 2, # os.cpu_count() - 1, "multiagent": { "policies": { name: (None, env.observation_space, env.action_space, {}) @@ -97,13 +91,17 @@ def train_ray(ppo_config, timesteps_total: int = 10): def load_ray(path, ppo_config): - """ + """Load ray. + Load a trained RLlib agent from the specified path. Call this before testing a trained agent. - :param path: - Path pointing to the agent's saved checkpoint (only used for RLlib agents) - :param ppo_config: - dict config + + Args: + path: Path pointing to the agent's saved checkpoint (only used for RLlib agents) + ppo_config: dict config + + Returns: + trainer: RLlib trainer object """ trainer = ppo.PPOTrainer(config=ppo_config) trainer.restore(path) @@ -145,7 +143,7 @@ def sample_trainer(trainer, env): def tune_training_loop(timesteps_total=10000): - """train trainer and sample""" + """Train trainer and sample.""" trainer, env, ppo_config = prepare_train() # train trainer @@ -162,8 +160,7 @@ def tune_training_loop(timesteps_total=10000): def manual_training_loop(timesteps_total=10000): - """train trainer and sample""" - + """Train trainer and sample.""" trainer, env, ppo_config = prepare_train() trainer_trained = train(trainer, max_steps=timesteps_total) @@ -172,4 +169,4 @@ def manual_training_loop(timesteps_total=10000): if __name__ == "__main__": init(local_mode=True) - tune_training_loop() \ No newline at end of file + tune_training_loop() diff --git a/gobblet/examples/example_basic.py b/gobblet/examples/example_basic.py index 58adacd..6af71c8 100644 --- a/gobblet/examples/example_basic.py +++ b/gobblet/examples/example_basic.py @@ -1,14 +1,19 @@ -from gobblet import gobblet_v1 import argparse -import numpy as np import time +import numpy as np + +from gobblet import gobblet_v1 + def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument( - "--render_mode", type=str, default="human", choices=["human", "rgb_array", "text", "text_full"], - help="Choose the rendering mode for the game." + "--render_mode", + type=str, + default="human", + choices=["human", "rgb_array", "text", "text_full"], + help="Choose the rendering mode for the game.", ) parser.add_argument( @@ -23,10 +28,12 @@ def get_parser() -> argparse.ArgumentParser: return parser + def get_args() -> argparse.Namespace: parser = get_parser() return parser.parse_known_args()[0] + if __name__ == "__main__": # train the agent and watch its performance in a match! args = get_args() @@ -41,7 +48,6 @@ def get_args() -> argparse.Namespace: env.render() # need to render the environment before pygame can take user input for agent in env.agent_iter(): - observation, reward, termination, truncation, info = env.last() if termination or truncation: @@ -49,10 +55,14 @@ def get_args() -> argparse.Namespace: env.step(None) else: - action_mask = observation['action_mask'] - action = np.random.choice(np.arange(len(action_mask)), p=action_mask / np.sum(action_mask)) + action_mask = observation["action_mask"] + action = np.random.choice( + np.arange(len(action_mask)), p=action_mask / np.sum(action_mask) + ) if args.render_mode == "human": - time.sleep(.5) # Wait .5 seconds between moves so the user can follow the sequence of moves + time.sleep( + 0.5 + ) # Wait .5 seconds between moves so the user can follow the sequence of moves - env.step(action) \ No newline at end of file + env.step(action) diff --git a/gobblet/examples/example_record_game.py b/gobblet/examples/example_record_game.py index 7bad479..694a11b 100644 --- a/gobblet/examples/example_record_game.py +++ b/gobblet/examples/example_record_game.py @@ -1,18 +1,32 @@ +import argparse + +import numpy as np import pygame + from gobblet.game.utils import GIFRecorder -import numpy as np -import argparse + def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument( - "--seed", type=int, default=None, help="Set random seed manually (will only affect CPU agents)" + "--seed", + type=int, + default=None, + help="Set random seed manually (will only affect CPU agents)", ) parser.add_argument( - "--cpu-players", type=int, default=1, choices=[0, 1, 2], help="Number of CPU players (options: 0, 1, 2)" + "--cpu-players", + type=int, + default=1, + choices=[0, 1, 2], + help="Number of CPU players (options: 0, 1, 2)", ) parser.add_argument( - "--player", type=int, default=0, choices=[0,1], help="Choose which player to play as: red = 0, yellow = 1" + "--player", + type=int, + default=0, + choices=[0, 1], + help="Choose which player to play as: red = 0, yellow = 1", ) parser.add_argument( "--screen-width", type=int, default=640, help="Width of pygame screen in pixels" @@ -20,10 +34,12 @@ def get_parser() -> argparse.ArgumentParser: return parser + def get_args() -> argparse.Namespace: parser = get_parser() return parser.parse_known_args()[0] + if __name__ == "__main__": from gobblet import gobblet_v1 @@ -44,7 +60,7 @@ def get_args() -> argparse.Namespace: # Record the first frame (empty board) recorder.capture_frame(env.unwrapped.screen) - manual_policy = gobblet_v1.ManualPolicy(env, recorder=recorder) + manual_policy = gobblet_v1.ManualGobbletPolicy(env, recorder=recorder) for agent in env.agent_iter(): clock.tick(env.metadata["render_fps"]) @@ -60,11 +76,13 @@ def get_args() -> argparse.Namespace: continue if agent == manual_policy.agent and args.cpu_players < 2: - action = manual_policy(observation, agent) + action = manual_policy(observation, agent) else: - action_mask = observation['action_mask'] - action = np.random.choice(np.arange(len(action_mask)), p=action_mask / np.sum(action_mask)) + action_mask = observation["action_mask"] + action = np.random.choice( + np.arange(len(action_mask)), p=action_mask / np.sum(action_mask) + ) env.step(action) - env.render() \ No newline at end of file + env.render() diff --git a/gobblet/examples/example_tianshou_DQN.py b/gobblet/examples/example_tianshou_DQN.py index 8e72953..f11fdde 100644 --- a/gobblet/examples/example_tianshou_DQN.py +++ b/gobblet/examples/example_tianshou_DQN.py @@ -1,18 +1,8 @@ # adapted from https://github.com/Farama-Foundation/PettingZoo/blob/master/tutorials/Tianshou/3_cli_and_logging.py -""" -This is a full example of using Tianshou with MARL to train agents, complete with argument parsing (CLI) and logging. - -Author: Will (https://github.com/WillDudley) - -Python version used: 3.8.10 - -Requirements: -pettingzoo == 1.22.0 -git+https://github.com/thu-ml/tianshou -""" import argparse import os +import time from copy import deepcopy from typing import Optional, Tuple @@ -23,7 +13,7 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv -from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy +from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -31,9 +21,8 @@ from gobblet import gobblet_v1 from gobblet.game.collector_manual_policy import ManualPolicyCollector -from gobblet.game.utils import GIFRecorder from gobblet.game.greedy_policy_tianshou import GreedyPolicy -import time +from gobblet.game.utils import GIFRecorder def get_parser() -> argparse.ArgumentParser: @@ -42,7 +31,9 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument("--eps-test", type=float, default=0.05) parser.add_argument("--eps-train", type=float, default=0.1) parser.add_argument("--buffer-size", type=int, default=20000) - parser.add_argument("--lr", type=float, default=1e-4) # TODO: Changing this to 1e-5 for some reason makes it pause after 3 or 4 epochs + parser.add_argument( + "--lr", type=float, default=1e-4 + ) # TODO: Changing this to 1e-5 for some reason makes it pause after 3 or 4 epochs parser.add_argument( "--gamma", type=float, default=0.9, help="a smaller gamma favors earlier win" ) @@ -60,14 +51,53 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument("--test-num", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.1) - parser.add_argument("--render_mode", type=str, default="human", choices=["human","rgb_array", "text", "text_full"], help="Choose the rendering mode for the game.") - parser.add_argument("--debug", action="store_true", help="Flag to enable to print extra debugging info") - parser.add_argument("--self_play", action="store_true", help="Flag to enable training via self-play (as opposed to fixed opponent)") - parser.add_argument("--self_play_generations", type=int, default=5, help="Number of generations of self-play agents to train (each generation can have multiple epochs of training)") - parser.add_argument("--self_play_greedy", action="store_true", help="Flag to have self-play train against a greedy agent for the first generation") - parser.add_argument("--cpu-players", type=int, default=2, choices=[1, 2], help="Number of CPU players (options: 1, 2)") - parser.add_argument("--player", type=int, default=0, choices=[0,1], help="Choose which player to play as: red = 0, yellow = 1") - parser.add_argument("--record", action="store_true", help="Flag to save a recording of the game (game.gif)") + parser.add_argument( + "--render_mode", + type=str, + default="human", + choices=["human", "rgb_array", "text", "text_full"], + help="Choose the rendering mode for the game.", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Flag to enable to print extra debugging info", + ) + parser.add_argument( + "--self_play", + action="store_true", + help="Flag to enable training via self-play (as opposed to fixed opponent)", + ) + parser.add_argument( + "--self_play_generations", + type=int, + default=5, + help="Number of generations of self-play agents to train (each generation can have multiple epochs of training)", + ) + parser.add_argument( + "--self_play_greedy", + action="store_true", + help="Flag to have self-play train against a greedy agent for the first generation", + ) + parser.add_argument( + "--cpu-players", + type=int, + default=2, + choices=[1, 2], + help="Number of CPU players (options: 1, 2)", + ) + parser.add_argument( + "--player", + type=int, + default=0, + choices=[0, 1], + help="Choose which player to play as: red = 0, yellow = 1", + ) + parser.add_argument( + "--record", + action="store_true", + help="Flag to save a recording of the game (game.gif)", + ) parser.add_argument( "--win-rate", type=float, @@ -120,12 +150,11 @@ def get_agents( env = get_env() observation_space = ( env.observation_space["observation"] - if isinstance(env.observation_space, gym.spaces.Dict) or isinstance(env.observation_space, gymnasium.spaces.Dict) + if isinstance(env.observation_space, gym.spaces.Dict) + or isinstance(env.observation_space, gymnasium.spaces.Dict) else env.observation_space ) - args.state_shape = ( - observation_space.shape or observation_space.n - ) + args.state_shape = observation_space.shape or observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if agent_learn is None: # model @@ -168,7 +197,9 @@ def get_agents( agent_opponent.load_state_dict(torch.load(args.opponent_path)) else: # agent_opponent = RandomPolicy() - agent_opponent = GreedyPolicy() # Greedy policy is a difficult opponent, should yeild much better results than random + agent_opponent = ( + GreedyPolicy() + ) # Greedy policy is a difficult opponent, should yield much better results than random if args.agent_id == 1: agents = [agent_learn, agent_opponent] @@ -203,33 +234,46 @@ def train_selfplay( env = get_env() observation_space = ( env.observation_space["observation"] - if isinstance(env.observation_space, gymnasium.spaces.Dict) or isinstance(env.observation_space, gym.spaces.Dict) + if isinstance(env.observation_space, gymnasium.spaces.Dict) + or isinstance(env.observation_space, gym.spaces.Dict) else env.observation_space ) - args.state_shape = ( - observation_space.shape or observation_space.n - ) + args.state_shape = observation_space.shape or observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # ======== model setup ========= # Note: custom models can be specified using the agent_fixed and agent_learn arguments if agent_learn is None: - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device - ).to(args.device) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) agent_learn = DQNPolicy( - net, optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq) + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq, + ) if agent_fixed is None: - net_fixed = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device - ).to(args.device) + net_fixed = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + ).to(args.device) optim_fixed = torch.optim.SGD(net_fixed.parameters(), lr=0) agent_fixed = DQNPolicy( - net_fixed, optim_fixed, args.gamma, args.n_step, - target_update_freq=args.target_update_freq) + net_fixed, + optim_fixed, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq, + ) # Load initial opponent from file # path = os.path.join(args.logdir, 'gobblet', 'dqn', 'policy.pth') @@ -240,7 +284,6 @@ def train_selfplay( policy = MultiAgentPolicyManager(agents, env) agents_list = list(policy.policies.keys()) - # ======== collector setup ========= train_collector = Collector( policy, @@ -266,7 +309,8 @@ def save_best_fn(policy): args.logdir, "gobblet", "dqn-selfplay", "policy.pth" ) torch.save( - policy.policies[agents_list[args.agent_id - 1]].state_dict(), model_save_path + policy.policies[agents_list[args.agent_id - 1]].state_dict(), + model_save_path, ) def stop_fn(mean_rewards): @@ -287,29 +331,44 @@ def test_fn(epoch, env_step): def reward_metric(rews): return rews[:, 0] - # Self-play loop for i in range(args.self_play_generations): result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_best_fn=save_best_fn, update_per_step=args.update_per_step, - logger=logger, test_in_train=False, reward_metric=reward_metric) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + update_per_step=args.update_per_step, + logger=logger, + test_in_train=False, + reward_metric=reward_metric, + ) - print(f"Launching game between learned agent ({type(policy.policies[agents_list[0]]).__name__}) and fixed agent ({type(policy.policies[agents_list[1]]).__name__}):") + print( + f"Launching game between learned agent ({type(policy.policies[agents_list[0]]).__name__}) and fixed agent ({type(policy.policies[agents_list[1]]).__name__}):" + ) # Render a single game between the learned policy and fixed policy from last generation watch_selfplay(args, policy.policies[agents_list[0]]) # Set fixed opponent policy as to the current trained policy (updates every epoch) policy.policies[agents_list[1]] = deepcopy(policy.policies[agents_list[0]]) - print('--- SELF-PLAY GENERATION: {} ---'.format(i + 1)) + print(f"--- SELF-PLAY GENERATION: {i + 1} ---") - print(f"Launching game between learned agent ({type(policy.policies[agents_list[0]]).__name__}) and itself ({type(policy.policies[agents_list[1]]).__name__}):") + print( + f"Launching game between learned agent ({type(policy.policies[agents_list[0]]).__name__}) and itself ({type(policy.policies[agents_list[1]]).__name__}):" + ) # Render a single game between the learned policy and itself watch_selfplay(args, policy.policies[agents_list[0]]) - model_save_path = os.path.join(args.logdir, 'gobblet', 'dqn-selfplay', 'policy.pth') + model_save_path = os.path.join(args.logdir, "gobblet", "dqn-selfplay", "policy.pth") torch.save(policy.policies[agents_list[0]].state_dict(), model_save_path) return result, policy.policies[agents_list[0]] @@ -358,11 +417,10 @@ def save_best_fn(policy): if hasattr(args, "model_save_path"): model_save_path = args.model_save_path else: - model_save_path = os.path.join( - args.logdir, "gobblet", "dqn", "policy.pth" - ) + model_save_path = os.path.join(args.logdir, "gobblet", "dqn", "policy.pth") torch.save( - policy.policies[agents_list[args.agent_id - 1]].state_dict(), model_save_path + policy.policies[agents_list[args.agent_id - 1]].state_dict(), + model_save_path, ) def stop_fn(mean_rewards): @@ -416,7 +474,9 @@ def watch( policy.policies[agents_list[args.agent_id - 1]].set_eps(args.eps_test) collector = Collector(policy, env, exploration_noise=True) - pettingzoo_env = env.workers[0].env.env # DummyVectorEnv -> Tianshou PettingZoo Wrapper -> PettingZoo Env + pettingzoo_env = env.workers[ + 0 + ].env.env # DummyVectorEnv -> Tianshou PettingZoo Wrapper -> PettingZoo Env if args.record: recorder = GIFRecorder() recorder.capture_frame(pettingzoo_env.unwrapped.screen) @@ -432,30 +492,39 @@ def watch( if collector.data.terminated or collector.data.truncated: rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()} [{type(policy.policies[agents_list[0]]).__name__}]") - print(f"Final reward: {rews[:, 1].mean()}, length: {lens.mean()} [{type(policy.policies[agents_list[1]]).__name__}]") + print( + f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()} [{type(policy.policies[agents_list[0]]).__name__}]" + ) + print( + f"Final reward: {rews[:, 1].mean()}, length: {lens.mean()} [{type(policy.policies[agents_list[1]]).__name__}]" + ) if recorder is not None: recorder.end_recording(pettingzoo_env.unwrapped.screen) recorder = None + def watch_selfplay( - args: argparse.Namespace = get_args(), - agent: Optional[BasePolicy] = None, + args: argparse.Namespace = get_args(), + agent: Optional[BasePolicy] = None, ) -> None: # env = get_env(render_mode=args.render_mode, args=args) env = DummyVectorEnv([lambda: get_env(render_mode=args.render_mode, args=args)]) - policy = MultiAgentPolicyManager(policies=[agent, deepcopy(agent)], env=get_env(render_mode=args.render_mode, args=args)) # fixed here + policy = MultiAgentPolicyManager( + policies=[agent, deepcopy(agent)], + env=get_env(render_mode=args.render_mode, args=args), + ) # fixed here policy.eval() collector = Collector(policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=True) rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}") + # ======== allows the user to input moves and play vs a pre-trained agent ====== def play( - args: argparse.Namespace = get_args(), - agent_learn: Optional[BasePolicy] = None, - agent_opponent: Optional[BasePolicy] = None, + args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, ) -> None: env = DummyVectorEnv([lambda: get_env(render_mode=args.render_mode, args=args)]) # env = get_env(render_mode=args.render_mode, args=args) # Throws error because collector looks for length, could just override though since I'm using my own collector @@ -469,24 +538,34 @@ def play( # Experimental: let the CPU agent to continue training (TODO: check if this actually changes things meaningfully) # policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_train) - collector = ManualPolicyCollector(policy, env, exploration_noise=True) # Collector for CPU actions + collector = ManualPolicyCollector( + policy, env, exploration_noise=True + ) # Collector for CPU actions - pettingzoo_env = env.workers[0].env.env # DummyVectorEnv -> Tianshou PettingZoo Wrapper -> PettingZoo Env + pettingzoo_env = env.workers[ + 0 + ].env.env # DummyVectorEnv -> Tianshou PettingZoo Wrapper -> PettingZoo Env if args.record: recorder = GIFRecorder() else: recorder = None - manual_policy = gobblet_v1.ManualPolicy(env=pettingzoo_env, agent_id=args.player, recorder=recorder) # Gobblet keyboard input requires access to raw_env (uses functions from board) + manual_policy = gobblet_v1.ManualGobbletPolicy( + env=pettingzoo_env, agent_id=args.player, recorder=recorder + ) # Gobblet keyboard input requires access to raw_env (uses functions from board) while pettingzoo_env.agents: agent_id = collector.data.obs.agent_id # If it is the players turn and there are less than 2 CPU players (at least one human player) if agent_id == pettingzoo_env.agents[args.player]: - observation = {"observation": collector.data.obs.obs.flatten(), - "action_mask": collector.data.obs.mask.flatten()} # PettingZoo expects a dict with this format + observation = { + "observation": collector.data.obs.obs.flatten(), + "action_mask": collector.data.obs.mask.flatten(), + } # PettingZoo expects a dict with this format action = manual_policy(observation, agent_id) - result = collector.collect_result(action=action.reshape(1), render=args.render) + result = collector.collect_result( + action=action.reshape(1), render=args.render + ) else: result = collector.collect(n_step=1, render=args.render) @@ -495,24 +574,31 @@ def play( if collector.data.terminated or collector.data.truncated: rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews[:, args.player].mean()}, length: {lens.mean()} [Human]") - print(f"Final reward: {rews[:, 1-args.player].mean()}, length: {lens.mean()} [{type(policy.policies[agents[1-args.player]]).__name__}]") + print( + f"Final reward: {rews[:, args.player].mean()}, length: {lens.mean()} [Human]" + ) + print( + f"Final reward: {rews[:, 1-args.player].mean()}, length: {lens.mean()} [{type(policy.policies[agents[1-args.player]]).__name__}]" + ) if recorder is not None: recorder.end_recording(pettingzoo_env.unwrapped.screen) recorder = None + if __name__ == "__main__": # train the agent and watch its performance in a match! args = get_args() if args.player == 1: - args.agent_id = 1 # Ensures trained agent is in the correct spot + args.agent_id = 1 # Ensures trained agent is in the correct spot if args.self_play: print("Training agent...") # Hard code the first fixed agent to be the greedy policy (after one generation it will be switched to a copy of the learned agent) agent_fixed = GreedyPolicy() if args.self_play_greedy else None - result, agent = train_selfplay(args=args, agent_fixed=agent_fixed) # Hard code the first fixed agent to be the greedy policy + result, agent = train_selfplay( + args=args, agent_fixed=agent_fixed + ) # Hard code the first fixed agent to be the greedy policy print("Starting game...") watch_selfplay(args, agent) @@ -524,4 +610,4 @@ def play( if args.cpu_players == 2: watch(args, agent) else: - play(args, agent) \ No newline at end of file + play(args, agent) diff --git a/gobblet/examples/example_tianshou_greedy.py b/gobblet/examples/example_tianshou_greedy.py index 005d7aa..1952617 100644 --- a/gobblet/examples/example_tianshou_greedy.py +++ b/gobblet/examples/example_tianshou_greedy.py @@ -1,31 +1,18 @@ # adapted from https://github.com/Farama-Foundation/PettingZoo/blob/master/tutorials/Tianshou/3_cli_and_logging.py -""" -This is a full example of using Tianshou with MARL to train agents, complete with argument parsing (CLI) and logging. - -Author: Will (https://github.com/WillDudley) - -Python version used: 3.8.10 - -Requirements: -pettingzoo == 1.22.0 -git+https://github.com/thu-ml/tianshou -""" import argparse -from typing import Optional, Tuple - +import time +from typing import Tuple import torch from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv -from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy - +from tianshou.policy import BasePolicy, MultiAgentPolicyManager from gobblet import gobblet_v1 from gobblet.game.collector_manual_policy import ManualPolicyCollector -from gobblet.game.utils import GIFRecorder from gobblet.game.greedy_policy_tianshou import GreedyPolicy -import time +from gobblet.game.utils import GIFRecorder def get_parser() -> argparse.ArgumentParser: @@ -34,7 +21,9 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument("--eps-test", type=float, default=0.05) parser.add_argument("--eps-train", type=float, default=0.1) parser.add_argument("--buffer-size", type=int, default=20000) - parser.add_argument("--lr", type=float, default=1e-4) # TODO: Changing this to 1e-5 for some reason makes it pause after 3 or 4 epochs + parser.add_argument( + "--lr", type=float, default=1e-4 + ) # TODO: Changing this to 1e-5 for some reason makes it pause after 3 or 4 epochs parser.add_argument( "--gamma", type=float, default=0.9, help="a smaller gamma favors earlier win" ) @@ -52,13 +41,49 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument("--test-num", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.1) - parser.add_argument("--render_mode", type=str, default="human", choices=["human","rgb_array", "text", "text_full"], help="Choose the rendering mode for the game.") - parser.add_argument("--debug", action="store_true", help="Flag to enable to print extra debugging info") - parser.add_argument("--self_play", action="store_true", help="Flag to enable training via self-play (as opposed to fixed opponent)") - parser.add_argument("--cpu-players", type=int, default=1, choices=[1, 2], help="Number of CPU players (options: 1, 2)") - parser.add_argument("--player", type=int, default=0, choices=[0,1], help="Choose which player to play as: red = 0, yellow = 1") - parser.add_argument("--record", action="store_true", help="Flag to save a recording of the game (game.gif)") - parser.add_argument("--depth", type=int, default=2, choices=[1, 2, 3], help="Search depth for greedy agent. Options: 1,2,3") + parser.add_argument( + "--render_mode", + type=str, + default="human", + choices=["human", "rgb_array", "text", "text_full"], + help="Choose the rendering mode for the game.", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Flag to enable to print extra debugging info", + ) + parser.add_argument( + "--self_play", + action="store_true", + help="Flag to enable training via self-play (as opposed to fixed opponent)", + ) + parser.add_argument( + "--cpu-players", + type=int, + default=1, + choices=[1, 2], + help="Number of CPU players (options: 1, 2)", + ) + parser.add_argument( + "--player", + type=int, + default=0, + choices=[0, 1], + help="Choose which player to play as: red = 0, yellow = 1", + ) + parser.add_argument( + "--record", + action="store_true", + help="Flag to save a recording of the game (game.gif)", + ) + parser.add_argument( + "--depth", + type=int, + default=2, + choices=[1, 2, 3], + help="Search depth for greedy agent. Options: 1,2,3", + ) parser.add_argument( "--win-rate", type=float, @@ -120,7 +145,9 @@ def watch() -> None: collector = ManualPolicyCollector(policy, env) - pettingzoo_env = env.workers[0].env.env # DummyVectorEnv -> Tianshou PettingZoo Wrapper -> PettingZoo Env + pettingzoo_env = env.workers[ + 0 + ].env.env # DummyVectorEnv -> Tianshou PettingZoo Wrapper -> PettingZoo Env if args.record: recorder = GIFRecorder() else: @@ -134,35 +161,50 @@ def watch() -> None: if collector.data.terminated or collector.data.truncated: rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()} [{type(policy.policies[agents[0]]).__name__}]") - print(f"Final reward: {rews[:, 1].mean()}, length: {lens.mean()} [{type(policy.policies[agents[1]]).__name__}]") + print( + f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()} [{type(policy.policies[agents[0]]).__name__}]" + ) + print( + f"Final reward: {rews[:, 1].mean()}, length: {lens.mean()} [{type(policy.policies[agents[1]]).__name__}]" + ) if recorder is not None: recorder.end_recording(pettingzoo_env.unwrapped.screen) recorder = None + # ======== allows the user to input moves and play vs a greedy agent ====== def play() -> None: env = DummyVectorEnv([lambda: get_env(render_mode=args.render_mode, args=args)]) policy, agents = get_agents() - collector = ManualPolicyCollector(policy, env, exploration_noise=True) # Collector for CPU actions + collector = ManualPolicyCollector( + policy, env, exploration_noise=True + ) # Collector for CPU actions - pettingzoo_env = env.workers[0].env.env # DummyVectorEnv -> Tianshou PettingZoo Wrapper -> PettingZoo Env + pettingzoo_env = env.workers[ + 0 + ].env.env # DummyVectorEnv -> Tianshou PettingZoo Wrapper -> PettingZoo Env if args.record: recorder = GIFRecorder() else: recorder = None - manual_policy = gobblet_v1.ManualPolicy(env=pettingzoo_env, agent_id=args.player, recorder=recorder) # Gobblet keyboard input requires access to raw_env (uses functions from board) + manual_policy = gobblet_v1.ManualGobbletPolicy( + env=pettingzoo_env, agent_id=args.player, recorder=recorder + ) # Gobblet keyboard input requires access to raw_env (uses functions from board) while pettingzoo_env.agents: agent_id = collector.data.obs.agent_id # If it is the players turn and there are less than 2 CPU players (at least one human player) if agent_id == pettingzoo_env.agents[args.player]: - observation = {"observation": collector.data.obs.obs.flatten(), - "action_mask": collector.data.obs.mask.flatten()} # PettingZoo expects a dict with this format + observation = { + "observation": collector.data.obs.obs.flatten(), + "action_mask": collector.data.obs.mask.flatten(), + } # PettingZoo expects a dict with this format action = manual_policy(observation, agent_id) - result = collector.collect_result(action=action.reshape(1), render=args.render) + result = collector.collect_result( + action=action.reshape(1), render=args.render + ) else: result = collector.collect(n_step=1, render=args.render) if recorder is not None: @@ -170,11 +212,16 @@ def play() -> None: if collector.data.terminated or collector.data.truncated: rews, lens = result["rews"], result["lens"] - print(f"\nFinal reward: {rews[:, args.player].mean()}, length: {lens.mean()} [Human]") - print(f"Final reward: {rews[:, 1-args.player].mean()}, length: {lens.mean()} [{type(policy.policies[agents[1-args.player]]).__name__}]") + print( + f"\nFinal reward: {rews[:, args.player].mean()}, length: {lens.mean()} [Human]" + ) + print( + f"Final reward: {rews[:, 1-args.player].mean()}, length: {lens.mean()} [{type(policy.policies[agents[1-args.player]]).__name__}]" + ) if recorder is not None: recorder.end_recording(pettingzoo_env.unwrapped.screen) + if __name__ == "__main__": # train the agent and watch its performance in a match! args = get_args() @@ -182,4 +229,4 @@ def play() -> None: if args.cpu_players == 2: watch() else: - play() \ No newline at end of file + play() diff --git a/gobblet/examples/example_user_input.py b/gobblet/examples/example_user_input.py index 0ba7f7f..628b751 100644 --- a/gobblet/examples/example_user_input.py +++ b/gobblet/examples/example_user_input.py @@ -1,18 +1,30 @@ -import pygame -import numpy as np import argparse +import numpy as np +import pygame + def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument( - "--seed", type=int, default=None, help="Set random seed manually (will only affect CPU agents)" + "--seed", + type=int, + default=None, + help="Set random seed manually (will only affect CPU agents)", ) parser.add_argument( - "--cpu-players", type=int, default=1, choices=[0, 1, 2], help="Number of CPU players (options: 0, 1, 2)" + "--cpu-players", + type=int, + default=1, + choices=[0, 1, 2], + help="Number of CPU players (options: 0, 1, 2)", ) parser.add_argument( - "--player", type=int, default=0, choices=[0,1], help="Choose which player to play as: red = 0, yellow = 1" + "--player", + type=int, + default=0, + choices=[0, 1], + help="Choose which player to play as: red = 0, yellow = 1", ) parser.add_argument( "--screen-width", type=int, default=640, help="Width of pygame screen in pixels" @@ -38,7 +50,7 @@ def get_args() -> argparse.Namespace: env = gobblet_v1.env(render_mode="human", args=args) env.reset() - manual_policy = gobblet_v1.ManualPolicy(env) + manual_policy = gobblet_v1.ManualGobbletPolicy(env) for agent in env.agent_iter(): clock.tick(env.metadata["render_fps"]) @@ -53,9 +65,11 @@ def get_args() -> argparse.Namespace: if agent == manual_policy.agent and args.cpu_players < 2: action = manual_policy(observation, agent) else: - action_mask = observation['action_mask'] - action = np.random.choice(np.arange(len(action_mask)), p=action_mask / np.sum(action_mask)) + action_mask = observation["action_mask"] + action = np.random.choice( + np.arange(len(action_mask)), p=action_mask / np.sum(action_mask) + ) env.step(action) - env.render() \ No newline at end of file + env.render() diff --git a/gobblet/game/board.py b/gobblet/game/board.py index c0ef6b8..abff10c 100644 --- a/gobblet/game/board.py +++ b/gobblet/game/board.py @@ -1,5 +1,6 @@ import numpy as np + class Board: def __init__(self, squares=None): # internally self.board.squares holds a representation of the gobblet board, consisting of three stacked 3x3 boards, one for each piece size. @@ -26,8 +27,6 @@ def __init__(self, squares=None): # [0, 0, 0], # [0, 0, 6] - - # empty -- 0 # player 0 -- 1 # player 1 -- -1 # Default: 2 @@ -41,7 +40,7 @@ def setup(self): self.calculate_winners() def get_action_from_pos_piece(self, pos, piece): - if pos in range(9) and piece in range(1,7): + if pos in range(9) and piece in range(1, 7): return 9 * (piece - 1) + pos else: return -1 @@ -76,29 +75,30 @@ def get_piece_size_from_action(self, action): # Returns the index on the board [0-26] for a given action def get_index_from_action(self, action): pos = self.get_pos_from_action(action) # [0-8] - piece_size = self.get_piece_size_from_action(action) # [1-3] - return pos + 9 * (piece_size - 1) # [0-26] - + piece_size = self.get_piece_size_from_action(action) # [1-3] + return pos + 9 * (piece_size - 1) # [0-26] # Return true if an action is legal, false otherwise. def is_legal(self, action, agent_index=0): - pos = self.get_pos_from_action(action) # [0-8] - piece = self.get_piece_from_action(action) # [1-6] - piece_size = self.get_piece_size_from_action(action) # [1-3] + pos = self.get_pos_from_action(action) # [0-8] + piece = self.get_piece_from_action(action) # [1-6] + piece_size = self.get_piece_size_from_action(action) # [1-3] agent_multiplier = 1 if agent_index == 0 else -1 board = self.squares.reshape(3, 9) # Check if this piece has been placed (if the piece number occurs anywhere on the level of that piece size) - if any(board[piece_size-1] == piece * agent_multiplier): - current_loc = np.where(board[piece_size-1] == piece * agent_multiplier)[0] # Returns array of values where piece is placed + if any(board[piece_size - 1] == piece * agent_multiplier): + current_loc = np.where(board[piece_size - 1] == piece * agent_multiplier)[ + 0 + ] # Returns array of values where piece is placed if len(current_loc) > 1: raise Exception("PIECE HAS BEEN USED TWICE") - return False # Piece has been used twice (not valid) + return False # Piece has been used twice (not valid) else: - current_loc = current_loc[0] # Current location [0-27] + current_loc = current_loc[0] # Current location [0-27] # If this piece is currently covered, moving it is not a legal action - if self.check_covered().reshape(3,9)[piece_size - 1][current_loc] == 1: + if self.check_covered().reshape(3, 9)[piece_size - 1][current_loc] == 1: return False # If this piece has not been placed @@ -107,8 +107,8 @@ def is_legal(self, action, agent_index=0): if flatboard[pos] == 0: return True else: - existing_piece_number = flatboard[pos] # [1-6] - existing_piece_size = (abs(existing_piece_number) + 1) // 2 # [1-3] + existing_piece_number = flatboard[pos] # [1-6] + existing_piece_size = (abs(existing_piece_number) + 1) // 2 # [1-3] if piece_size > existing_piece_size: return True else: @@ -153,20 +153,27 @@ def calculate_winners(self): self.winning_combinations = winning_combinations def print(self): - print(self.get_flatboard().reshape(3,3).transpose()) + print(self.get_flatboard().reshape(3, 3).transpose()) # returns flattened board consisting of only top pieces (excluding pieces which are gobbled by other pieces) def get_flatboard(self): flatboard = np.zeros(9) board = self.squares.reshape(3, 9) - for i in range(9): # For every square in the 3x3 grid, find the topmost element (largest piece) - top_piece_size = (np.amax( - abs(board[:, i]))) # [-3, 2, 0] denotes a large piece gobbling a medium piece, this will return 3 + for i in range( + 9 + ): # For every square in the 3x3 grid, find the topmost element (largest piece) + top_piece_size = np.amax( + abs(board[:, i]) + ) # [-3, 2, 0] denotes a large piece gobbling a medium piece, this will return 3 top_piece_index = list(abs(board[:, i])).index( - top_piece_size) # Get the row of the top piece (have to index [0] - top_piece_color = np.sign(board[ - top_piece_index, i]) # Get the color of the top piece: 1 for player_1 and -1 for player_2, 0 for neither - flatboard[i] = top_piece_color * top_piece_size # Simplify the board into only the top elements + top_piece_size + ) # Get the row of the top piece (have to index [0] + top_piece_color = np.sign( + board[top_piece_index, i] + ) # Get the color of the top piece: 1 for player_1 and -1 for player_2, 0 for neither + flatboard[i] = ( + top_piece_color * top_piece_size + ) # Simplify the board into only the top elements return flatboard # returns: @@ -193,26 +200,38 @@ def check_game_over(self): else: return False - def check_covered(self): # Return a 27 length array indicating which positions have a piece which is covered + def check_covered( + self, + ): # Return a 27 length array indicating which positions have a piece which is covered board = self.squares.reshape(3, 9) covered = np.zeros((3, 9)) - for i in range(9): # Check small pieces - if board[0, i] != 0 and (board[1, i] != 0 or board[2, i] != 0): # If there is a small piece, and either a large or medium piece covering it + for i in range(9): # Check small pieces + if board[0, i] != 0 and ( + board[1, i] != 0 or board[2, i] != 0 + ): # If there is a small piece, and either a large or medium piece covering it covered[0, i] = 1 - for i in range(9): # Check medium pieces - if board[1, i] != 0 and board[2, i] != 0: # If there is a meidum piece and a large piece covering it + for i in range(9): # Check medium pieces + if ( + board[1, i] != 0 and board[2, i] != 0 + ): # If there is a meidum piece and a large piece covering it covered[1, i] = 1 - covered[2, :] = 0 # Large pieces can't be covered + covered[2, :] = 0 # Large pieces can't be covered # Doesn't matter about what color is covering it, you can't move that piece this turn (allows self-gobbling) return covered.flatten() -# DEBUG + # DEBUG def print_pieces(self): open_indices = [i for i in range(len(self.squares)) if self.squares[i] == 0] open_squares = [np.where(self.get_flatboard() == 0)[0]] - occupied_squares = [i % 9 for i in range(len(self.squares)) if self.squares[i] != 0] # List with entries 0-9 - movable_squares = [i % 9 for i in occupied_squares if self.check_covered()[i] == 0] # List with entries 0-9 - covered_squares = [i % 9 for i in np.where(self.check_covered() == 1)[0] ] # List with entries 0-9 + occupied_squares = [ + i % 9 for i in range(len(self.squares)) if self.squares[i] != 0 + ] # List with entries 0-9 + movable_squares = [ + i % 9 for i in occupied_squares if self.check_covered()[i] == 0 + ] # List with entries 0-9 + covered_squares = [ + i % 9 for i in np.where(self.check_covered() == 1)[0] + ] # List with entries 0-9 print("open_indices: ", open_indices) print("open_squares: ", open_squares) print("squares with pieces: ", occupied_squares) @@ -220,4 +239,4 @@ def print_pieces(self): print("squares with covered pieces: ", covered_squares) def __str__(self): - return str(self.squares.reshape(3,3,3)) + return str(self.squares.reshape(3, 3, 3)) diff --git a/gobblet/game/collector_manual_policy.py b/gobblet/game/collector_manual_policy.py index 1dc8b7e..ad8983a 100644 --- a/gobblet/game/collector_manual_policy.py +++ b/gobblet/game/collector_manual_policy.py @@ -1,20 +1,13 @@ # Extending tianshou collector class to work with manual policy (for user input) import time -import warnings -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import gym import numpy as np -import torch - -from tianshou.data import ( - Batch, - ReplayBuffer, -) -from tianshou.env import BaseVectorEnv, DummyVectorEnv -from tianshou.policy import BasePolicy - +from tianshou.data import Batch, ReplayBuffer from tianshou.data.collector import Collector +from tianshou.env import BaseVectorEnv +from tianshou.policy import BasePolicy class ManualPolicyCollector(Collector): @@ -26,15 +19,15 @@ def __init__( preprocess_fn: Optional[Callable[..., Batch]] = None, exploration_noise: bool = False, ) -> None: - super(ManualPolicyCollector, self).__init__(policy=policy, env=env, exploration_noise=exploration_noise) + super().__init__(policy=policy, env=env, exploration_noise=exploration_noise) # Custom function to collect the result of an inputted action def collect_result( - self, - action: int = None, - render: Optional[float] = None, - no_grad: bool = True, - gym_reset_kwargs: Optional[Dict[str, Any]] = None, + self, + action: int = None, + render: Optional[float] = None, + no_grad: bool = True, + gym_reset_kwargs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Collect the results of an inputted action.. @@ -73,7 +66,7 @@ def collect_result( while True: assert len(self.data) == len(ready_env_ids) # restore the state: if the last state is None, it won't store - last_state = self.data.policy.pop("hidden_state", None) + last_state = self.data.policy.pop("hidden_state", None) # noqa: F841 # use hard coded action (rather than using a policy or randomly sampling) self.data.update(act=action) @@ -109,7 +102,7 @@ def collect_result( terminated=terminated, truncated=truncated, done=done, - info=info + info=info, ) if self.preprocess_fn: self.data.update( @@ -154,10 +147,9 @@ def collect_result( # remove surplus env id from ready_env_ids # to avoid bias in selecting environments - self.data.obs = self.data.obs_next - if (n_step and step_count >= n_step): + if n_step and step_count >= n_step: break # generate statistics @@ -167,10 +159,7 @@ def collect_result( if episode_count > 0: rews, lens, idxs = list( - map( - np.concatenate, - [episode_rews, episode_lens, episode_start_indices] - ) + map(np.concatenate, [episode_rews, episode_lens, episode_start_indices]) ) rew_mean, rew_std = rews.mean(), rews.std() len_mean, len_std = lens.mean(), lens.std() @@ -188,4 +177,4 @@ def collect_result( "len": len_mean, "rew_std": rew_std, "len_std": len_std, - } \ No newline at end of file + } diff --git a/gobblet/game/gobblet.py b/gobblet/game/gobblet.py index edc9efc..d2aaad4 100644 --- a/gobblet/game/gobblet.py +++ b/gobblet/game/gobblet.py @@ -93,15 +93,16 @@ """ +import os + import gymnasium import numpy as np -from gymnasium import spaces import pygame -import os - +from gymnasium import spaces from pettingzoo import AECEnv from pettingzoo.utils import agent_selector, wrappers from pettingzoo.utils.conversions import parallel_wrapper_fn + from .board import Board from .utils import get_image, load_chip, load_chip_preview @@ -115,8 +116,10 @@ def env(render_mode=None, args=None): env = wrappers.OrderEnforcingWrapper(env) return env + parallel_env = parallel_wrapper_fn(env) + class raw_env(AECEnv): metadata = { "render_modes": ["human", "rgb_array", "text", "text_full"], @@ -129,7 +132,7 @@ class raw_env(AECEnv): def __init__(self, render_mode=None, args=None): super().__init__() self.board = Board() - self.board_size = 3 # Will need to make a separate file for 4x4 + self.board_size = 3 # Will need to make a separate file for 4x4 self.agents = ["player_1", "player_2"] self.possible_agents = self.agents[:] @@ -141,7 +144,9 @@ def __init__(self, render_mode=None, args=None): "observation": spaces.Box( low=0, high=1, shape=(3, 3, 13), dtype=np.int8 ), - "action_mask": spaces.Box(low=0, high=1, shape=(54,), dtype=np.int8), + "action_mask": spaces.Box( + low=0, high=1, shape=(54,), dtype=np.int8 + ), } ) for i in self.agents @@ -175,22 +180,30 @@ def observe(self, agent): board = self.board.squares.reshape(3, 3, 3) if self.agents.index(agent) == 1: - board = board * -1 # Swap the signs on the board for the two different agents + board = ( + board * -1 + ) # Swap the signs on the board for the two different agents # Represent observations in the same way as pettingzoo.chess: specific channel for each color piece (e.g., two for each white small piece) layers = [] for i in range(1, 7): - layers.append(board[(i - 1) // 2] == i) # 3x3 array with an entry of 1 for squares with each white piece (1, ..., 6) + layers.append( + board[(i - 1) // 2] == i + ) # 3x3 array with an entry of 1 for squares with each white piece (1, ..., 6) for i in range(1, 7): - layers.append(board[(i - 1) // 2] == -i) # 3x3 array with an entry of 1 for squares with each black piece (-1, ..., -6) + layers.append( + board[(i - 1) // 2] == -i + ) # 3x3 array with an entry of 1 for squares with each black piece (-1, ..., -6) if self.agents.index(agent) == 1: agents_layer = np.ones((3, 3)) else: agents_layer = np.zeros((3, 3)) - layers.append(agents_layer) # Thirteenth layer encoding the current player (agent index 0 or 1) + layers.append( + agents_layer + ) # Thirteenth layer encoding the current player (agent index 0 or 1) observation = np.stack(layers, axis=2).astype(np.int8) legal_moves = self._legal_moves() if agent == self.agent_selection else [] @@ -234,7 +247,7 @@ def step(self, action): if self.board.check_game_over(): winner = self.board.check_for_winner() - if winner == 0: # NOTE: don't think ties are possible in gobblet + if winner == 0: # NOTE: don't think ties are possible in gobblet # tie pass elif winner == 1: @@ -259,7 +272,6 @@ def step(self, action): if self.render_mode in ["human", "text", "text_full", "rgb_array"]: self.render() - def reset(self, seed=None, return_info=False, options=None): # reset environment self.board = Board() @@ -288,17 +300,17 @@ def getSymbolFull(input): if input == 0: return "- " if input > 0: - return "+{}".format(int(input)) + return f"+{int(input)}" else: - return "{}".format(int(input)) + return f"{int(input)}" def getSymbol(input): if input == 0: return "- " if input > 0: - return "+{}".format(int((input + 1) // 2)) + return f"+{int((input + 1) // 2)}" else: - return "{}".format(int((input) // 2)) + return f"{int((input) // 2)}" if self.debug: self.board.print_pieces() @@ -306,48 +318,111 @@ def getSymbol(input): pos = self.action % 9 piece = (self.action // 9) + 1 piece = (piece + 1) // 2 - print(f"TURN: {self.turn}, AGENT: {self.agent_selection}, ACTION: {self.action}, POSITION: {pos}, PIECE: {piece}") + print( + f"TURN: {self.turn}, AGENT: {self.agent_selection}, ACTION: {self.action}, POSITION: {pos}, PIECE: {piece}" + ) board = list(map(getSymbol, self.board.get_flatboard())) print(" " * 7 + "|" + " " * 7 + "|" + " " * 7) - print(f" {board[0]} " + "|" + f" {board[3]} " + "|" + f" {board[6]} ") + print( + f" {board[0]} " + "|" + f" {board[3]} " + "|" + f" {board[6]} " + ) print("_" * 7 + "|" + "_" * 7 + "|" + "_" * 7) print(" " * 7 + "|" + " " * 7 + "|" + " " * 7) - print(f" {board[1]} " + "|" + f" {board[4]} " + "|" + f" {board[7]} ") + print( + f" {board[1]} " + "|" + f" {board[4]} " + "|" + f" {board[7]} " + ) print("_" * 7 + "|" + "_" * 7 + "|" + "_" * 7) print(" " * 7 + "|" + " " * 7 + "|" + " " * 7) - print(f" {board[2]} " + "|" + f" {board[5]} " + "|" + f" {board[8]} ") + print( + f" {board[2]} " + "|" + f" {board[5]} " + "|" + f" {board[8]} " + ) print(" " * 7 + "|" + " " * 7 + "|" + " " * 7) print() elif self.render_mode == "text_full": pos = self.action % 9 piece = (self.action // 9) + 1 - print(f"TURN: {self.turn}, AGENT: {self.agent_selection}, ACTION: {self.action}, POSITION: {pos}, PIECE: {piece}") + print( + f"TURN: {self.turn}, AGENT: {self.agent_selection}, ACTION: {self.action}, POSITION: {pos}, PIECE: {piece}" + ) board = list(map(getSymbolFull, self.board.squares)) - print(" " * 9 + "SMALL" + " " * 9 + " " + - " " * 10 + "MED" + " " * 10 + " " + - " " * 9 + "LARGE" + " " * 9 + " ") - top= " " * 7 + "|" + " " * 7 + "|" + " " * 7 + print( + " " * 9 + + "SMALL" + + " " * 9 + + " " + + " " * 10 + + "MED" + + " " * 10 + + " " + + " " * 9 + + "LARGE" + + " " * 9 + + " " + ) + top = " " * 7 + "|" + " " * 7 + "|" + " " * 7 bottom = "_" * 7 + "|" + "_" * 7 + "|" + "_" * 7 - top1= f" {board[0]} " + "|" + f" {board[3]} " + "|" + f" {board[6]} " - top2 = f" {board[9]} " + "|" + f" {board[12]} " + "|" + f" {board[15]} " - top3 = f" {board[18]} " + "|" + f" {board[21]} " + "|" + f" {board[24]} " + top1 = ( + f" {board[0]} " + "|" + f" {board[3]} " + "|" + f" {board[6]} " + ) + top2 = ( + f" {board[9]} " + + "|" + + f" {board[12]} " + + "|" + + f" {board[15]} " + ) + top3 = ( + f" {board[18]} " + + "|" + + f" {board[21]} " + + "|" + + f" {board[24]} " + ) print(top + " " + top + " " + top) print(top1 + " " + top2 + " " + top3) print(bottom + " " + bottom + " " + bottom) - mid1 = f" {board[1]} " + "|" + f" {board[4]} " + "|" + f" {board[7]} " - mid2 = f" {board[10]} " + "|" + f" {board[13]} " + "|" + f" {board[16]} " - mid3 = f" {board[19]} " + "|" + f" {board[22]} " + "|" + f" {board[25]} " + mid1 = ( + f" {board[1]} " + "|" + f" {board[4]} " + "|" + f" {board[7]} " + ) + mid2 = ( + f" {board[10]} " + + "|" + + f" {board[13]} " + + "|" + + f" {board[16]} " + ) + mid3 = ( + f" {board[19]} " + + "|" + + f" {board[22]} " + + "|" + + f" {board[25]} " + ) print(top + " " + top + " " + top) print(mid1 + " " + mid2 + " " + mid3) print(bottom + " " + bottom + " " + bottom) - bot1 = f" {board[2]} " + "|" + f" {board[5]} " + "|" + f" {board[8]} " - bot2 = f" {board[9+2]} " + "|" + f" {board[9+5]} " + "|" + f" {board[9+8]} " - bot3 = f" {board[18+2]} " + "|" + f" {board[18+5]} " + "|" + f" {board[18+8]} " + bot1 = ( + f" {board[2]} " + "|" + f" {board[5]} " + "|" + f" {board[8]} " + ) + bot2 = ( + f" {board[9+2]} " + + "|" + + f" {board[9+5]} " + + "|" + + f" {board[9+8]} " + ) + bot3 = ( + f" {board[18+2]} " + + "|" + + f" {board[18+5]} " + + "|" + + f" {board[18+8]} " + ) print(top + " " + top + " " + top) print(bot1 + " " + bot2 + " " + bot3) print(top + " " + top + " " + top) @@ -358,7 +433,9 @@ def getSymbol(input): if self.render_mode == "human": if self.screen is None: pygame.init() - self.screen = pygame.display.set_mode((self.screen_width, self.screen_height)) + self.screen = pygame.display.set_mode( + (self.screen_width, self.screen_height) + ) pygame.event.get() elif self.screen is None: self.screen = pygame.Surface((self.screen_width, self.screen_height)) @@ -381,14 +458,26 @@ def getSymbol(input): self.preview = {} self.preview["player_1"] = {} - self.preview["player_1"][3] = load_chip_preview(tile_size, "GobbletLargeRedPreview.png", scale_large) - self.preview["player_1"][2] = load_chip_preview(tile_size, "GobbletMedRedPreview.png", scale_med) - self.preview["player_1"][1] = load_chip_preview(tile_size, "GobbletSmallRedPreview.png", scale_small) + self.preview["player_1"][3] = load_chip_preview( + tile_size, "GobbletLargeRedPreview.png", scale_large + ) + self.preview["player_1"][2] = load_chip_preview( + tile_size, "GobbletMedRedPreview.png", scale_med + ) + self.preview["player_1"][1] = load_chip_preview( + tile_size, "GobbletSmallRedPreview.png", scale_small + ) self.preview["player_2"] = {} - self.preview["player_2"][3] = load_chip_preview(tile_size, "GobbletLargeYellowPreview.png", scale_large) - self.preview["player_2"][2] = load_chip_preview(tile_size, "GobbletMedYellowPreview.png", scale_med) - self.preview["player_2"][1] = load_chip_preview(tile_size, "GobbletSmallYellowPreview.png", scale_small) + self.preview["player_2"][3] = load_chip_preview( + tile_size, "GobbletLargeYellowPreview.png", scale_large + ) + self.preview["player_2"][2] = load_chip_preview( + tile_size, "GobbletMedYellowPreview.png", scale_med + ) + self.preview["player_2"][1] = load_chip_preview( + tile_size, "GobbletSmallYellowPreview.png", scale_small + ) # preview_chips = {self.agents[0]: self.preview["player_1"], self.agents[1]: self.preview["player_1"]} @@ -399,34 +488,50 @@ def getSymbol(input): self.screen.blit(board_img, (0, 0)) - offset = (self.screen_width * ((9+4) / 47)) # Piece is 9px wide, gap between pieces 4px, total width is 47px - offset_side = (self.screen_width * (6 / 47)) - 1 # Distance from the side of the board to the first piece is 6px - offset_centering = offset * 1/3 + \ - (5 * self.screen_width/1000 if self.screen_width > 500 else 8 * self.screen_width/1000) + offset = self.screen_width * ( + (9 + 4) / 47 + ) # Piece is 9px wide, gap between pieces 4px, total width is 47px + offset_side = ( + self.screen_width * (6 / 47) + ) - 1 # Distance from the side of the board to the first piece is 6px + offset_centering = offset * 1 / 3 + ( + 5 * self.screen_width / 1000 + if self.screen_width > 500 + else 8 * self.screen_width / 1000 + ) # Extra 5px fixed alignment issues at 1000x1000, but ratio needs to be higher for lower res (trial & error) # Blit the chips and their positions for i in range(9): for j in range(1, 4): - if self.board.squares[i + 9 * (j - 1)] == 2 * j - 1 or \ - self.board.squares[i + 9 * (j - 1)] == 2 * j: # small pieces (1,2), medium pieces (3,4), large pieces (5,6) + if ( + self.board.squares[i + 9 * (j - 1)] == 2 * j - 1 + or self.board.squares[i + 9 * (j - 1)] == 2 * j + ): # small pieces (1,2), medium pieces (3,4), large pieces (5,6) self.screen.blit( red[j], - red[j].get_rect(center = - (int(i / 3) * (offset) + offset_side + offset_centering, - (i % 3) * (offset) + offset_side + offset_centering - ) - ) + red[j].get_rect( + center=( + int(i / 3) * (offset) + + offset_side + + offset_centering, + (i % 3) * (offset) + offset_side + offset_centering, + ) + ), ) - if self.board.squares[i + 9 * (j - 1)] == -1 * (2 * j - 1) or\ - self.board.squares[i + 9 * (j - 1)] == -1 * (2 * j): + if self.board.squares[i + 9 * (j - 1)] == -1 * ( + 2 * j - 1 + ) or self.board.squares[i + 9 * (j - 1)] == -1 * (2 * j): self.screen.blit( yellow[j], - yellow[j].get_rect(center = - (int(i / 3) * (offset) + offset_side + offset_centering, - (i % 3) * (offset) + offset_side + offset_centering - ) - ) + yellow[j].get_rect( + center=( + int(i / 3) * (offset) + + offset_side + + offset_centering, + (i % 3) * (offset) + offset_side + offset_centering, + ) + ), ) # Blit the preview chips and their positions @@ -435,20 +540,26 @@ def getSymbol(input): if self.board.squares_preview[i + 9 * (j - 1)] == 1: self.screen.blit( self.preview["player_1"][j], - self.preview["player_1"][j].get_rect(center = - (int(i / 3) * (offset) + offset_side + offset_centering, - (i % 3) * (offset) + offset_side + offset_centering - ) - ) + self.preview["player_1"][j].get_rect( + center=( + int(i / 3) * (offset) + + offset_side + + offset_centering, + (i % 3) * (offset) + offset_side + offset_centering, + ) + ), ) if self.board.squares_preview[i + 9 * (j - 1)] == -1: self.screen.blit( self.preview["player_2"][j], - self.preview["player_2"][j].get_rect(center = - (int(i / 3) * (offset) + offset_side + offset_centering, - (i % 3) * (offset) + offset_side + offset_centering - ) - ) + self.preview["player_2"][j].get_rect( + center=( + int(i / 3) * (offset) + + offset_side + + offset_centering, + (i % 3) * (offset) + offset_side + offset_centering, + ) + ), ) pygame.display.update() diff --git a/gobblet/game/greedy_policy.py b/gobblet/game/greedy_policy.py index acbf540..d4ef434 100644 --- a/gobblet/game/greedy_policy.py +++ b/gobblet/game/greedy_policy.py @@ -1,14 +1,16 @@ -from typing import Optional, Any +from typing import Any, Optional import numpy as np + from gobblet.game.board import Board + class GreedyGobbletPolicy: def __init__( - self, - depth: Optional[int] = 2, - seed: Optional[int] = 0, - **kwargs: Any, + self, + depth: Optional[int] = 2, + seed: Optional[int] = 0, + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.board = None @@ -17,7 +19,9 @@ def __init__( def compute_actions_rllib(self, obs_batch): observations = obs_batch["observation"] - observations = observations.reshape(observations.shape[0], 3, 3, -1) # Infer observation dimension and batch size + observations = observations.reshape( + observations.shape[0], 3, 3, -1 + ) # Infer observation dimension and batch size masks = obs_batch["action_mask"] actions = [] for i in range(len(observations)): @@ -25,30 +29,38 @@ def compute_actions_rllib(self, obs_batch): actions.append(act) return actions - def compute_action_tianshou( - self, - obs - ): + def compute_action_tianshou(self, obs): mask = obs.mask obs = obs.obs if hasattr(obs, "obs") else obs # return self.compute_action(obs, mask) def compute_action( - self, - obs, - mask, + self, + obs, + mask, ) -> np.ndarray: - obs_player = obs[..., :6] # Index the last dimension - board_player = np.array([(i + 1) * obs_player[..., i] + (i + 2) * obs_player[..., i + 1] for i in - range(0, 6, 2)]) # Reshapes [3,3,6] to [3,3,3] + board_player = np.array( + [ + (i + 1) * obs_player[..., i] + (i + 2) * obs_player[..., i + 1] + for i in range(0, 6, 2) + ] + ) # Reshapes [3,3,6] to [3,3,3] obs_opponent = obs[..., 6:12] board_opponent = np.array( - [(i + 1) * obs_opponent[..., i] + (i + 2) * obs_opponent[..., i + 1] for i in range(0, 6, 2)]) + [ + (i + 1) * obs_opponent[..., i] + (i + 2) * obs_opponent[..., i + 1] + for i in range(0, 6, 2) + ] + ) board = np.where(board_player > board_opponent, board_player, -board_opponent) - agent_index = obs[..., 12].max() # Thirteenth layer of obs encodes agent_index (all zeros or all ones) - opponent_index = 1 - agent_index # If agent index is 1, we want 0; if agent index is 0, we want 1 + agent_index = obs[ + ..., 12 + ].max() # Thirteenth layer of obs encodes agent_index (all zeros or all ones) + opponent_index = ( + 1 - agent_index + ) # If agent index is 1, we want 0; if agent index is 0, we want 1 # If we are playing as the second agent, we need to adjust the representation of the board to reflect that if agent_index == 1: @@ -61,7 +73,9 @@ def compute_action( winner_values = [1, -1] legal_actions = mask.flatten().nonzero()[0] - actions_depth1 = list(legal_actions) # Initialize the same as legal actions, then remove ones that cause a loss + actions_depth1 = list( + legal_actions + ) # Initialize the same as legal actions, then remove ones that cause a loss chosen_action = None results = {} @@ -77,7 +91,9 @@ def compute_action( if results[action] == winner_values[agent_index]: # Win for our agent chosen_action = action break - elif results[action] == winner_values[opponent_index]: # Loss for our agent + elif ( + results[action] == winner_values[opponent_index] + ): # Loss for our agent if len(actions_depth1) > 1: actions_depth1.remove(action) else: @@ -92,20 +108,26 @@ def compute_action( depth1_board.play_turn(agent_index=agent_index, action=action) # Check what the opponent can do after this potential move - legal_actions_depth2 = [act for act in range(len(mask.flatten())) if - depth1_board.is_legal(agent_index=opponent_index, action=act)] + legal_actions_depth2 = [ + act + for act in range(len(mask.flatten())) + if depth1_board.is_legal(agent_index=opponent_index, action=act) + ] results_depth2 = {} for action_depth2 in legal_actions_depth2: depth2_board = Board() depth2_board.squares = depth1_board.squares.copy() - depth2_board.play_turn(agent_index=opponent_index, action=action_depth2) + depth2_board.play_turn( + agent_index=opponent_index, action=action_depth2 + ) results_depth2[action_depth2] = depth2_board.check_for_winner() # If the opponent can win in the next move, we don't want to do this action - if results_depth2[action_depth2] == winner_values[ - opponent_index]: # Check if it's possible for the opponent to win next turn after + if ( + results_depth2[action_depth2] == winner_values[opponent_index] + ): # Check if it's possible for the opponent to win next turn after if len(actions_depth1) > 1: if action in actions_depth1: actions_depth1.remove(action) @@ -116,15 +138,21 @@ def compute_action( # BUT this might miss us from winning the game ourselves in blcoking them, so don't exit the loop # So we only do it if there aren't deterministic wins already chosen (and we don't break, keep looking after) if self.board.is_legal(action_depth2, agent_index=agent_index): - if chosen_action == None: + if chosen_action is None: chosen_action = action_depth2 # Depth 2: if this move sets the opponent up with only moves that win the game for us, pick this move - if all(winner == winner_values[agent_index] for winner in results_depth2.values()): + if all( + winner == winner_values[agent_index] + for winner in results_depth2.values() + ): chosen_action = action break # Depth 2: if this move blocks the opponent from winning, pick this if we cannot find any guaranteed wins - if all(winner != winner_values[opponent_index] for winner in results_depth2.values()): + if all( + winner != winner_values[opponent_index] + for winner in results_depth2.values() + ): chosen_action = action # BLOCKING ACTION # Depth 3: Given that we block the opponent this way, can we win the next turn? @@ -134,13 +162,22 @@ def compute_action( depth1_board.play_turn(agent_index=agent_index, action=action) # Search over depth 2 actions where the opponent doesn't win - for action_depth2 in [key for key, value in results_depth2.items() if value == 0]: + for action_depth2 in [ + key for key, value in results_depth2.items() if value == 0 + ]: depth2_board = Board() depth2_board.squares = depth1_board.squares.copy() - depth2_board.play_turn(agent_index=agent_index, action=action_depth2) - - legal_actions_depth3 = [act for act in range(len(mask.flatten())) if - depth2_board.is_legal(agent_index=agent_index, action=act)] + depth2_board.play_turn( + agent_index=agent_index, action=action_depth2 + ) + + legal_actions_depth3 = [ + act + for act in range(len(mask.flatten())) + if depth2_board.is_legal( + agent_index=agent_index, action=act + ) + ] actions_depth3 = list(legal_actions_depth3) results_depth3 = {} @@ -148,15 +185,24 @@ def compute_action( for action_depth3 in legal_actions_depth3: depth3_board = Board() depth3_board.squares = depth2_board.squares.copy() - depth3_board.play_turn(agent_index=agent_index, action=action) + depth3_board.play_turn( + agent_index=agent_index, action=action + ) # If we can win next turn, do it - results_depth3[action_depth3] = depth3_board.check_for_winner() - if results_depth3[action_depth3] == winner_values[agent_index]: # Win for our agent + results_depth3[ + action_depth3 + ] = depth3_board.check_for_winner() + if ( + results_depth3[action_depth3] + == winner_values[agent_index] + ): # Win for our agent chosen_action = action # If we can win in depth 3, then we know this blocking action is good break - elif results_depth3[action_depth3] == winner_values[ - opponent_index]: # Loss for our agent + elif ( + results_depth3[action_depth3] + == winner_values[opponent_index] + ): # Loss for our agent if len(actions_depth3) > 1: actions_depth3.remove(action) else: diff --git a/gobblet/game/greedy_policy_rllib.py b/gobblet/game/greedy_policy_rllib.py index 6d57f05..55fa865 100644 --- a/gobblet/game/greedy_policy_rllib.py +++ b/gobblet/game/greedy_policy_rllib.py @@ -1,15 +1,11 @@ -from typing import ( - List, - Optional, - Union, -) +from typing import List, Optional, Union +import numpy as np +from ray.rllib.examples.policy.random_policy import RandomPolicy from ray.rllib.utils.annotations import override from ray.rllib.utils.typing import TensorStructType, TensorType -from ray.rllib.examples.policy.random_policy import RandomPolicy from gobblet.game.greedy_policy import GreedyGobbletPolicy -import numpy as np class GreedyPolicy(RandomPolicy): @@ -19,12 +15,12 @@ def __init__(self, *args, **kwargs): @override(RandomPolicy) def compute_actions( - self, - obs_batch: Union[List[TensorStructType], TensorStructType], - state_batches: Optional[List[TensorType]] = None, - prev_action_batch: Union[List[TensorStructType], TensorStructType] = None, - prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None, - **kwargs + self, + obs_batch: Union[List[TensorStructType], TensorStructType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Union[List[TensorStructType], TensorStructType] = None, + prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None, + **kwargs, ): actions = self.policy.compute_actions_rllib(obs_batch) return ( diff --git a/gobblet/game/greedy_policy_tianshou.py b/gobblet/game/greedy_policy_tianshou.py index 062a7c4..4f97db9 100644 --- a/gobblet/game/greedy_policy_tianshou.py +++ b/gobblet/game/greedy_policy_tianshou.py @@ -1,19 +1,22 @@ -from typing import Any, Dict, Optional, Union from copy import deepcopy +from typing import Any, Dict, Optional, Union + import numpy as np from numpy import ndarray - -from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as +from tianshou.data import Batch from tianshou.policy import BasePolicy from gobblet.game.greedy_policy import GreedyGobbletPolicy + + class GreedyPolicy(BasePolicy): - """ + """Greedy Policy. + Basic greedy policy which checks if a move results in a victory, and if it sets the opponent up to win (or lose) in the next turn. The depth argument controls the agent's search depth (default of 2 is a balance between computational efficiency and optimal play) * depth = 1: Agent considers moves which it can use to directly win * depth = 2: Agent also considers moves it can take to block the opponent from winning next turn - * depth = 3: Agent also considers moves which set it up to win in two moves: no matter waht opponents does in retaliation (unblockable wins) + * depth = 3: Agent also considers moves which set it up to win in two moves: no matter what opponents does in retaliation (unblockable wins) """ def __init__( @@ -65,11 +68,15 @@ def forward( if len(batch_single[input].agent_id) > 1: batch_single[input].obs = batch_single[input].obs[b, ...] batch_single[input].mask = batch_single[input].mask[b, ...] - batch_single[input].agent_id = batch_single[input].agent_id[b, ...].reshape(1) + batch_single[input].agent_id = ( + batch_single[input].agent_id[b, ...].reshape(1) + ) act = self.forward_unbatched(batch=batch_single) # array(1) - acts.append(act) # [ array(1), array(2), ... array(10) ] + acts.append(act) # [ array(1), array(2), ... array(10) ] acts = np.array(acts) - return Batch(act=acts) # Batch( act: [ array(1), array(2), ... array(10) ] ) + return Batch( + act=acts + ) # Batch( act: [ array(1), array(2), ... array(10) ] ) else: batch_single = deepcopy(full_batch) act = self.forward_unbatched(batch=batch_single) @@ -88,4 +95,4 @@ def forward_unbatched( def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: """Since a random agent learns nothing, it returns an empty dict.""" - return {} \ No newline at end of file + return {} diff --git a/gobblet/game/manual_policy.py b/gobblet/game/manual_policy.py index 354d92a..adbcea9 100644 --- a/gobblet/game/manual_policy.py +++ b/gobblet/game/manual_policy.py @@ -1,13 +1,14 @@ # Adapted from https://github.com/Farama-Foundation/PettingZoo/blob/master/pettingzoo/butterfly/knights_archers_zombies/manual_policy.py -import pygame -from .utils import GIFRecorder import sys + import numpy as np +import pygame +from .utils import GIFRecorder -class ManualPolicy: - def __init__(self, env, agent_id: int = 0, recorder: GIFRecorder = None): +class ManualGobbletPolicy: + def __init__(self, env, agent_id: int = 0, recorder: GIFRecorder = None): self.env = env self.agent_id = agent_id self.agent = self.env.agents[self.agent_id] @@ -34,7 +35,7 @@ def __call__(self, observation, agent): pygame.display.quit() sys.exit() - ''' GET MOUSE INPUT''' + """ GET MOUSE INPUT""" mousex, mousey = pygame.mouse.get_pos() width, height = pygame.display.get_surface().get_size() pos_x = 0 @@ -55,45 +56,57 @@ def __call__(self, observation, agent): agent_multiplier = 1 if agent == env.agents[0] else -1 - ''' FIND PLACED PIECES ''' - placed_pieces = env.unwrapped.board.squares[env.unwrapped.board.squares.nonzero()] - placed_pieces_agent = [a for a in placed_pieces if np.sign(a) == agent_multiplier] + """ FIND PLACED PIECES """ + placed_pieces = env.unwrapped.board.squares[ + env.unwrapped.board.squares.nonzero() + ] + placed_pieces_agent = [ + a for a in placed_pieces if np.sign(a) == agent_multiplier + ] placed_pieces_agent_abs = [abs(p) for p in placed_pieces_agent] pieces = np.arange(1, 7) unplaced = [p for p in pieces if p not in placed_pieces_agent_abs] flatboard = env.unwrapped.board.get_flatboard() # Selected piece size persists as we loop through events, allowing the user to cycle through sizes - if piece_size_selected > 0: - piece = piece - else: + if piece_size_selected == 0: if len(unplaced) > 0: piece = unplaced[-1] piece_size_selected = (piece + 1) // 2 else: piece = -1 - ''' READ KEYBOARD INPUT''' + """ READ KEYBOARD INPUT""" if event.type == pygame.KEYDOWN: if not picked_up: - if event.key == pygame.K_SPACE: # Cycle through pieces (from largest to smallest) + if ( + event.key == pygame.K_SPACE + ): # Cycle through pieces (from largest to smallest) piece_cycle += 1 cycle_choices = np.unique( - [(p + 1) // 2 for p in unplaced]) # Transform [1,2,3,4,5,6] to [1,2,3) + [(p + 1) // 2 for p in unplaced] + ) # Transform [1,2,3,4,5,6] to [1,2,3) if len(cycle_choices) > 0: - piece_size = cycle_choices[(np.amax(cycle_choices) - (piece_cycle + 1)) % len(cycle_choices)] + piece_size = cycle_choices[ + (np.amax(cycle_choices) - (piece_cycle + 1)) + % len(cycle_choices) + ] # else: # piece_size = -1 piece_size_selected = piece_size - if (piece_size * 2) - 1 in unplaced: # Check if the first of this piece size is available + if ( + piece_size * 2 + ) - 1 in unplaced: # Check if the first of this piece size is available piece = piece_size * 2 - 1 else: - piece = piece_size * 2 # Otherwise choose the second of this piece size + piece = ( + piece_size * 2 + ) # Otherwise choose the second of this piece size else: - if event.key == pygame.K_1: # Select piece size 1 + if event.key == pygame.K_1: # Select piece size 1 piece_size_selected = 1 if 1 in unplaced: piece = 1 @@ -103,7 +116,7 @@ def __call__(self, observation, agent): piece_cycle = 2 else: piece = -1 - elif event.key == pygame.K_2: # Select piece size 2 + elif event.key == pygame.K_2: # Select piece size 2 piece_size_selected = 2 if 3 in unplaced: piece = 3 @@ -113,7 +126,7 @@ def __call__(self, observation, agent): piece_cycle = 1 else: piece = -1 - elif event.key == pygame.K_3: # Select piece size 3 + elif event.key == pygame.K_3: # Select piece size 3 piece_size_selected = 3 if 5 in unplaced: piece = 5 @@ -129,31 +142,36 @@ def __call__(self, observation, agent): piece_size = (piece + 1) // 2 # Don't render a preview if both pieces of a given size have been placed - ''' GET PREVIEW ACTION ''' + """ GET PREVIEW ACTION """ # Get the action from the preview (in position the mouse cursor is currently hovering over) - action_prev = env.unwrapped.board.get_action(pos, piece_size, env.agents.index(agent)) + action_prev = env.unwrapped.board.get_action( + pos, piece_size, env.agents.index(agent) + ) - ''' CLEAR ACTION PREVIEW FOR ILLEGAL MOVES''' + """ CLEAR ACTION PREVIEW FOR ILLEGAL MOVES""" # If the hovered over position means placing a picked up piece in the same spot, mark it as illegal if pos == picked_up_pos or piece == -1: action_prev = -1 - ''' CLEAR PREVIOUSLY PREVIEWED MOVES ''' + """ CLEAR PREVIOUSLY PREVIEWED MOVES """ env.unwrapped.board.squares_preview[:] = 0 if action_prev != -1: - if not env.unwrapped.board.is_legal(action_prev, env.agents.index(agent)): # If this action is illegal + if not env.unwrapped.board.is_legal( + action_prev, env.agents.index(agent) + ): # If this action is illegal action_prev = -1 else: env.unwrapped.board.squares_preview[ - pos + 9 * (piece_size - 1)] = agent_multiplier # Preview this position + pos + 9 * (piece_size - 1) + ] = agent_multiplier # Preview this position - ''' UPDATE DISPLAY with previewed move''' + """ UPDATE DISPLAY with previewed move""" env.render() pygame.display.update() if recorder is not None: recorder.capture_frame(env.unwrapped.screen) - ''' PICK UP / PLACE A PIECE ''' + """ PICK UP / PLACE A PIECE """ if event.type == pygame.MOUSEBUTTONDOWN: # Pick up a piece (only able to if it has already been placed, and is not currently picked up) if flatboard[pos] in placed_pieces_agent and not picked_up: @@ -164,28 +182,38 @@ def __call__(self, observation, agent): # If the piece size selected is larger, clicking here should self-gobble the smaller piece if piece_size_on_board >= piece_size_selected: # Can only pick up a piece if there is a legal move to place it, other than where it was before - if not all(observation["action_mask"][9 * (piece - 1): 9 * piece] == 0): + if not all( + observation["action_mask"][9 * (piece - 1) : 9 * piece] == 0 + ): picked_up = True picked_up_pos = pos # Remove a placed piece piece_size_selected = (piece + 1) // 2 - index = np.where(env.unwrapped.board.squares == piece_to_pick_up)[0][0] + index = np.where( + env.unwrapped.board.squares == piece_to_pick_up + )[0][0] env.unwrapped.board.squares[index] = 0 # Set the only possible actions to be moving this piece to a new square - observation["action_mask"][pos + 9 * (piece - 1)] = 0 # TODO: check if this is already zero - observation["action_mask"][:9 * (piece - 1)] = 0 # Zero out all the possible actions - observation["action_mask"][9 * (piece):] = 0 + observation["action_mask"][ + pos + 9 * (piece - 1) + ] = 0 # TODO: check if this is already zero + observation["action_mask"][ + : 9 * (piece - 1) + ] = 0 # Zero out all the possible actions + observation["action_mask"][9 * (piece) :] = 0 # Place a piece (if it is legal to do so) else: if action_prev != -1: - env.unwrapped.board.squares_preview[pos + 9 * (piece_size - 1)] = 0 + env.unwrapped.board.squares_preview[ + pos + 9 * (piece_size - 1) + ] = 0 action = pos + 9 * (piece - 1) break return np.int32(action) @property def available_agents(self): - return self.env.agent_name_mapping \ No newline at end of file + return self.env.agent_name_mapping diff --git a/gobblet/game/random_admissible_policy_rllib.py b/gobblet/game/random_admissible_policy_rllib.py index 99cf7f2..3f7c308 100644 --- a/gobblet/game/random_admissible_policy_rllib.py +++ b/gobblet/game/random_admissible_policy_rllib.py @@ -1,14 +1,11 @@ -from typing import ( - List, - Optional, - Union, -) +from typing import List, Optional, Union -from ray.rllib.utils.annotations import override -from ray.rllib.utils.typing import TensorStructType, TensorType -from ray.rllib.examples.policy.random_policy import RandomPolicy import numpy as np import tree # pip install dm_tree +from ray.rllib.examples.policy.random_policy import RandomPolicy +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import TensorStructType, TensorType + class RandomAdmissiblePolicy(RandomPolicy): def __init__(self, *args, **kwargs): @@ -16,20 +13,26 @@ def __init__(self, *args, **kwargs): @override(RandomPolicy) def compute_actions( - self, - obs_batch: Union[List[TensorStructType], TensorStructType], - state_batches: Optional[List[TensorType]] = None, - prev_action_batch: Union[List[TensorStructType], TensorStructType] = None, - prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None, - **kwargs + self, + obs_batch: Union[List[TensorStructType], TensorStructType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Union[List[TensorStructType], TensorStructType] = None, + prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None, + **kwargs, ): if "action_mask" in obs_batch.keys(): action_masks = obs_batch["action_mask"] - actions = [np.random.choice(np.arange(len(action_mask)), p=action_mask / np.sum(action_mask)) for - action_mask in action_masks] + actions = [ + np.random.choice( + np.arange(len(action_mask)), p=action_mask / np.sum(action_mask) + ) + for action_mask in action_masks + ] else: obs_batch_size = len(tree.flatten(obs_batch)[0]) - actions = [self.action_space_for_sampling.sample() for _ in range(obs_batch_size)] + actions = [ + self.action_space_for_sampling.sample() for _ in range(obs_batch_size) + ] return ( actions, [], diff --git a/gobblet/game/utils.py b/gobblet/game/utils.py index 73a439e..c42f8f5 100644 --- a/gobblet/game/utils.py +++ b/gobblet/game/utils.py @@ -1,99 +1,145 @@ # adapted from pettingzoo.classic.connect_four +import glob import os +import re +import subprocess +import time +from pathlib import Path +from typing import Union + import pygame -import sys + def get_image(path): + """Get image. + Load an image from file into a pygame object + + Returns: + sfc: Pygame surface object with the image + """ cwd = os.path.dirname(__file__) image = pygame.image.load(os.path.join(cwd, path)) sfc = pygame.Surface(image.get_size(), flags=pygame.SRCALPHA) sfc.blit(image, (0, 0)) return sfc + def load_chip(tile_size, filename, scale): + """Load chip. + + Load the image of a single chip + + Returns: + Chip: Pygame surface object containing the chip + """ chip = get_image(os.path.join("img", filename)) chip = pygame.transform.scale( chip, (int(tile_size * (scale)), int(tile_size * (scale))) ) return chip + def load_chip_preview(tile_size, filename, scale): - chip = get_image(os.path.join(os.path.join("img","preview"), filename)) + """Load chip preview. + + Load the preview image of a single chip + + Returns: + Chip: Pygame surface object containing the chip preview + """ + chip = get_image(os.path.join(os.path.join("img", "preview"), filename)) chip = pygame.transform.scale( chip, (int(tile_size * (scale)), int(tile_size * (scale))) ) return chip + # from https://github.com/michaelfeil/skyjo_rl/blob/dev/rlskyjo/utils.py -from pathlib import Path -from typing import Union -import glob -import os -import re + def get_project_root() -> Path: - """return Path to the project directory, top folder of gobblet-rl + """Get project root. + + return Path to the project directory, top folder of gobblet-rl. + Returns: Path: Path to the project directory """ return Path(__file__).parent.parent.parent.resolve() -def find_file_in_subdir(parent_dir: Union[Path, str], file_str: Union[Path, str], regex_match: str = None) -> Union[str, None]: - files = glob.glob( - os.path.join(parent_dir, "**", file_str), recursive=True - ) + +def find_file_in_subdir( + parent_dir: Union[Path, str], file_str: Union[Path, str], regex_match: str = None +) -> Union[str, None]: + """Find file in subdirectory. + + Finds the path of a file in a subdirectory + + Args: + parent_dir: directory to look for the file in + + Returns: + Path: Path to the project directory + """ + files = glob.glob(os.path.join(parent_dir, "**", file_str), recursive=True) if regex_match is not None: p = re.compile(regex_match) - files = [ s for s in files if p.match(s) ] + files = [s for s in files if p.match(s)] return sorted(files)[-1] if len(files) else None -# adapted from https://commons.wikimedia.org/wiki/File:SquareWaveFourierArrows.gif#Source_code -import time -import subprocess +# adapted from https://commons.wikimedia.org/wiki/File:SquareWaveFourierArrows.gif#Source_code class GIFRecorder: + """GIF Recorder. + + This class is used to record a PyGame surface and save it to a .gif file. """ - This class is used to record a PyGame surface and save it to a gif file. - """ - def __init__(self, out_file=f'game.gif'): - """ - Initialize the recorder - :param out_file: Output file to save the recording + + def __init__(self, out_file="game.gif"): + """Initialize the recorder. + + Args: + out_file: Output file to save the recording """ - print(f'Initializing GIF Recorder...') - print(f'Output of the recording will be saved to {out_file}.') + print("Initializing GIF Recorder...") + print(f"Output of the recording will be saved to {out_file}.") self.filename_list = [] self.frame_num = 0 self.start_time = time.time() self.path = get_project_root() - # self.path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir), os.pardir) # Root project directoy print(self.path) self.out_file = out_file self.ended = False - - def capture_frame(self, surf): - """ + """Capture frame. + Call this method every frame, pass in the pygame surface to capture. - :param surf: pygame surface to capture - :return: None - """ - """ + Note: surface must have the dimensions specified in the constructor. + + Args: + surf: pygame surface to capture - Note: surface must have the dimensions specified in the constructor. + Returns: + None """ - if not self.ended: # Stop saving frames after we have exported the recording + if not self.ended: # Stop saving frames after we have exported the recording # transform the pixels to the format used by open-cv - self.filename_list.append(os.path.join(self.path, f'temp_{time.time()}_' + str(self.frame_num) + '.png')) + self.filename_list.append( + os.path.join( + self.path, f"temp_{time.time()}_" + str(self.frame_num) + ".png" + ) + ) pygame.image.save(surf, self.filename_list[-1]) self.frame_num += 1 - # Convert indivual image files to GIF + # Convert individual image files to GIF def end_recording(self, surf): - """ + """End recording. + Call this method to stop recording. + :return: None """ if not self.ended: @@ -103,9 +149,13 @@ def end_recording(self, surf): # stop recording duration = time.time() - self.start_time - seconds_per_frame = duration/ self.frame_num + seconds_per_frame = duration / self.frame_num frame_delay = str(int(seconds_per_frame * 100)) - command_list = ['convert', '-delay', frame_delay, '-loop', '0'] + self.filename_list + [self.out_file] + command_list = ( + ["convert", "-delay", frame_delay, "-loop", "0"] + + self.filename_list + + [self.out_file] + ) # Use the "convert" command (part of ImageMagick) to build the animation subprocess.call(command_list, cwd=self.path) # Earlier, we saved an image file for each frame of the animation. Now diff --git a/gobblet/gobblet_v1.py b/gobblet/gobblet_v1.py index dceacff..c5cbbb9 100644 --- a/gobblet/gobblet_v1.py +++ b/gobblet/gobblet_v1.py @@ -1,2 +1,2 @@ -from gobblet.game.gobblet import env, raw_env, parallel_env # noqa: 401 -from gobblet.game.manual_policy import ManualPolicy \ No newline at end of file +from gobblet.game.gobblet import env, parallel_env, raw_env # noqa: 401 +from gobblet.game.manual_policy import ManualGobbletPolicy # noqa 401 diff --git a/requirements.txt b/requirements.txt index 74ced2a..a5b40ae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ pytest==7.1.2 ray==2.2.0 tianshou==0.4.11 torch==1.12.1 +pre-commit==3.1.1 \ No newline at end of file diff --git a/setup.py b/setup.py index 1977cec..374f6ae 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,5 @@ #!/usr/bin/env python from setuptools import setup -if __name__=="__main__": - setup() \ No newline at end of file +if __name__ == "__main__": + setup() diff --git a/tests/test_gobblet_env.py b/tests/test_gobblet_env.py index bd951e1..8834dfb 100644 --- a/tests/test_gobblet_env.py +++ b/tests/test_gobblet_env.py @@ -1,9 +1,11 @@ +import numpy as np import pettingzoo import pettingzoo.test import pytest -import numpy as np + from gobblet import gobblet_v1 + # Note: raw_env is required in order to test the board state, as env() only allows observations @pytest.fixture(scope="function") def env(): @@ -21,9 +23,7 @@ def test_reset(env): def test_reset_starting(env): "Verify that reset() sets the board state to the correct starting position" - assert ( - (env.board.squares == np.zeros(27)).all() - ) + assert (env.board.squares == np.zeros(27)).all() def test_api(env): diff --git a/tests/test_manual_policy_collector.py b/tests/test_manual_policy_collector.py index 9ac019b..d502c74 100644 --- a/tests/test_manual_policy_collector.py +++ b/tests/test_manual_policy_collector.py @@ -11,135 +11,501 @@ git+https://github.com/thu-ml/tianshou """ -from typing import Optional, Tuple +import time +from typing import Tuple -import gym import numpy as np -import torch from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import BasePolicy, MultiAgentPolicyManager - from gobblet import gobblet_v1 from gobblet.game.collector_manual_policy import ManualPolicyCollector -from gobblet.game.greedy_policy import GreedyPolicy -import time +from gobblet.game.greedy_policy import GreedyGobbletPolicy def get_agents() -> Tuple[BasePolicy, list]: env = get_env() - agents = [GreedyPolicy(), GreedyPolicy()] + agents = [GreedyGobbletPolicy(), GreedyGobbletPolicy()] policy = MultiAgentPolicyManager(agents, env) return policy, env.agents + def get_env(render_mode=None, args=None): return PettingZooEnv(gobblet_v1.env(render_mode=render_mode, args=args)) + # ======== allows the user to input moves and play vs a pre-trained agent ====== def test_collector() -> None: env = DummyVectorEnv([lambda: get_env(render_mode="human", args=None)]) policy, agents = get_agents() - collector = ManualPolicyCollector(policy, env, exploration_noise=True) # Collector for CPU actions + collector = ManualPolicyCollector( + policy, env, exploration_noise=True + ) # Collector for CPU actions pettingzoo_env = env.workers[0].env.env - output0 = np.array([[True, True, True, True, True, True, True, True, True, True, True, True, - True, True, True, True, True, True, True, True, True, True, True, True, - True, True, True, True, True, True, True, True, True, True, True, True, - True, True, True, True, True, True, True, True, True, True, True, True, - True, True, True, True, True, True]]) - assert(np.array_equal(collector.data.obs.mask, output0)) + output0 = np.array( + [ + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + ] + ] + ) + assert np.array_equal(collector.data.obs.mask, output0) - ''' PLAYER 1''' + """ PLAYER 1""" action = np.array(18) result = collector.collect_result(action=action.reshape(1), render=0.1) - output1 = np.array([[False, True, True, True, True, True, True, True, True, False, True, True, - True, True, True, True, True, True, False, True, True, True, True, True, - True, True, True, False, True, True, True, True, True, True, True, True, - True, True, True, True, True, True, True, True, True, True, True, True, - True, True, True, True, True, True]]) - assert(np.array_equal(collector.data.obs.mask, output1)) + output1 = np.array( + [ + [ + False, + True, + True, + True, + True, + True, + True, + True, + True, + False, + True, + True, + True, + True, + True, + True, + True, + True, + False, + True, + True, + True, + True, + True, + True, + True, + True, + False, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + ] + ] + ) + assert np.array_equal(collector.data.obs.mask, output1) - time.sleep(.25) + time.sleep(0.25) - ''' PLAYER 2 (covers it)''' + """ PLAYER 2 (covers it)""" action = np.array(36) result = collector.collect_result(action=action.reshape(1), render=0.1) - output2 = np.array([[False, True, True, True, True, True, True, True, True, - False, True, True, True, True, True, True, True, True, - False, False, False, False, False, False, False, False, False, - False, True, True, True, True, True, True, True, True, - False, True, True, True, True, True, True, True, True, - False, True, True, True, True, True, True, True, True]]) + output2 = np.array( + [ + [ + False, + True, + True, + True, + True, + True, + True, + True, + True, + False, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + True, + True, + True, + True, + True, + True, + True, + False, + True, + True, + True, + True, + True, + True, + True, + True, + False, + True, + True, + True, + True, + True, + True, + True, + True, + ] + ] + ) assert np.array_equal(collector.data.obs.mask, output2) - time.sleep(.25) + time.sleep(0.25) - ''' PLAYER 1''' - action = np.array(27+1) + """ PLAYER 1""" + action = np.array(27 + 1) result = collector.collect_result(action=action.reshape(1), render=0.1) - output3 = np.array([[False, False, True, True, True, True, True, True, True, - False, False, True, True, True, True, True, True, True, - False, False, True, True, True, True, True, True, True, - False, False, True, True, True, True, True, True, True, - False, True, True, True, True, True, True, True, True, - False, True, True, True, True, True, True, True, True]]) - assert(np.array_equal(collector.data.obs.mask, output3)) - - time.sleep(.25) - ''' PLAYER 2 (covers it)''' - action = np.array(45+1) + output3 = np.array( + [ + [ + False, + False, + True, + True, + True, + True, + True, + True, + True, + False, + False, + True, + True, + True, + True, + True, + True, + True, + False, + False, + True, + True, + True, + True, + True, + True, + True, + False, + False, + True, + True, + True, + True, + True, + True, + True, + False, + True, + True, + True, + True, + True, + True, + True, + True, + False, + True, + True, + True, + True, + True, + True, + True, + True, + ] + ] + ) + assert np.array_equal(collector.data.obs.mask, output3) + + time.sleep(0.25) + """ PLAYER 2 (covers it)""" + action = np.array(45 + 1) result = collector.collect_result(action=action.reshape(1), render=0.1) - output4 = np.array([[False, False, True, True, True, True, True, True, True, - False, False, True, True, True, True, True, True, True, - False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, - False, False, True, True, True, True, True, True, True, - False, False, True, True, True, True, True, True, True]]) - assert(np.array_equal(collector.data.obs.mask, output4)) + output4 = np.array( + [ + [ + False, + False, + True, + True, + True, + True, + True, + True, + True, + False, + False, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + True, + True, + True, + True, + True, + True, + False, + False, + True, + True, + True, + True, + True, + True, + True, + ] + ] + ) + assert np.array_equal(collector.data.obs.mask, output4) - time.sleep(.25) + time.sleep(0.25) - ''' PLAYER 1 (tries to move covered piece [ILLEGAL])''' - action = np.array(27+2) + """ PLAYER 1 (tries to move covered piece [ILLEGAL])""" + action = np.array(27 + 2) # Moves 18-35 should be illegal as they are with medium pieces. (36 and 37 as well but they are with a large piece) - output6 = [2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 38, 39, 40, 41, 42, 43, 44, 47, 48, 49, 50, 51, 52, 53] + output6 = [ + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + ] legal_moves = pettingzoo_env.unwrapped._legal_moves() - assert(output6 == legal_moves) + assert output6 == legal_moves - output5 = np.array([[False, False, True, True, True, True, True, True, True, - False, False, True, True, True, True, True, True, True, - False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, - False, False, True, True, True, True, True, True, True, - False, False, True, True, True, True, True, True, True]]) - assert(np.array_equal(collector.data.obs.mask, output5)) + output5 = np.array( + [ + [ + False, + False, + True, + True, + True, + True, + True, + True, + True, + False, + False, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + True, + True, + True, + True, + True, + True, + False, + False, + True, + True, + True, + True, + True, + True, + True, + ] + ] + ) + assert np.array_equal(collector.data.obs.mask, output5) result = collector.collect_result(action=action.reshape(1), render=0.1) # Result should be empty because this is an illegal move - output7 = {'n/ep': 0, 'n/st': 1, 'rews': np.array([], dtype=np.float64), 'lens': np.array([], dtype=np.int64), 'idxs': np.array([], dtype=np.int64), 'rew': 0, 'len': 0, 'rew_std': 0, 'len_std': 0} - assert(str(result) == str(output7)) + output7 = { + "n/ep": 0, + "n/st": 1, + "rews": np.array([], dtype=np.float64), + "lens": np.array([], dtype=np.int64), + "idxs": np.array([], dtype=np.int64), + "rew": 0, + "len": 0, + "rew_std": 0, + "len_std": 0, + } + assert str(result) == str(output7) # Board state should be unchanged because the bot tried to execute an illegal move" - output8 = np.array([[[ 0., 0., 0.], - [ 0., 0., 0.], - [ 0., 0., 0.]], - [[ 3., 4., 0.], - [ 0., 0., 0.], - [ 0., 0., 0.]], - [[-5., -6., 0.], - [ 0., 0., 0.], - [ 0., 0., 0.]]]) - assert(np.array_equal(pettingzoo_env.unwrapped.board.squares.reshape(3,3,3), output8)) + output8 = np.array( + [ + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[-5.0, -6.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ] + ) + assert np.array_equal( + pettingzoo_env.unwrapped.board.squares.reshape(3, 3, 3), output8 + ) if __name__ == "__main__": # train the agent and watch its performance in a match! print("Starting game...") - test_collector() \ No newline at end of file + test_collector()