Skip to content

Commit

Permalink
[air/release] Fix batch format in dreambooth example (#37102) (#37189)
Browse files Browse the repository at this point in the history
This fixes an error caused by the default batch format of Ray Data changing to numpy. We need to manually specify pandas.

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
  • Loading branch information
justinvyu authored Jul 7, 2023
1 parent a4bc5b2 commit 47ec25b
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
2 changes: 1 addition & 1 deletion python/ray/air/examples/dreambooth/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_train_dataset(args, image_resolution=512):
# prior preserving loss in one pass.
dup_times = class_dataset.count() // instance_dataset.count()
instance_dataset = instance_dataset.map_batches(
lambda df: pd.concat([df] * dup_times)
lambda df: pd.concat([df] * dup_times), batch_format="pandas"
)

# Load tokenizer for tokenizing the image prompts.
Expand Down
17 changes: 9 additions & 8 deletions python/ray/air/examples/dreambooth/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
accelerate==0.15.0
bitsandbytes
diffusers==0.11.1
flax==0.6.4
huggingface_hub
numpy==1.21
torchvision
transformers>=4.25.1
accelerate==0.20.3
bitsandbytes==0.39.1
diffusers==0.17.1
flax==0.6.11
huggingface_hub==0.16.2
numpy==1.24.4
torch==2.0.1
torchvision==0.15.2
transformers==4.30.2
6 changes: 5 additions & 1 deletion release/air_examples/dreambooth/dreambooth_env.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
base_image: {{ env["RAY_IMAGE_ML_NIGHTLY_GPU"] | default("anyscale/ray-ml:nightly-py37-gpu") }}
# NOTE:
# - This test runs with py38 (see the entry in release_tests.yaml)
# - This test installs dependencies on top of a base ray image
# instead of using the default ray-ml image. See dreambooth/requirements.txt.
base_image: "anyscale/ray:nightly-py38-cu118"
env_vars: {}
debian_packages:
- curl
Expand Down

0 comments on commit 47ec25b

Please sign in to comment.