diff --git a/examples/stable_diffusion_v2/tools/model_conversion/convert_weights.py b/examples/stable_diffusion_v2/tools/model_conversion/convert_weights.py index a189f9d961..f8bd6833d8 100644 --- a/examples/stable_diffusion_v2/tools/model_conversion/convert_weights.py +++ b/examples/stable_diffusion_v2/tools/model_conversion/convert_weights.py @@ -52,7 +52,7 @@ def _load_torch_ckpt(ckpt_file): source_data = torch.load(ckpt_file, map_location="cpu") - if ["state_dict"] in source_data: + if "state_dict" in source_data: source_data = source_data["state_dict"] return source_data