From 8357dae26a58f019b39e58a64b83d4233c6b383d Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 8 May 2024 12:42:40 +0100 Subject: [PATCH 1/9] first draft --- .../diffusion/configuration_diffusion.py | 17 ++++++--- .../policies/diffusion/modeling_diffusion.py | 36 ++++++++++++++----- lerobot/scripts/train.py | 3 +- 3 files changed, 41 insertions(+), 15 deletions(-) diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index d7341c33b..272986abd 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -130,14 +130,21 @@ def __post_init__(self): raise ValueError( f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." ) + # There should only be one image key. + image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} + if len(image_keys) != 1: + raise ValueError( + f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." + ) + image_key = next(iter(image_keys)) if ( - self.crop_shape[0] > self.input_shapes["observation.image"][1] - or self.crop_shape[1] > self.input_shapes["observation.image"][2] + self.crop_shape[0] > self.input_shapes[image_key][1] + or self.crop_shape[1] > self.input_shapes[image_key][2] ): raise ValueError( - f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} ' - f'for `crop_shape` and {self.input_shapes["observation.image"]} for ' - '`input_shapes["observation.image"]`.' + f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} " + f"for `crop_shape` and {self.input_shapes[image_key]} for " + "`input_shapes[{image_key}]`." ) supported_prediction_types = ["epsilon", "sample"] if self.prediction_type not in supported_prediction_types: diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 91cf6dd06..e5d099da1 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -67,15 +67,31 @@ def __init__( self.diffusion = DiffusionModel(config) def reset(self): - """ - Clear observation and action queues. Should be called on `env.reset()` - """ + """Clear observation and action queues. Should be called on `env.reset()`""" self._queues = { "observation.image": deque(maxlen=self.config.n_obs_steps), "observation.state": deque(maxlen=self.config.n_obs_steps), "action": deque(maxlen=self.config.n_action_steps), } + def _preprocess_batch_keys(self, batch: dict[str, Tensor], train_mode: bool = False): + """Check that the keys can be handled by this policy and standardize the image key. + + This should be run after input normalization. + """ + assert "observation.state" in batch + # There should only be one image key. + image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")} + assert ( + len(image_keys) == 1 + ), f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." + if train_mode: + assert "action" in batch + assert "action_is_pad" in batch + image_key = next(iter(image_keys)) + batch["observation.image"] = batch[image_key] + del batch[image_key] + @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. @@ -98,10 +114,8 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: "horizon" may not the best name to describe what the variable actually means, because this period is actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. """ - assert "observation.image" in batch - assert "observation.state" in batch - batch = self.normalize_inputs(batch) + self._preprocess_batch_keys(batch) self._queues = populate_queues(self._queues, batch) @@ -121,6 +135,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) + self._preprocess_batch_keys(batch) batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) return {"loss": loss} @@ -185,13 +200,12 @@ def conditional_sample( def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: """ - This function expects `batch` to have (at least): + This function expects `batch` to have: { "observation.state": (B, n_obs_steps, state_dim) "observation.image": (B, n_obs_steps, C, H, W) } """ - assert set(batch).issuperset({"observation.state", "observation.image"}) batch_size, n_obs_steps = batch["observation.state"].shape[:2] assert n_obs_steps == self.config.n_obs_steps @@ -315,9 +329,13 @@ def __init__(self, config: DiffusionConfig): # Set up pooling and final layers. # Use a dry run to get the feature map shape. + image_keys = {k for k in config.input_shapes if k.startswith("observation.image")} + assert len(image_keys) == 1 with torch.inference_mode(): feat_map_shape = tuple( - self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:] + self.backbone( + torch.zeros(size=(1, config.input_shapes[next(iter(image_keys))][0], *config.crop_shape)) + ).shape[1:] ) self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints) self.feature_dim = config.spatial_softmax_num_keypoints * 2 diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index d5fedc843..801109ecf 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -8,6 +8,7 @@ import torch from datasets import concatenate_datasets from datasets.utils import disable_progress_bars, enable_progress_bars +from omegaconf import DictConfig from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle @@ -290,7 +291,7 @@ def shift_indices(episode_index, index): sampler.num_samples = len(concat_dataset) -def train(cfg: dict, out_dir=None, job_name=None): +def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): if out_dir is None: raise NotImplementedError() if job_name is None: From c3a5cbb0b6f4daf87136ac3a7d1c8eb7130bb622 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 8 May 2024 12:58:14 +0100 Subject: [PATCH 2/9] don't delete observation.image key --- lerobot/common/policies/diffusion/modeling_diffusion.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index e5d099da1..1ed9d3914 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -89,8 +89,9 @@ def _preprocess_batch_keys(self, batch: dict[str, Tensor], train_mode: bool = Fa assert "action" in batch assert "action_is_pad" in batch image_key = next(iter(image_keys)) - batch["observation.image"] = batch[image_key] - del batch[image_key] + if image_key != "observation.image": + batch["observation.image"] = batch[image_key] + del batch[image_key] @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: From 5580b51ee26de2adcc925ffb9194e6de4c7e5416 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 8 May 2024 18:24:34 +0100 Subject: [PATCH 3/9] backup wip --- .../common/policies/act/configuration_act.py | 13 ++++--- lerobot/common/policies/act/modeling_act.py | 35 ++++++++++--------- .../policies/diffusion/modeling_diffusion.py | 14 ++++---- 3 files changed, 32 insertions(+), 30 deletions(-) diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index a3980b14d..d1040d52a 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -130,10 +130,9 @@ def __post_init__(self): raise ValueError( f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" ) - # Check that there is only one image. - # TODO(alexander-soare): generalize this to multiple images. - if ( - sum(k.startswith("observation.images.") for k in self.input_shapes) != 1 - or "observation.images.top" not in self.input_shapes - ): - raise ValueError('For now, only "observation.images.top" is accepted for an image input.') + # There should only be one image key. + image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} + if len(image_keys) != 1: + raise ValueError( + f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." + ) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 5ff25fea2..9ef01e573 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -47,6 +47,7 @@ def __init__( if config is None: config = ACTConfig() self.config = config + self.normalize_inputs = Normalize( config.input_shapes, config.input_normalization_modes, dataset_stats ) @@ -56,8 +57,11 @@ def __init__( self.unnormalize_outputs = Unnormalize( config.output_shapes, config.output_normalization_modes, dataset_stats ) + self.model = ACT(config) + self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + def reset(self): """This should be called whenever the environment is reset.""" if self.config.n_action_steps is not None: @@ -71,13 +75,10 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: environment. It works by managing the actions in a queue and only calling `select_actions` when the queue is empty. """ - assert "observation.images.top" in batch - assert "observation.state" in batch - self.eval() batch = self.normalize_inputs(batch) - self._stack_images(batch) + self._check_and_preprocess_batch(batch) if len(self._action_queue) == 0: # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue @@ -94,7 +95,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) - self._stack_images(batch) + self._check_and_preprocess_batch(batch, train_mode=True) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) l1_loss = ( @@ -117,20 +118,20 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: return loss_dict - def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - """Stacks all the images in a batch and puts them in a new key: "observation.images". + def _check_and_preprocess_batch(self, batch: dict[str, Tensor], train_mode: bool = False): + """Check that the keys can be handled by this policy and stack all images into one tensor. - This function expects `batch` to have (at least): - { - "observation.state": (B, state_dim) batch of robot states. - "observation.images.{name}": (B, C, H, W) tensor of images. - } + This should be run after input normalization. """ - # Stack images in the order dictated by input_shapes. - batch["observation.images"] = torch.stack( - [batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")], - dim=-4, - ) + assert "observation.state" in batch + image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")} + assert image_keys == set( + self.expected_image_keys + ), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}." + if train_mode: + assert "action" in batch + assert "action_is_pad" in batch + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) class ACT(nn.Module): diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 1ed9d3914..d3d1370a0 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -66,6 +66,8 @@ def __init__( self.diffusion = DiffusionModel(config) + self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + def reset(self): """Clear observation and action queues. Should be called on `env.reset()`""" self._queues = { @@ -74,7 +76,7 @@ def reset(self): "action": deque(maxlen=self.config.n_action_steps), } - def _preprocess_batch_keys(self, batch: dict[str, Tensor], train_mode: bool = False): + def _check_and_preprocess_batch_keys(self, batch: dict[str, Tensor], train_mode: bool = False): """Check that the keys can be handled by this policy and standardize the image key. This should be run after input normalization. @@ -82,9 +84,9 @@ def _preprocess_batch_keys(self, batch: dict[str, Tensor], train_mode: bool = Fa assert "observation.state" in batch # There should only be one image key. image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")} - assert ( - len(image_keys) == 1 - ), f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." + assert image_keys == set( + self.expected_image_keys + ), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}." if train_mode: assert "action" in batch assert "action_is_pad" in batch @@ -116,7 +118,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. """ batch = self.normalize_inputs(batch) - self._preprocess_batch_keys(batch) + self._check_and_preprocess_batch_keys(batch) self._queues = populate_queues(self._queues, batch) @@ -136,7 +138,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) - self._preprocess_batch_keys(batch) + self._check_and_preprocess_batch_keys(batch, train_mode=True) batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) return {"loss": loss} From e80fc1d7eb8443ca3eb5d12fcd873b8e32c4d1d1 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 8 May 2024 18:43:28 +0100 Subject: [PATCH 4/9] backup wip --- lerobot/common/policies/act/modeling_act.py | 38 ++++++++++--------- .../policies/diffusion/modeling_diffusion.py | 12 ++++-- .../policies/tdmpc/configuration_tdmpc.py | 12 ++++-- .../common/policies/tdmpc/modeling_tdmpc.py | 27 +++++++++++++ 4 files changed, 65 insertions(+), 24 deletions(-) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 9ef01e573..bdfffd115 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -67,6 +67,25 @@ def reset(self): if self.config.n_action_steps is not None: self._action_queue = deque([], maxlen=self.config.n_action_steps) + def _check_and_preprocess_batch( + self, batch: dict[str, Tensor], train_mode: bool = False + ) -> dict[str, Tensor]: + """Check that the keys can be handled by this policy and stack all images into one tensor. + + This should be run after input normalization. + """ + batch = dict(batch) # shallow copy + assert "observation.state" in batch + image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")} + assert image_keys == set( + self.expected_image_keys + ), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}." + if train_mode: + assert "action" in batch + assert "action_is_pad" in batch + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + return batch + @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. @@ -78,7 +97,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: self.eval() batch = self.normalize_inputs(batch) - self._check_and_preprocess_batch(batch) + batch = self._check_and_preprocess_batch(batch) if len(self._action_queue) == 0: # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue @@ -95,7 +114,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) - self._check_and_preprocess_batch(batch, train_mode=True) + batch = self._check_and_preprocess_batch(batch, train_mode=True) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) l1_loss = ( @@ -118,21 +137,6 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: return loss_dict - def _check_and_preprocess_batch(self, batch: dict[str, Tensor], train_mode: bool = False): - """Check that the keys can be handled by this policy and stack all images into one tensor. - - This should be run after input normalization. - """ - assert "observation.state" in batch - image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")} - assert image_keys == set( - self.expected_image_keys - ), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}." - if train_mode: - assert "action" in batch - assert "action_is_pad" in batch - batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) - class ACT(nn.Module): """Action Chunking Transformer: The underlying neural network for ACTPolicy. diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index be5c09d76..740ed1774 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -77,11 +77,14 @@ def reset(self): "action": deque(maxlen=self.config.n_action_steps), } - def _check_and_preprocess_batch_keys(self, batch: dict[str, Tensor], train_mode: bool = False): - """Check that the keys can be handled by this policy and standardize the image key. + def _check_and_preprocess_batch_keys( + self, batch: dict[str, Tensor], train_mode: bool = False + ) -> dict[str, Tensor]: + """Check that the keys can be handled by this policy and standardizes the image key. This should be run after input normalization. """ + batch = dict(batch) # shallow copy assert "observation.state" in batch # There should only be one image key. image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")} @@ -95,6 +98,7 @@ def _check_and_preprocess_batch_keys(self, batch: dict[str, Tensor], train_mode: if image_key != "observation.image": batch["observation.image"] = batch[image_key] del batch[image_key] + return batch @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: @@ -119,7 +123,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. """ batch = self.normalize_inputs(batch) - self._check_and_preprocess_batch_keys(batch) + batch = self._check_and_preprocess_batch_keys(batch) self._queues = populate_queues(self._queues, batch) @@ -139,7 +143,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) - self._check_and_preprocess_batch_keys(batch, train_mode=True) + batch = self._check_and_preprocess_batch_keys(batch, train_mode=True) batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) return {"loss": loss} diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index 00d00913d..391c6571d 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -131,12 +131,18 @@ class TDMPCConfig: def __post_init__(self): """Input validation (not exhaustive).""" - if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]: + # There should only be one image key. + image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} + if len(image_keys) != 1: + raise ValueError( + f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." + ) + image_key = next(iter(image_keys)) + if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]: # TODO(alexander-soare): This limitation is solely because of code in the random shift # augmentation. It should be able to be removed. raise ValueError( - "Only square images are handled now. Got image shape " - f"{self.input_shapes['observation.image']}." + f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}." ) if self.n_gaussian_samples <= 0: raise ValueError( diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 1fba43d08..e38107de0 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -96,6 +96,8 @@ def __init__( config.output_shapes, config.output_normalization_modes, dataset_stats ) + self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + def save(self, fp): """Save state dict of TOLD model to filepath.""" torch.save(self.state_dict(), fp) @@ -118,6 +120,29 @@ def reset(self): # CEM for the next step. self._prev_mean: torch.Tensor | None = None + def _check_and_preprocess_batch_keys( + self, batch: dict[str, Tensor], train_mode: bool = False + ) -> dict[str, Tensor]: + """Check that the keys can be handled by this policy and standardizes the image key. + + This should be run after input normalization. + """ + batch = dict(batch) # shallow copy + assert "observation.state" in batch + # There should only be one image key. + image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")} + assert image_keys == set( + self.expected_image_keys + ), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}." + if train_mode: + assert "action" in batch + assert "action_is_pad" in batch + image_key = next(iter(image_keys)) + if image_key != "observation.image": + batch["observation.image"] = batch[image_key] + del batch[image_key] + return batch + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]): """Select a single action given environment observations.""" @@ -125,6 +150,7 @@ def select_action(self, batch: dict[str, Tensor]): assert "observation.state" in batch batch = self.normalize_inputs(batch) + batch = self._check_and_preprocess_batch_keys(batch) self._queues = populate_queues(self._queues, batch) @@ -303,6 +329,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: device = get_device_from_parameters(self) batch = self.normalize_inputs(batch) + batch = self._check_and_preprocess_batch_keys(batch, train_mode=True) batch = self.normalize_targets(batch) info = {} From 158627d07e0b2b875a938ad5d75c6771662646da Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 9 May 2024 10:06:57 +0100 Subject: [PATCH 5/9] enable multiple images for ACT --- .../common/policies/act/configuration_act.py | 6 ------ lerobot/common/policies/act/modeling_act.py | 18 +++++++++--------- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index d1040d52a..a0d23b7a7 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -130,9 +130,3 @@ def __post_init__(self): raise ValueError( f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" ) - # There should only be one image key. - image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} - if len(image_keys) != 1: - raise ValueError( - f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." - ) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index bdfffd115..839634dac 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -166,10 +166,10 @@ class ACT(nn.Module): │ encoder │ │ │ │Transf.│ │ │ │ │ │ │encoder│ │ └───▲─────┘ │ │ │ │ │ - │ │ │ └───▲───┘ │ - │ │ │ │ │ - inputs └─────┼─────┘ │ - │ │ + │ │ │ └▲──▲─▲─┘ │ + │ │ │ │ │ │ │ + inputs └─────┼──┘ │ image emb. │ + │ state emb. │ └───────────────────────┘ """ @@ -311,18 +311,18 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso all_cam_features.append(cam_features) all_cam_pos_embeds.append(cam_pos_embed) # Concatenate camera observation feature maps and positional embeddings along the width dimension. - encoder_in = torch.cat(all_cam_features, axis=3) - cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3) + encoder_in = torch.cat(all_cam_features, axis=-1) + cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1) # Get positional embeddings for robot state and latent. - robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) - latent_embed = self.encoder_latent_input_proj(latent_sample) + robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C) + latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C) # Stack encoder input and positional embeddings moving to (S, B, C). encoder_in = torch.cat( [ torch.stack([latent_embed, robot_state_embed], axis=0), - encoder_in.flatten(2).permute(2, 0, 1), + einops.rearrange(encoder_in, "b c h w -> (h w) b c"), ] ) pos_embed = torch.cat( From 4cfebf1f0a6e7dc7ac2a2d2d815abd1e8fd8c362 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 9 May 2024 11:55:10 +0100 Subject: [PATCH 6/9] ready for review --- lerobot/common/policies/act/modeling_act.py | 23 +-------- .../policies/diffusion/modeling_diffusion.py | 40 +++++----------- .../common/policies/tdmpc/modeling_tdmpc.py | 47 ++++--------------- lerobot/common/policies/utils.py | 4 ++ tests/test_policies.py | 33 +++++++++++++ 5 files changed, 58 insertions(+), 89 deletions(-) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 839634dac..47b0cb3a7 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -67,25 +67,6 @@ def reset(self): if self.config.n_action_steps is not None: self._action_queue = deque([], maxlen=self.config.n_action_steps) - def _check_and_preprocess_batch( - self, batch: dict[str, Tensor], train_mode: bool = False - ) -> dict[str, Tensor]: - """Check that the keys can be handled by this policy and stack all images into one tensor. - - This should be run after input normalization. - """ - batch = dict(batch) # shallow copy - assert "observation.state" in batch - image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")} - assert image_keys == set( - self.expected_image_keys - ), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}." - if train_mode: - assert "action" in batch - assert "action_is_pad" in batch - batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) - return batch - @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. @@ -97,7 +78,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: self.eval() batch = self.normalize_inputs(batch) - batch = self._check_and_preprocess_batch(batch) + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) if len(self._action_queue) == 0: # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue @@ -113,8 +94,8 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch = self.normalize_targets(batch) - batch = self._check_and_preprocess_batch(batch, train_mode=True) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) l1_loss = ( diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 740ed1774..d9287f397 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -67,7 +67,12 @@ def __init__( self.diffusion = DiffusionModel(config) - self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + # Note: This check is covered in the post-init of the config but have a sanity check just in case. + assert len(image_keys) == 1 + self.input_image_key = image_keys[0] + + self.reset() def reset(self): """Clear observation and action queues. Should be called on `env.reset()`""" @@ -77,29 +82,6 @@ def reset(self): "action": deque(maxlen=self.config.n_action_steps), } - def _check_and_preprocess_batch_keys( - self, batch: dict[str, Tensor], train_mode: bool = False - ) -> dict[str, Tensor]: - """Check that the keys can be handled by this policy and standardizes the image key. - - This should be run after input normalization. - """ - batch = dict(batch) # shallow copy - assert "observation.state" in batch - # There should only be one image key. - image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")} - assert image_keys == set( - self.expected_image_keys - ), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}." - if train_mode: - assert "action" in batch - assert "action_is_pad" in batch - image_key = next(iter(image_keys)) - if image_key != "observation.image": - batch["observation.image"] = batch[image_key] - del batch[image_key] - return batch - @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. @@ -123,13 +105,13 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. """ batch = self.normalize_inputs(batch) - batch = self._check_and_preprocess_batch_keys(batch) + batch["observation.image"] = batch[self.input_image_key] self._queues = populate_queues(self._queues, batch) if len(self._queues["action"]) == 0: # stack n latest observations from the queue - batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} actions = self.diffusion.generate_actions(batch) # TODO(rcadene): make above methods return output dictionary? @@ -143,7 +125,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) - batch = self._check_and_preprocess_batch_keys(batch, train_mode=True) + batch["observation.image"] = batch[self.input_image_key] batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) return {"loss": loss} @@ -352,9 +334,9 @@ def __init__(self, config: DiffusionConfig): # Use a dry run to get the feature map shape. # The dummy input should take the number of image channels from `config.input_shapes` and it should # use the height and width from `config.crop_shape`. - image_keys = {k for k in config.input_shapes if k.startswith("observation.image")} + image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] assert len(image_keys) == 1 - image_key = next(iter(image_keys)) + image_key = image_keys[0] dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape)) with torch.inference_mode(): dummy_feature_map = self.backbone(dummy_input) diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index e38107de0..76336885c 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -96,15 +96,12 @@ def __init__( config.output_shapes, config.output_normalization_modes, dataset_stats ) - self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + # Note: This check is covered in the post-init of the config but have a sanity check just in case. + assert len(image_keys) == 1 + self.input_image_key = image_keys[0] - def save(self, fp): - """Save state dict of TOLD model to filepath.""" - torch.save(self.state_dict(), fp) - - def load(self, fp): - """Load a saved state dict from filepath into current agent.""" - self.load_state_dict(torch.load(fp)) + self.reset() def reset(self): """ @@ -120,37 +117,11 @@ def reset(self): # CEM for the next step. self._prev_mean: torch.Tensor | None = None - def _check_and_preprocess_batch_keys( - self, batch: dict[str, Tensor], train_mode: bool = False - ) -> dict[str, Tensor]: - """Check that the keys can be handled by this policy and standardizes the image key. - - This should be run after input normalization. - """ - batch = dict(batch) # shallow copy - assert "observation.state" in batch - # There should only be one image key. - image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")} - assert image_keys == set( - self.expected_image_keys - ), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}." - if train_mode: - assert "action" in batch - assert "action_is_pad" in batch - image_key = next(iter(image_keys)) - if image_key != "observation.image": - batch["observation.image"] = batch[image_key] - del batch[image_key] - return batch - @torch.no_grad() def select_action(self, batch: dict[str, Tensor]): """Select a single action given environment observations.""" - assert "observation.image" in batch - assert "observation.state" in batch - batch = self.normalize_inputs(batch) - batch = self._check_and_preprocess_batch_keys(batch) + batch["observation.image"] = batch[self.input_image_key] self._queues = populate_queues(self._queues, batch) @@ -329,14 +300,11 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: device = get_device_from_parameters(self) batch = self.normalize_inputs(batch) - batch = self._check_and_preprocess_batch_keys(batch, train_mode=True) + batch["observation.image"] = batch[self.input_image_key] batch = self.normalize_targets(batch) info = {} - # TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation. - batch_size = batch["index"].shape[0] - # (b, t) -> (t, b) for key in batch: if batch[key].ndim > 1: @@ -364,6 +332,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: # Run latent rollout using the latent dynamics model and policy model. # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action # gives us a next `z`. + batch_size = batch["index"].shape[0] z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device) z_preds[0] = self.model.encode(current_observation) reward_preds = torch.empty_like(reward, device=device) diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py index b23c13366..be300a390 100644 --- a/lerobot/common/policies/utils.py +++ b/lerobot/common/policies/utils.py @@ -4,6 +4,10 @@ def populate_queues(queues, batch): for key in batch: + # Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the + # queues have the keys they want). + if key not in queues: + continue if len(queues[key]) != queues[key].maxlen: # initialize by copying the first observation several times until the queue is full while len(queues[key]) != queues[key].maxlen: diff --git a/tests/test_policies.py b/tests/test_policies.py index 12beec92c..a5862932f 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -49,6 +49,14 @@ def test_get_policy_and_config_classes(policy_name: str): "act", ["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"], ), + # Note: these parameters also need custom logic in the test function for overriding the Hydra config. + ( + "aloha", + "diffusion", + ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"], + ), + # Note: these parameters also need custom logic in the test function for overriding the Hydra config. + ("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]), ], ) @require_env @@ -72,6 +80,31 @@ def test_policy(env_name, policy_name, extra_overrides): + extra_overrides, ) + # Additional config override logic. + if env_name == "aloha" and policy_name == "diffusion": + for keys in [ + ("training", "delta_timestamps"), + ("policy", "input_shapes"), + ("policy", "input_normalization_modes"), + ]: + dct = dict(cfg[keys[0]][keys[1]]) + dct["observation.images.top"] = dct["observation.image"] + del dct["observation.image"] + cfg[keys[0]][keys[1]] = dct + cfg.override_dataset_stats = None + + # Additional config override logic. + if env_name == "pusht" and policy_name == "act": + for keys in [ + ("policy", "input_shapes"), + ("policy", "input_normalization_modes"), + ]: + dct = dict(cfg[keys[0]][keys[1]]) + dct["observation.image"] = dct["observation.images.top"] + del dct["observation.images.top"] + cfg[keys[0]][keys[1]] = dct + cfg.training.override_dataset_stats = None + # Check that we can make the policy object. dataset = make_dataset(cfg) policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) From 2ea8ad4dafb7da4c032a848bd639578f17ad98c7 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 9 May 2024 11:57:52 +0100 Subject: [PATCH 7/9] call self.reset() in ACT --- lerobot/common/policies/act/modeling_act.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 47b0cb3a7..942b5844e 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -62,6 +62,8 @@ def __init__( self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + self.reset() + def reset(self): """This should be called whenever the environment is reset.""" if self.config.n_action_steps is not None: From 43e1c620b7a1615b825bd047122e05db93de5b6b Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 10 May 2024 07:25:06 +0100 Subject: [PATCH 8/9] fix test --- tests/test_policies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_policies.py b/tests/test_policies.py index ac44e1aaf..900ef42bb 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -103,7 +103,7 @@ def test_policy(env_name, policy_name, extra_overrides): dct["observation.image"] = dct["observation.images.top"] del dct["observation.images.top"] cfg[keys[0]][keys[1]] = dct - cfg.training.override_dataset_stats = None + cfg.override_dataset_stats = None # Check that we can make the policy object. dataset = make_dataset(cfg) From 461c48a66ce3525db789865b8e72689ffbe624e2 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 16 May 2024 13:47:58 +0100 Subject: [PATCH 9/9] revision --- lerobot/common/policies/diffusion/modeling_diffusion.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 2a2a65547..1659b68eb 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -19,6 +19,7 @@ TODO(alexander-soare): - Remove reliance on Robomimic for SpatialSoftmax. - Remove reliance on diffusers for DDPMScheduler and LR scheduler. + - Make compatible with multiple image keys. """ import math @@ -85,7 +86,10 @@ def __init__( image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] # Note: This check is covered in the post-init of the config but have a sanity check just in case. - assert len(image_keys) == 1 + if len(image_keys) != 1: + raise NotImplementedError( + f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." + ) self.input_image_key = image_keys[0] self.reset()