From 5c439888626145f94db1fdb00f5787ad27b64602 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 10 Sep 2022 10:02:43 -0400 Subject: [PATCH] reduce VRAM memory usage by half during model loading * This moves the call to half() before model.to(device) to avoid GPU copy of full model. Improves speed and reduces memory usage dramatically * This fix contributed by @mh-dm (Mihai) --- ldm/generate.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index 27f89bb4d6e..05e2c6a4403 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -536,9 +536,6 @@ def _load_model_from_config(self, config, ckpt): sd = pl_sd['state_dict'] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) - model.to(self.device) - model.eval() - if self.full_precision: print( @@ -549,6 +546,8 @@ def _load_model_from_config(self, config, ckpt): '>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.' ) model.half() + model.to(self.device) + model.eval() # usage statistics toc = time.time()