Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make policies compatible with other/multiple image keys #149

Merged
merged 12 commits into from
May 16, 2024

Conversation

alexander-soare
Copy link
Contributor

@alexander-soare alexander-soare commented May 8, 2024

What this does

Makes all policies compatible with any key starting with "observation.image" (Eg: Diffusion Policy does not work with the ALOHA datasets because they use "observation.images.top" as their image key.)

Here's the design approach:

  1. Use the input_shapes policy configuration parameters as the "source of truth" for the expected inputs.
  2. Implicit logic: any key starting with "observation.image" is an image key.
  3. Policies use 1 and 2 to create a image_keys attribute.
  4. image_keys is used to unpack the batch in forward and select_action

As a side effect I'm also able to enable multiple image handling in ACT.

TODO: Update available_policies_per_env

How it was tested

CI tests were added for ACTPolicy/PushT and DiffusionPolicy/ALOHA. I did not add a test for TD-MPC as that also needs a "next.reward" key

For ACT I also tried stacking two of the same image int ACTPolicy._check_and_preprocess_batch to make sure it can handle multiple images.

@alexander-soare alexander-soare requested a review from Cadene May 8, 2024 11:48
@alexander-soare alexander-soare marked this pull request as draft May 8, 2024 11:56
@alexander-soare alexander-soare changed the title Make Diffusion Policy compatible with other image keys Make policies compatible with other/multiple image keys May 9, 2024
@@ -130,10 +130,3 @@ def __post_init__(self):
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
# Check that there is only one image.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: This check is deleted as ACT can handle multiple images.

│ │ │ │ │
inputs └─────┼─────┘ │
│ │
│ │ │ └▲──▲─▲─┘ │
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: all changes here and below in this file were just me making the code more clear.

@@ -72,6 +80,31 @@ def test_policy(env_name, policy_name, extra_overrides):
+ extra_overrides,
)

# Additional config override logic.
Copy link
Contributor Author

@alexander-soare alexander-soare May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: This (and in general the idea of testing policies x datasets) scales as n². Not nice. I think we need to think of an O(N) solution like validating that datasets have a certain data key format, and policies can handle this format.

@alexander-soare alexander-soare marked this pull request as ready for review May 9, 2024 11:02
Copy link
Collaborator

@Cadene Cadene left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Comment on lines 86 to 89
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
assert len(image_keys) == 1
self.input_image_key = image_keys[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
assert len(image_keys) == 1
self.input_image_key = image_keys[0]
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
# TODO(alexander-soare): make diffusion compatible with multiple image keys
assert len(image_keys) == 1
self.input_image_key = image_keys[0]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in main docstring.

Comment on lines +115 to +118
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
assert len(image_keys) == 1
self.input_image_key = image_keys[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
assert len(image_keys) == 1
self.input_image_key = image_keys[0]
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
# TODO(alexander-soare): make diffusion compatible with multiple image keys
assert len(image_keys) == 1
self.input_image_key = image_keys[0]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or raise NotImplementedError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@alexander-soare alexander-soare merged commit 68c1b13 into huggingface:main May 16, 2024
5 checks passed
@alexander-soare alexander-soare deleted the policy_compatibility branch May 16, 2024 12:51
@Joeland4
Copy link

Nice!

Thanks a lot, @Cadene @alexander-soare still a problem when make_datasets of Aloha_xxx while using diffusion policy, seems like the key 'ovservation.image' in diffusion.yaml is not consistent with that in datasets.

script:
python lerobot/scripts/train.py policy=diffusion env=aloha env.task=AlohaInsertion-v0 dataset_repo_id=lerobot/aloha_sim_insertion_human

error:
s/lerobot/common/datasets/factory.py", line 53, in make_dataset
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
KeyError: 'observation.image'

I'd like to offer improvements, but I'm new to code, I'll try my best and hope you experts can provide a nice improvement.

@alexander-soare
Copy link
Contributor Author

@Joeland4 my next task is to create an example/tutorial on how to adapt the config. Here's one that should work for you in the meantime (as in it won't raise an exception):

# @package _global_

# Defaults for training for the PushT dataset as per https://github.com/real-stanford/diffusion_policy.
# Note: We do not track EMA model weights as we discovered it does not improve the results. See
#       https://github.com/huggingface/lerobot/pull/134 for more details.

seed: 100000
dataset_repo_id: lerobot/aloha_sim_transfer_cube_human

training:
  offline_steps: 200000
  online_steps: 0
  eval_freq: 5000
  save_freq: 5000
  log_freq: 250
  save_model: true

  batch_size: 64
  grad_clip_norm: 10
  lr: 1.0e-4
  lr_scheduler: cosine
  lr_warmup_steps: 500
  adam_betas: [0.95, 0.999]
  adam_eps: 1.0e-8
  adam_weight_decay: 1.0e-6
  online_steps_between_rollouts: 1
  
  delta_timestamps:
    observation.images.top: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
    observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
    action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"

eval:
  n_episodes: 50
  batch_size: 50


policy:
  name: diffusion

  # Input / output structure.
  n_obs_steps: 2
  horizon: 16
  n_action_steps: 8

  input_shapes:
    # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
    observation.images.top: [3, 480, 640]
    observation.state: ["${env.state_dim}"]
  output_shapes:
    action: ["${env.action_dim}"]

  # Normalization / Unnormalization
  input_normalization_modes:
    observation.images.top: mean_std
    observation.state: min_max
  output_normalization_modes:
    action: min_max

  # Architecture / modeling.
  # Vision backbone.
  vision_backbone: resnet18
  crop_shape: [420, 420]
  crop_is_random: True
  pretrained_backbone_weights: null
  use_group_norm: True
  spatial_softmax_num_keypoints: 32
  # Unet.
  down_dims: [512, 1024, 2048]
  kernel_size: 5
  n_groups: 8
  diffusion_step_embed_dim: 128
  use_film_scale_modulation: True
  # Noise scheduler.
  num_train_timesteps: 100
  beta_schedule: squaredcos_cap_v2
  beta_start: 0.0001
  beta_end: 0.02
  prediction_type: epsilon # epsilon / sample
  clip_sample: True
  clip_sample_range: 1.0

  # Inference
  num_inference_steps: 100

  # Loss computation
  do_mask_loss_for_padding: false

And here's the script I use to run it.

DATASET=aloha_sim_transfer_cube_human
NAME=diffusion_$DATASET

python lerobot/scripts/train.py \
    hydra.run.dir=outputs/train/$NAME \
    hydra.job.name=$NAME \
    env=aloha \
    env.task=AlohaTransferCube-v0 \
    dataset_repo_id=lerobot/$DATASET \
    policy=diffusion_aloha \
    training.save_model=true \
    training.offline_steps=200000 \
    training.save_freq=20000 \
    training.eval_freq=10000 \
    eval.n_episodes=50 \
    wandb.enable=false \
    wandb.disable_artifact=true \
    device=cuda \

@nnop
Copy link

nnop commented Nov 12, 2024

How did you fuse multiple images? Is there any technical reference for this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants