Skip to content

Commit

Permalink
Merge branch 'dev' into fix-neocarbs
Browse files Browse the repository at this point in the history
  • Loading branch information
jsuarez5341 authored Feb 17, 2025
2 parents d13630e + 451839b commit 204678c
Show file tree
Hide file tree
Showing 55 changed files with 318 additions and 235 deletions.
47 changes: 47 additions & 0 deletions .github/workflows/install.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
name: install

on:
push:
pull_request:

jobs:
test:
name: test ${{ matrix.py }} - ${{ matrix.os }} - ${{ matrix.env }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os:
- ubuntu-latest
- macos-latest
py:
- "3.11"
- "3.10"
- "3.9"
env:
- pip
- conda
steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Setup Conda
if: matrix.env == 'conda'
uses: conda-incubator/setup-miniconda@v2
with:
python-version: ${{ matrix.py }}
miniconda-version: "latest"
activate-environment: test-env
auto-update-conda: true

- name: Setup Python for pip
if: matrix.env == 'pip'
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.py }}

- name: Upgrade pip
run: python -m pip install -U pip

- name: Install pufferlib
run: pip install -e .
7 changes: 6 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,9 @@ global-include *.pxd
global-include *.h
global-include *.py
recursive-include pufferlib/resources *

recursive-exclude experiments *
recursive-exclude wandb *
recursive-exclude tests *
include raylib-5.0_linux_amd64/lib/libraylib.a
include raylib-5.0_macos/lib/libraylib.a
recursive-exclude raylib-5.0_webassembly *
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
![figure](https://pufferai.github.io/source/resource/header.png)

[![PyPI version](https://badge.fury.io/py/pufferlib.svg)](https://badge.fury.io/py/pufferlib)
![PyPI - Python Version](https://img.shields.io/pypi/pyversions/pufferlib)
![Github Actions](https://github.com/PufferAI/PufferLib/actions/workflows/install.yml/badge.svg)
[![](https://dcbadge.vercel.app/api/server/spT4huaGYV?style=plastic)](https://discord.gg/spT4huaGYV)
[![Twitter](https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow%20%40jsuarez5341)](https://twitter.com/jsuarez5341)

Expand Down
1 change: 1 addition & 0 deletions config/default.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[base]
package = None
env_name = None
vec = native
policy_name = Policy
rnn_name = None
max_suggestion_cost = 3600
Expand Down
1 change: 1 addition & 0 deletions config/ocean/connect4.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[base]
package = ocean
env_name = puffer_connect4
vec = multiprocessing
policy_name = Policy
rnn_name = Recurrent

Expand Down
1 change: 1 addition & 0 deletions config/ocean/go.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[base]
package = ocean
env_name = puffer_go
vec = multiprocessing
policy_name = Go
rnn_name = Recurrent

Expand Down
3 changes: 2 additions & 1 deletion config/ocean/grid.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
[base]
package = ocean
env_name = puffer_grid
policy_name = Grid
vec = multiprocessing
policy_name = Policy
rnn_name = Recurrent

[env]
Expand Down
1 change: 1 addition & 0 deletions config/ocean/moba.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[base]
package = ocean
env_name = puffer_moba
vec = multiprocessing
policy_name = MOBA
rnn_name = Recurrent

Expand Down
1 change: 1 addition & 0 deletions config/ocean/nmmo3.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[base]
package = ocean
env_name = puffer_nmmo3
vec = multiprocessing
policy_name = NMMO3
rnn_name = NMMO3LSTM

Expand Down
1 change: 1 addition & 0 deletions config/ocean/snake.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[base]
package = ocean
env_name = puffer_snake
vec = multiprocessing
rnn_name = Recurrent

[env]
Expand Down
1 change: 1 addition & 0 deletions config/ocean/trash_pickup.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[base]
package = ocean
env_name = trash_pickup puffer_trash_pickup
vec = multiprocessing
policy_name = TrashPickup
rnn_name = Recurrent

Expand Down
16 changes: 9 additions & 7 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def sweep(args, env_name, make_env, policy_cls, rnn_cls):

def train(args, make_env, policy_cls, rnn_cls, target_metric, min_eval_points=100,
elos={'model_random.pt': 1000}, vecenv=None, wandb=None, neptune=None):

if args['vec'] == 'serial':
vec = pufferlib.vector.Serial
elif args['vec'] == 'multiprocessing':
Expand All @@ -106,7 +105,7 @@ def train(args, make_env, policy_cls, rnn_cls, target_metric, min_eval_points=10
elif args['vec'] == 'native':
vec = pufferlib.environment.PufferEnv
else:
raise ValueError(f'Invalid --vector (serial/multiprocessing/ray/native).')
raise ValueError(f'Invalid --vec (serial/multiprocessing/ray/native).')

if vecenv is None:
vecenv = pufferlib.vector.make(
Expand Down Expand Up @@ -229,7 +228,10 @@ def train(args, make_env, policy_cls, rnn_cls, target_metric, min_eval_points=10

for section in p.sections():
for key in p[section]:
argparse_key = f'--{section}.{key}'.replace('_', '-')
if section == 'base':
argparse_key = f'--{key}'.replace('_', '-')
else:
argparse_key = f'--{section}.{key}'.replace('_', '-')
parser.add_argument(argparse_key, default=p[section][key])

# Late add help so you get a dynamic menu based on the env
Expand All @@ -251,7 +253,7 @@ def train(args, make_env, policy_cls, rnn_cls, target_metric, min_eval_points=10
except:
prev[subkey] = value

package = args['base']['package']
package = args['package']
module_name = f'pufferlib.environments.{package}'
if package == 'ocean':
module_name = 'pufferlib.ocean'
Expand All @@ -260,12 +262,12 @@ def train(args, make_env, policy_cls, rnn_cls, target_metric, min_eval_points=10
env_module = importlib.import_module(module_name)

make_env = env_module.env_creator(env_name)
policy_cls = getattr(env_module.torch, args['base']['policy_name'])
policy_cls = getattr(env_module.torch, args['policy_name'])

rnn_name = args['base']['rnn_name']
rnn_name = args['rnn_name']
rnn_cls = None
if rnn_name is not None:
rnn_cls = getattr(env_module.torch, args['base']['rnn_name'])
rnn_cls = getattr(env_module.torch, args['rnn_name'])

if args['baseline']:
assert args['mode'] in ('train', 'eval', 'evaluate')
Expand Down
2 changes: 1 addition & 1 deletion pufferlib/emulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def __init__(self, env=None, env_creator=None, env_args=[], buf=None, env_kwargs
self.num_agents = len(self.possible_agents)

set_buffers(self, buf)
if isinstance(self.env.observation_space, pufferlib.spaces.Box):
if isinstance(self.env_single_observation_space, pufferlib.spaces.Box):
self.obs_struct = self.observations
else:
self.obs_struct = self.observations.view(self.obs_dtype)
Expand Down
10 changes: 5 additions & 5 deletions pufferlib/ocean/breakout/breakout.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ void demo() {
.brick_cols = 18,
};
allocate(&env);
reset(&env);
c_reset(&env);

Client* client = make_client(&env);

Expand All @@ -35,8 +35,8 @@ void demo() {
forward_linearlstm(net, env.observations, env.actions);
}

step(&env);
render(client, &env);
c_step(&env);
c_render(client, &env);
}
free_linearlstm(net);
free(weights);
Expand All @@ -60,13 +60,13 @@ void performance_test() {
.brick_cols = 18,
};
allocate(&env);
reset(&env);
c_reset(&env);

long start = time(NULL);
int i = 0;
while (time(NULL) - start < test_time) {
env.actions[0] = rand() % 4;
step(&env);
c_step(&env);
i++;
}
long end = time(NULL);
Expand Down
8 changes: 4 additions & 4 deletions pufferlib/ocean/breakout/breakout.h
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ void reset_round(Breakout* env) {
env->ball_vx = 0.0;
env->ball_vy = 0.0;
}
void reset(Breakout* env) {
void c_reset(Breakout* env) {
env->log = (Log){0};
env->score = 0;
env->num_balls = 5;
Expand Down Expand Up @@ -480,11 +480,11 @@ void step_frame(Breakout* env, int action) {
env->dones[0] = 1;
env->log.score = env->score;
add_log(env->log_buffer, &env->log);
reset(env);
c_reset(env);
}
}

void step(Breakout* env) {
void c_step(Breakout* env) {
env->dones[0] = 0;
env->log.episode_length += 1;
env->rewards[0] = 0.0;
Expand Down Expand Up @@ -521,7 +521,7 @@ Client* make_client(Breakout* env) {
return client;
}

void render(Client* client, Breakout* env) {
void c_render(Client* client, Breakout* env) {
if (IsKeyDown(KEY_ESCAPE)) {
exit(0);
}
Expand Down
12 changes: 6 additions & 6 deletions pufferlib/ocean/breakout/cy_breakout.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ cdef extern from "breakout.h":

Client* make_client(Breakout* env)
void close_client(Client* client)
void render(Client* client, Breakout* env)
void reset(Breakout* env)
void step(Breakout* env)
void c_render(Client* client, Breakout* env)
void c_reset(Breakout* env)
void c_step(Breakout* env)

cdef class CyBreakout:
cdef:
Expand Down Expand Up @@ -103,12 +103,12 @@ cdef class CyBreakout:
def reset(self):
cdef int i
for i in range(self.num_envs):
reset(&self.envs[i])
c_reset(&self.envs[i])

def step(self):
cdef int i
for i in range(self.num_envs):
step(&self.envs[i])
c_step(&self.envs[i])

def render(self):
cdef Breakout* env = &self.envs[0]
Expand All @@ -119,7 +119,7 @@ cdef class CyBreakout:
self.client = make_client(env)
os.chdir(cwd)

render(self.client, env)
c_render(self.client, env)

def close(self):
if self.client != NULL:
Expand Down
10 changes: 5 additions & 5 deletions pufferlib/ocean/connect4/connect4.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ void interactive() {
.piece_height = 96,
};
allocate_cconnect4(&env);
reset(&env);
c_reset(&env);

Client* client = make_client(env.width, env.height);
float observations[42] = {0};
Expand Down Expand Up @@ -43,10 +43,10 @@ void interactive() {

tick = (tick + 1) % 60;
if (env.actions[0] >= 0 && env.actions[0] <= 6) {
step(&env);
c_step(&env);
}

render(client, &env);
c_render(client, &env);
}
free_linearlstm(net);
free(weights);
Expand All @@ -63,13 +63,13 @@ void performance_test() {
.piece_height = 96,
};
allocate_cconnect4(&env);
reset(&env);
c_reset(&env);

long start = time(NULL);
int i = 0;
while (time(NULL) - start < test_time) {
env.actions[0] = rand() % 7;
step(&env);
c_step(&env);
i++;
}
long end = time(NULL);
Expand Down
8 changes: 4 additions & 4 deletions pufferlib/ocean/connect4/connect4.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ void compute_observation(CConnect4* env) {
}
}

void reset(CConnect4* env) {
void c_reset(CConnect4* env) {
env->log = (Log){0};
env->dones[0] = NOT_DONE;
env->player_pieces = 0;
Expand All @@ -294,13 +294,13 @@ void finish_game(CConnect4* env, float reward) {
compute_observation(env);
}

void step(CConnect4* env) {
void c_step(CConnect4* env) {
env->log.episode_length += 1;
env->rewards[0] = 0.0;

if (env->dones[0] == DONE) {
add_log(env->log_buffer, &env->log);
reset(env);
c_reset(env);
return;
}

Expand Down Expand Up @@ -359,7 +359,7 @@ Client* make_client(int width, int height) {
return client;
}

void render(Client* client, CConnect4* env) {
void c_render(Client* client, CConnect4* env) {
if (IsKeyDown(KEY_ESCAPE)) {
exit(0);
}
Expand Down
Loading

0 comments on commit 204678c

Please sign in to comment.