From 5f5f0c25c5e9f81a4135392b9d51aeda3116d3d1 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 14 Apr 2023 15:18:16 +0200 Subject: [PATCH 1/2] [requirements] Add PyArrow to ray[tune] dependencies Signed-off-by: Kai Fricke --- python/ray/air/examples/dreambooth/dataset.py | 2 +- python/setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/air/examples/dreambooth/dataset.py b/python/ray/air/examples/dreambooth/dataset.py index c9f078b42bcea..c9ab28b3c3791 100644 --- a/python/ray/air/examples/dreambooth/dataset.py +++ b/python/ray/air/examples/dreambooth/dataset.py @@ -81,7 +81,7 @@ def collate(batch, device, dtype): # of the batch. # During training, a batch will be chunked into 2 sub-batches for prior # preserving loss calculation. - images = torch.squeeze(torch.stack([batch["image"], batch["image_1"]])) + images = torch.cat([batch["image"], batch["image_1"]], dim=1) images = images.to(memory_format=torch.contiguous_format).float() prompt_ids = torch.cat([batch["prompt_ids"], batch["prompt_ids_1"]], dim=0) diff --git a/python/setup.py b/python/setup.py index 1cf991e7f8842..b8d7c45a74399 100644 --- a/python/setup.py +++ b/python/setup.py @@ -261,7 +261,7 @@ def get_packages(self): "smart_open", ], "serve": ["uvicorn", "requests", "starlette", "fastapi", "aiorwlock"], - "tune": ["pandas", "tabulate", "tensorboardX>=1.9", "requests"], + "tune": ["pandas", "tabulate", "tensorboardX>=1.9", "requests", pyarrow_dep], "k8s": ["kubernetes", "urllib3"], "observability": [ "opentelemetry-api", From c26b67c44c8ec1366b6bd704c6b7044c7bd4c1af Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 14 Apr 2023 15:18:47 +0200 Subject: [PATCH 2/2] Revert wip Signed-off-by: Kai Fricke --- python/ray/air/examples/dreambooth/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/air/examples/dreambooth/dataset.py b/python/ray/air/examples/dreambooth/dataset.py index c9ab28b3c3791..c9f078b42bcea 100644 --- a/python/ray/air/examples/dreambooth/dataset.py +++ b/python/ray/air/examples/dreambooth/dataset.py @@ -81,7 +81,7 @@ def collate(batch, device, dtype): # of the batch. # During training, a batch will be chunked into 2 sub-batches for prior # preserving loss calculation. - images = torch.cat([batch["image"], batch["image_1"]], dim=1) + images = torch.squeeze(torch.stack([batch["image"], batch["image_1"]])) images = images.to(memory_format=torch.contiguous_format).float() prompt_ids = torch.cat([batch["prompt_ids"], batch["prompt_ids_1"]], dim=0)