Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into policy_compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-soare committed May 10, 2024
2 parents 2ea8ad4 + b187942 commit 4ae8d61
Show file tree
Hide file tree
Showing 15 changed files with 90 additions and 26 deletions.
7 changes: 3 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ test-end-to-end:
${MAKE} test-act-ete-eval
${MAKE} test-diffusion-ete-train
${MAKE} test-diffusion-ete-eval
# TODO(rcadene, alexander-soare): enable end-to-end tests for tdmpc
# ${MAKE} test-tdmpc-ete-train
# ${MAKE} test-tdmpc-ete-eval
${MAKE} test-tdmpc-ete-train
${MAKE} test-tdmpc-ete-eval
${MAKE} test-default-ete-eval

test-act-ete-train:
Expand Down Expand Up @@ -80,7 +79,7 @@ test-tdmpc-ete-train:
policy=tdmpc \
env=xarm \
env.task=XarmLift-v0 \
dataset_repo_id=lerobot/xarm_lift_medium_replay \
dataset_repo_id=lerobot/xarm_lift_medium \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=2 \
Expand Down
27 changes: 27 additions & 0 deletions lerobot/common/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
import os.path as osp
import random
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import Generator

import hydra
import numpy as np
Expand Down Expand Up @@ -39,6 +41,31 @@ def set_global_seed(seed):
torch.cuda.manual_seed_all(seed)


@contextmanager
def seeded_context(seed: int) -> Generator[None, None, None]:
"""Set the seed when entering a context, and restore the prior random state at exit.
Example usage:
```
a = random.random() # produces some random number
with seeded_context(1337):
b = random.random() # produces some other random number
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
```
"""
random_state = random.getstate()
np_random_state = np.random.get_state()
torch_random_state = torch.random.get_rng_state()
torch_cuda_random_state = torch.cuda.random.get_rng_state()
set_global_seed(seed)
yield None
random.setstate(random_state)
np.random.set_state(np_random_state)
torch.random.set_rng_state(torch_random_state)
torch.cuda.random.set_rng_state(torch_cuda_random_state)


def init_logging():
def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
Expand Down
12 changes: 6 additions & 6 deletions lerobot/configs/policy/act.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
seed: 1000
dataset_repo_id: lerobot/aloha_sim_insertion_human

override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)

training:
offline_steps: 80000
online_steps: 0
Expand All @@ -18,12 +24,6 @@ training:
grad_clip_norm: 10
online_steps_between_rollouts: 1

override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)

delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"

Expand Down
28 changes: 14 additions & 14 deletions lerobot/configs/policy/diffusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,20 @@
seed: 100000
dataset_repo_id: lerobot/pusht

override_dataset_stats:
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
observation.image:
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
# from the original codebase, but we should remove these and train our own pretrained model
observation.state:
min: [13.456424, 32.938293]
max: [496.14618, 510.9579]
action:
min: [12.0, 25.0]
max: [511.0, 511.0]

training:
offline_steps: 200000
online_steps: 0
Expand Down Expand Up @@ -34,20 +48,6 @@ eval:
n_episodes: 50
batch_size: 50

override_dataset_stats:
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
observation.image:
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
# from the original codebase, but we should remove these and train our own pretrained model
observation.state:
min: [13.456424, 32.938293]
max: [496.14618, 510.9579]
action:
min: [12.0, 25.0]
max: [511.0, 511.0]

policy:
name: diffusion

Expand Down
2 changes: 1 addition & 1 deletion lerobot/configs/policy/tdmpc.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# @package _global_

seed: 1
dataset_repo_id: lerobot/xarm_lift_medium_replay
dataset_repo_id: lerobot/xarm_lift_medium

training:
offline_steps: 25000
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def test_normalize(insert_temporal_dim):
@pytest.mark.parametrize(
"env_name, policy_name, extra_overrides",
[
# ("xarm", "tdmpc", ["policy.n_action_repeats=2"]),
("xarm", "tdmpc", []),
(
"pusht",
"diffusion",
Expand Down
38 changes: 38 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import random
from typing import Callable

import numpy as np
import pytest
import torch

from lerobot.common.utils.utils import seeded_context, set_global_seed


@pytest.mark.parametrize(
"rand_fn",
[
random.random,
np.random.random,
lambda: torch.rand(1).item(),
]
+ [lambda: torch.rand(1, device="cuda")]
if torch.cuda.is_available()
else [],
)
def test_seeding(rand_fn: Callable[[], int]):
set_global_seed(0)
a = rand_fn()
with seeded_context(1337):
c = rand_fn()
b = rand_fn()
set_global_seed(0)
a_ = rand_fn()
b_ = rand_fn()
# Check that `set_global_seed` lets us reproduce a and b.
assert a_ == a
# Additionally, check that the `seeded_context` didn't interrupt the global RNG.
assert b_ == b
set_global_seed(1337)
c_ = rand_fn()
# Check that `seeded_context` and `global_seed` give the same reproducibility.
assert c_ == c

0 comments on commit 4ae8d61

Please sign in to comment.