Skip to content

Commit

Permalink
Align PT and Flax API - allow loading checkpoint from PyTorch configs (
Browse files Browse the repository at this point in the history
…#827)

* up

* finish

* add more tests

* up

* up

* finish
  • Loading branch information
patrickvonplaten authored Oct 13, 2022
1 parent 78db11d commit 7c22626
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 31 deletions.
55 changes: 32 additions & 23 deletions src/diffusers/pipeline_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,24 +111,27 @@ def register_modules(self, **kwargs):
from diffusers import pipelines

for name, module in kwargs.items():
# retrieve library
library = module.__module__.split(".")[0]
if module is None:
register_dict = {name: (None, None)}
else:
# retrieve library
library = module.__module__.split(".")[0]

# check if the module is a pipeline module
pipeline_dir = module.__module__.split(".")[-2]
path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# check if the module is a pipeline module
pipeline_dir = module.__module__.split(".")[-2]
path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)

# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = pipeline_dir
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = pipeline_dir

# retrieve class_name
class_name = module.__class__.__name__
# retrieve class_name
class_name = module.__class__.__name__

register_dict = {name: (library, class_name)}
register_dict = {name: (library, class_name)}

# save model index config
self.register_to_config(**register_dict)
Expand Down Expand Up @@ -320,6 +323,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
pipeline_class = cls
else:
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
class_name = (
config_dict["_class_name"]
if config_dict["_class_name"].startswith("Flax")
else "Flax" + config_dict["_class_name"]
)
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])

# some modules can be passed directly to the init
Expand All @@ -342,6 +350,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
for name, (library_name, class_name) in init_dict.items():
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
sub_model_should_be_defined = True

# if the model is in a pipeline module, then we load it from the pipeline
if name in passed_class_obj:
Expand All @@ -362,6 +371,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
)
elif passed_class_obj[name] is None:
logger.warn(
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
f" that this might lead to problems when using {pipeline_class} and is not recommended."
)
sub_model_should_be_defined = False
else:
logger.warn(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
Expand All @@ -372,25 +387,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
loaded_sub_model = passed_class_obj[name]
elif is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
if from_pt:
class_obj = import_flax_or_no_model(pipeline_module, class_name)
else:
class_obj = getattr(pipeline_module, class_name)
class_obj = import_flax_or_no_model(pipeline_module, class_name)

importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in importable_classes.keys()}
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
if from_pt:
class_obj = import_flax_or_no_model(library, class_name)
else:
class_obj = getattr(library, class_name)
class_obj = import_flax_or_no_model(library, class_name)

importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}

if loaded_sub_model is None:
if loaded_sub_model is None and sub_model_should_be_defined:
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...pipeline_flax_utils import FlaxDiffusionPipeline
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler
from ...utils import logging
from . import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
Expand Down Expand Up @@ -60,6 +64,16 @@ def __init__(
super().__init__()
self.dtype = dtype

if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)

self.register_modules(
vae=vae,
text_encoder=text_encoder,
Expand Down Expand Up @@ -265,10 +279,23 @@ def __call__(
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
)

safety_params = params["safety_checker"]
images = (images * 255).round().astype("uint8")
images = np.asarray(images).reshape(-1, height, width, 3)
images, has_nsfw_concept = self._run_safety_checker(images, safety_params, jit)
if self.safety_checker is not None:
safety_params = params["safety_checker"]
images_uint8_casted = (images * 255).round().astype("uint8")
num_devices, batch_size = images.shape[:2]

images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
images = np.asarray(images)

# block images
if any(has_nsfw_concept):
for i, is_nsfw in enumerate(has_nsfw_concept):
images[i] = np.asarray(images_uint8_casted[i])

images = images.reshape(num_devices, batch_size, height, width, 3)
else:
has_nsfw_concept = False

if not return_dict:
return (images, has_nsfw_concept)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(

if safety_checker is None:
logger.warn(
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(

if safety_checker is None:
logger.warn(
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(

if safety_checker is None:
logger.warn(
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
Expand Down
100 changes: 99 additions & 1 deletion tests/test_pipelines_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

if is_flax_available():
import jax
import jax.numpy as jnp
from diffusers import FlaxStableDiffusionPipeline
from flax.jax_utils import replicate
from flax.training.common_utils import shard
Expand All @@ -34,7 +35,7 @@
class FlaxPipelineTests(unittest.TestCase):
def test_dummy_all_tpus(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe"
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
)

prompt = (
Expand All @@ -57,6 +58,103 @@ def test_dummy_all_tpus(self):
prompt_ids = shard(prompt_ids)

images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images

assert images.shape == (8, 1, 64, 64, 3)
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.151474)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 49947.875)) < 1e-2

images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))

assert len(images_pil) == 8

def test_stable_diffusion_v1_4(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="flax", safety_checker=None
)

prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)

prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)

images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images

images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
for i, image in enumerate(images_pil):
image.save(f"/home/patrick/images/flax-test-{i}_fp32.png")

assert images.shape == (8, 1, 512, 512, 3)
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 1e-2

def test_stable_diffusion_v1_4_bfloat_16(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16, safety_checker=None
)

prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)

prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)

images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images

assert images.shape == (8, 1, 512, 512, 3)
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 1e-2

def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16
)

prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)

prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images

assert images.shape == (8, 1, 512, 512, 3)
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 1e-2

0 comments on commit 7c22626

Please sign in to comment.