Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make policies compatible with other/multiple image keys #149

Merged
merged 12 commits into from
May 16, 2024
17 changes: 12 additions & 5 deletions lerobot/common/policies/diffusion/configuration_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,21 @@ def __post_init__(self):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
# There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) != 1:
raise ValueError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
image_key = next(iter(image_keys))
if (
self.crop_shape[0] > self.input_shapes["observation.image"][1]
or self.crop_shape[1] > self.input_shapes["observation.image"][2]
self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes[image_key][2]
):
raise ValueError(
f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} '
f'for `crop_shape` and {self.input_shapes["observation.image"]} for '
'`input_shapes["observation.image"]`.'
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
f"for `crop_shape` and {self.input_shapes[image_key]} for "
"`input_shapes[{image_key}]`."
)
supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types:
Expand Down
37 changes: 28 additions & 9 deletions lerobot/common/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,32 @@ def __init__(
self.diffusion = DiffusionModel(config)

def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
"""
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
"observation.image": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
}

def _preprocess_batch_keys(self, batch: dict[str, Tensor], train_mode: bool = False):
"""Check that the keys can be handled by this policy and standardize the image key.

This should be run after input normalization.
"""
assert "observation.state" in batch
# There should only be one image key.
alexander-soare marked this conversation as resolved.
Show resolved Hide resolved
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
assert (
len(image_keys) == 1
), f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
if train_mode:
assert "action" in batch
assert "action_is_pad" in batch
image_key = next(iter(image_keys))
if image_key != "observation.image":
batch["observation.image"] = batch[image_key]
del batch[image_key]

alexander-soare marked this conversation as resolved.
Show resolved Hide resolved
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
Expand All @@ -98,10 +115,8 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
assert "observation.image" in batch
assert "observation.state" in batch

batch = self.normalize_inputs(batch)
self._preprocess_batch_keys(batch)

self._queues = populate_queues(self._queues, batch)

Expand All @@ -121,6 +136,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
self._preprocess_batch_keys(batch)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
Expand Down Expand Up @@ -185,13 +201,12 @@ def conditional_sample(

def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
"""
This function expects `batch` to have (at least):
This function expects `batch` to have:
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W)
}
"""
assert set(batch).issuperset({"observation.state", "observation.image"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.config.n_obs_steps

Expand Down Expand Up @@ -315,9 +330,13 @@ def __init__(self, config: DiffusionConfig):

# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
image_keys = {k for k in config.input_shapes if k.startswith("observation.image")}
assert len(image_keys) == 1
with torch.inference_mode():
feat_map_shape = tuple(
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:]
self.backbone(
torch.zeros(size=(1, config.input_shapes[next(iter(image_keys))][0], *config.crop_shape))
).shape[1:]
)
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
Expand Down
3 changes: 2 additions & 1 deletion lerobot/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars
from omegaconf import DictConfig

from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
Expand Down Expand Up @@ -290,7 +291,7 @@ def shift_indices(episode_index, index):
sampler.num_samples = len(concat_dataset)


def train(cfg: dict, out_dir=None, job_name=None):
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None:
raise NotImplementedError()
if job_name is None:
Expand Down
Loading