-
Notifications
You must be signed in to change notification settings - Fork 894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make policies compatible with other/multiple image keys #149
Make policies compatible with other/multiple image keys #149
Conversation
@@ -130,10 +130,3 @@ def __post_init__(self): | |||
raise ValueError( | |||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" | |||
) | |||
# Check that there is only one image. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: This check is deleted as ACT can handle multiple images.
│ │ │ │ │ | ||
inputs └─────┼─────┘ │ | ||
│ │ | ||
│ │ │ └▲──▲─▲─┘ │ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: all changes here and below in this file were just me making the code more clear.
@@ -72,6 +80,31 @@ def test_policy(env_name, policy_name, extra_overrides): | |||
+ extra_overrides, | |||
) | |||
|
|||
# Additional config override logic. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: This (and in general the idea of testing policies x datasets) scales as n². Not nice. I think we need to think of an O(N) solution like validating that datasets have a certain data key format, and policies can handle this format.
0e4553b
to
4ae8d61
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] | ||
# Note: This check is covered in the post-init of the config but have a sanity check just in case. | ||
assert len(image_keys) == 1 | ||
self.input_image_key = image_keys[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] | |
# Note: This check is covered in the post-init of the config but have a sanity check just in case. | |
assert len(image_keys) == 1 | |
self.input_image_key = image_keys[0] | |
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] | |
# Note: This check is covered in the post-init of the config but have a sanity check just in case. | |
# TODO(alexander-soare): make diffusion compatible with multiple image keys | |
assert len(image_keys) == 1 | |
self.input_image_key = image_keys[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in main docstring.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] | ||
# Note: This check is covered in the post-init of the config but have a sanity check just in case. | ||
assert len(image_keys) == 1 | ||
self.input_image_key = image_keys[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] | |
# Note: This check is covered in the post-init of the config but have a sanity check just in case. | |
assert len(image_keys) == 1 | |
self.input_image_key = image_keys[0] | |
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] | |
# Note: This check is covered in the post-init of the config but have a sanity check just in case. | |
# TODO(alexander-soare): make diffusion compatible with multiple image keys | |
assert len(image_keys) == 1 | |
self.input_image_key = image_keys[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or raise NotImplementedError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Thanks a lot, @Cadene @alexander-soare still a problem when make_datasets of Aloha_xxx while using diffusion policy, seems like the key 'ovservation.image' in diffusion.yaml is not consistent with that in datasets. script: error: I'd like to offer improvements, but I'm new to code, I'll try my best and hope you experts can provide a nice improvement. |
@Joeland4 my next task is to create an example/tutorial on how to adapt the config. Here's one that should work for you in the meantime (as in it won't raise an exception): # @package _global_
# Defaults for training for the PushT dataset as per https://github.com/real-stanford/diffusion_policy.
# Note: We do not track EMA model weights as we discovered it does not improve the results. See
# https://github.com/huggingface/lerobot/pull/134 for more details.
seed: 100000
dataset_repo_id: lerobot/aloha_sim_transfer_cube_human
training:
offline_steps: 200000
online_steps: 0
eval_freq: 5000
save_freq: 5000
log_freq: 250
save_model: true
batch_size: 64
grad_clip_norm: 10
lr: 1.0e-4
lr_scheduler: cosine
lr_warmup_steps: 500
adam_betas: [0.95, 0.999]
adam_eps: 1.0e-8
adam_weight_decay: 1.0e-6
online_steps_between_rollouts: 1
delta_timestamps:
observation.images.top: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
eval:
n_episodes: 50
batch_size: 50
policy:
name: diffusion
# Input / output structure.
n_obs_steps: 2
horizon: 16
n_action_steps: 8
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.top: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.top: mean_std
observation.state: min_max
output_normalization_modes:
action: min_max
# Architecture / modeling.
# Vision backbone.
vision_backbone: resnet18
crop_shape: [420, 420]
crop_is_random: True
pretrained_backbone_weights: null
use_group_norm: True
spatial_softmax_num_keypoints: 32
# Unet.
down_dims: [512, 1024, 2048]
kernel_size: 5
n_groups: 8
diffusion_step_embed_dim: 128
use_film_scale_modulation: True
# Noise scheduler.
num_train_timesteps: 100
beta_schedule: squaredcos_cap_v2
beta_start: 0.0001
beta_end: 0.02
prediction_type: epsilon # epsilon / sample
clip_sample: True
clip_sample_range: 1.0
# Inference
num_inference_steps: 100
# Loss computation
do_mask_loss_for_padding: false And here's the script I use to run it. DATASET=aloha_sim_transfer_cube_human
NAME=diffusion_$DATASET
python lerobot/scripts/train.py \
hydra.run.dir=outputs/train/$NAME \
hydra.job.name=$NAME \
env=aloha \
env.task=AlohaTransferCube-v0 \
dataset_repo_id=lerobot/$DATASET \
policy=diffusion_aloha \
training.save_model=true \
training.offline_steps=200000 \
training.save_freq=20000 \
training.eval_freq=10000 \
eval.n_episodes=50 \
wandb.enable=false \
wandb.disable_artifact=true \
device=cuda \ |
How did you fuse multiple images? Is there any technical reference for this? |
What this does
Makes all policies compatible with any key starting with "observation.image" (Eg: Diffusion Policy does not work with the ALOHA datasets because they use "observation.images.top" as their image key.)
Here's the design approach:
input_shapes
policy configuration parameters as the "source of truth" for the expected inputs.image_keys
attribute.image_keys
is used to unpack the batch inforward
andselect_action
As a side effect I'm also able to enable multiple image handling in ACT.
TODO: Update available_policies_per_env
How it was tested
CI tests were added for ACTPolicy/PushT and DiffusionPolicy/ALOHA. I did not add a test for TD-MPC as that also needs a "next.reward" key
For ACT I also tried stacking two of the same image int
ACTPolicy._check_and_preprocess_batch
to make sure it can handle multiple images.