Skip to content

Commit

Permalink
hotswap with bfl code
Browse files Browse the repository at this point in the history
  • Loading branch information
daanelson committed Feb 4, 2025
1 parent 729489b commit d67a42f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 22 deletions.
1 change: 0 additions & 1 deletion cog.yaml.template
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ build:
- "tokenizers==0.19.1"
- "protobuf==5.27.2"
- "diffusers==0.32.2"
- "peft==0.14.0"
- "loguru==0.7.2"
- "pybase64==1.4.0"
- "pydash==8.0.3"
Expand Down
42 changes: 21 additions & 21 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,37 +715,34 @@ class HotswapPredictor(Predictor):
def setup(self) -> None:
self.base_setup()
shared_cache = WeightsDownloadCache()

self.bf16_dev = DiffusersFlux(FLUX_DEV, shared_cache)
shared_models = self.bf16_dev.get_models()

# hack to get around delta in vae code
bfl_ae = load_ae(FLUX_DEV)

shared_models_for_fp8 = LoadedModels(
ae=bfl_ae,
clip=PreLoadedHFEmbedder(True, 77, shared_models.tokenizer, shared_models.text_encoder),
t5=PreLoadedHFEmbedder(False, 512, shared_models.tokenizer_2, shared_models.text_encoder_2),
flow=None,
config=None
self.bf16_dev= BflBf16Predictor(
FLUX_DEV,
offload=self.should_offload(),
weights_download_cache=shared_cache,
restore_lora_from_cloned_weights=True,
)
self.fp8_dev = BflFp8Flux(
FLUX_DEV_FP8,
shared_models_for_fp8,
torch_compile=True,
loaded_models=self.bf16_dev.get_shared_models(),
torch_compile=False,
compilation_aspect_ratios=ASPECT_RATIOS,
offload=self.should_offload(),
weights_download_cache=shared_cache,
restore_lora_from_cloned_weights=True,
)

self.bf16_schnell = DiffusersFlux(FLUX_SCHNELL, shared_cache, shared_models)
shared_models_for_fp8.t5=PreLoadedHFEmbedder(False, 256, shared_models.tokenizer_2, shared_models.text_encoder_2)

self.bf16_schnell = BflBf16Predictor(
FLUX_SCHNELL,
loaded_models=self.bf16_dev.get_shared_models(),
offload=self.should_offload(),
weights_download_cache=shared_cache,
restore_lora_from_cloned_weights=True
)
self.fp8_schnell = BflFp8Flux(
FLUX_SCHNELL_FP8,
shared_models_for_fp8,
torch_compile=True,
loaded_models=self.bf16_dev.get_shared_models(),
torch_compile=False,
compilation_aspect_ratios=ASPECT_RATIOS,
offload=self.should_offload(),
weights_download_cache=shared_cache,
restore_lora_from_cloned_weights=True,
)
Expand Down Expand Up @@ -862,6 +859,9 @@ def predict(
height=height,
)

# unload loras s.t. everything fits in memory
model.handle_loras(None, 1.0, None, 1.0)

return self.postprocess(
imgs,
disable_safety_checker,
Expand Down

0 comments on commit d67a42f

Please sign in to comment.