diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index f3d47595c..d46efef09 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -136,15 +136,19 @@ jobs: pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ - name: Run tests - # TODO: unify with-t5-data and with-clip-data flags into a single flag - # and make it possible to run only tests that require data. + # TODO: unify with-*-data flags into a single flag and make it possible to run + # only tests that require data. + # We would still want the separate flags as we may endup with data being + # scattered on different CI machines. run: | source ${VENV_DIR}/bin/activate pytest \ - --with-clip-data \ + --with-clip-data \ + --with-flux-data \ --with-t5-data \ sharktank/tests/models/clip/clip_test.py \ sharktank/tests/models/t5/t5_test.py \ + sharktank/tests/models/flux/flux_test.py \ --durations=0 diff --git a/sharktank/conftest.py b/sharktank/conftest.py index 9d6257513..d7118893a 100644 --- a/sharktank/conftest.py +++ b/sharktank/conftest.py @@ -97,6 +97,15 @@ def pytest_addoption(parser): "code. The user is expected to provide the data" ), ) + parser.addoption( + "--with-flux-data", + action="store_true", + default=False, + help=( + "Enable tests that use Flux data like models that is not a part of the source " + "code. The user is expected to provide the data" + ), + ) parser.addoption( "--with-t5-data", action="store_true", diff --git a/sharktank/integration/models/punet/integration_test.py b/sharktank/integration/models/punet/integration_test.py index 754a54311..2ebb9e155 100644 --- a/sharktank/integration/models/punet/integration_test.py +++ b/sharktank/integration/models/punet/integration_test.py @@ -67,7 +67,7 @@ def download(filename): @pytest.fixture(scope="module") def sdxl_fp16_dataset(sdxl_fp16_base_files, temp_dir): - from sharktank.models.punet.tools import import_hf_dataset + from sharktank.tools import import_hf_dataset dataset = temp_dir / "sdxl_fp16_dataset.irpa" import_hf_dataset.main( diff --git a/sharktank/sharktank/export.py b/sharktank/sharktank/export.py index 0a1c6940d..b54978e8b 100644 --- a/sharktank/sharktank/export.py +++ b/sharktank/sharktank/export.py @@ -4,11 +4,14 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Callable, Any +from typing import Callable, Optional, Any import torch +from os import PathLike +import iree.turbine.aot as aot from iree.turbine.aot import DeviceAffinity, FxProgramsBuilder from torch.utils._pytree import tree_structure, tree_unflatten, tree_flatten from .types.tensors import ShardedTensor +from .layers import BaseLayer from torch.utils._pytree import PyTree, _is_leaf import functools @@ -172,3 +175,51 @@ def flat_fn(*args, **kwargs): ) assert False, "TODO: implement the case when not using an FxProgramsBuilder" + + +def export_static_model_mlir( + model: BaseLayer, + output_path: PathLike, + function_batch_size_pairs: Optional[dict[Optional[str], list[int]]] = None, + batch_sizes: Optional[list[int]] = None, +): + """Export a model with no dynamic dimensions. + + For the set of provided function name batch sizes pair, the resulting MLIR will + have function names with the below format. + ``` + _bs + ``` + + If `batch_sizes` is given then it defaults to a single function with named + "forward". + + The model is required to implement method `sample_inputs`. + """ + + assert not (function_batch_size_pairs is not None and batch_sizes is not None) + + if batch_sizes is not None: + function_batch_size_pairs = {None: batch_sizes} + + if function_batch_size_pairs is None and batch_sizes is None: + function_batch_size_pairs = {None: batch_sizes} + + fxb = FxProgramsBuilder(model) + + for function, batch_sizes in function_batch_size_pairs.items(): + for batch_size in batch_sizes: + args, kwargs = model.sample_inputs(batch_size, function) + + @fxb.export_program( + name=f"{function or 'forward'}_bs{batch_size}", + args=args, + kwargs=kwargs, + dynamic_shapes=None, + strict=False, + ) + def _(model, **kwargs): + return model(**kwargs) + + output = aot.export(fxb) + output.save_mlir(output_path) diff --git a/sharktank/sharktank/layers/base.py b/sharktank/sharktank/layers/base.py index 11a21f885..8f74c239d 100644 --- a/sharktank/sharktank/layers/base.py +++ b/sharktank/sharktank/layers/base.py @@ -4,15 +4,12 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Dict - +from typing import Dict, Optional +from collections import OrderedDict import torch import torch.nn as nn -from ..types import ( - InferenceTensor, - Theta, -) +from ..types import InferenceTensor, Theta, AnyTensor from ..utils import debugging __all__ = [ @@ -56,6 +53,21 @@ def assert_not_nan(self, *ts: torch.Tensor): if torch.isnan(t).any(): raise AssertionError(f"Tensor contains nans! {t}") + def sample_inputs( + self, batch_size: int = 1, function: Optional[str] = None + ) -> tuple[tuple[AnyTensor], OrderedDict[str, AnyTensor]]: + """Return sample inputs that can be used to run the function from the model. + If function is None then layer is treated as the callable. + E.g. + ``` + args, kwargs = model.sample_inputs() + model(*args, **kwargs) + ``` + + One purpose of this method is to standardize exportation of models to MLIR. + """ + raise NotImplementedError() + class ThetaLayer(BaseLayer): "Base class for layers that derive parameters from a Theta object." diff --git a/sharktank/sharktank/layers/mmdit.py b/sharktank/sharktank/layers/mmdit.py index 1557883ae..1c398f608 100644 --- a/sharktank/sharktank/layers/mmdit.py +++ b/sharktank/sharktank/layers/mmdit.py @@ -55,11 +55,15 @@ def __init__(self, theta, num_heads: int): self.add_module("img_attn_qkv", LinearLayer(theta("img_attn.qkv"))) self.add_module( "img_attn_norm_q", - RMSNormLayer(theta("img_attn.norm.query_norm"), epsilon=1e-6), + RMSNormLayer( + theta("img_attn.norm.query_norm"), weight_name="scale", epsilon=1e-6 + ), ) self.add_module( "img_attn_norm_k", - RMSNormLayer(theta("img_attn.norm.key_norm"), epsilon=1e-6), + RMSNormLayer( + theta("img_attn.norm.key_norm"), weight_name="scale", epsilon=1e-6 + ), ) self.add_module("img_attn_proj", LinearLayer(theta("img_attn.proj"))) @@ -70,11 +74,15 @@ def __init__(self, theta, num_heads: int): self.add_module("txt_attn_qkv", LinearLayer(theta("txt_attn.qkv"))) self.add_module( "txt_attn_norm_q", - RMSNormLayer(theta("txt_attn.norm.query_norm"), epsilon=1e-6), + RMSNormLayer( + theta("txt_attn.norm.query_norm"), weight_name="scale", epsilon=1e-6 + ), ) self.add_module( "txt_attn_norm_k", - RMSNormLayer(theta("txt_attn.norm.key_norm"), epsilon=1e-6), + RMSNormLayer( + theta("txt_attn.norm.key_norm"), weight_name="scale", epsilon=1e-6 + ), ) self.add_module("txt_attn_proj", LinearLayer(theta("txt_attn.proj"))) @@ -151,14 +159,15 @@ def __init__(self, theta, num_heads: int): super().__init__(theta) self.num_heads = num_heads - self.add_module("mod", ModulationLayer(theta("mod"), double=False)) + self.add_module("mod", ModulationLayer(theta("modulation"), double=False)) self.add_module( - "attn_norm_q", RMSNormLayer(theta("attn.norm.query_norm"), epsilon=1e-6) + "attn_norm_q", + RMSNormLayer(theta("norm.query_norm"), weight_name="scale", epsilon=1e-6), ) self.add_module( - "attn_norm_k", RMSNormLayer(theta("attn.norm.key_norm"), epsilon=1e-6) + "attn_norm_k", + RMSNormLayer(theta("norm.key_norm"), weight_name="scale", epsilon=1e-6), ) - self.add_module("attn_proj", LinearLayer(theta("attn.proj"))) self.add_module("linear1", LinearLayer(theta("linear1"))) self.add_module("linear2", LinearLayer(theta("linear2"))) diff --git a/sharktank/sharktank/layers/testing.py b/sharktank/sharktank/layers/testing.py index 74ba49624..6ea089bb7 100644 --- a/sharktank/sharktank/layers/testing.py +++ b/sharktank/sharktank/layers/testing.py @@ -65,10 +65,10 @@ def make_mmdit_double_block_random_theta( mlp_hidden_size3 = int(2 * (mlp_ratio - 1) * hidden_size) return Theta( { - "img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + "img_attn.norm.key_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels,), dtype=dtype) ), - "img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + "img_attn.norm.query_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels,), dtype=dtype) ), "img_attn.proj.bias": DefaultPrimitiveTensor( @@ -101,10 +101,10 @@ def make_mmdit_double_block_random_theta( "img_mod.lin.weight": DefaultPrimitiveTensor( data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype) ), - "txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + "txt_attn.norm.key_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels,), dtype=dtype) ), - "txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + "txt_attn.norm.query_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels,), dtype=dtype) ), "txt_attn.proj.bias": DefaultPrimitiveTensor( @@ -155,10 +155,10 @@ def make_mmdit_single_block_random_theta( mlp_hidden_size3 = int((2 * mlp_ratio - 1) * hidden_size) return Theta( { - "attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + "norm.key_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels,), dtype=dtype) ), - "attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + "norm.query_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels,), dtype=dtype) ), "attn.proj.bias": DefaultPrimitiveTensor( @@ -179,10 +179,10 @@ def make_mmdit_single_block_random_theta( "linear2.weight": DefaultPrimitiveTensor( data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype) ), - "mod.lin.bias": DefaultPrimitiveTensor( + "modulation.lin.bias": DefaultPrimitiveTensor( data=make_rand_torch((mlp_hidden_size,), dtype=dtype) ), - "mod.lin.weight": DefaultPrimitiveTensor( + "modulation.lin.weight": DefaultPrimitiveTensor( data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) ), } diff --git a/sharktank/sharktank/models/flux/export.py b/sharktank/sharktank/models/flux/export.py new file mode 100644 index 000000000..fae3a5362 --- /dev/null +++ b/sharktank/sharktank/models/flux/export.py @@ -0,0 +1,49 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from os import PathLike + +from ...export import export_static_model_mlir +from ...tools.import_hf_dataset import import_hf_dataset +from .flux import FluxModelV1, FluxParams +from ...types import Dataset +from ...utils.hf_datasets import get_dataset + +flux_transformer_default_batch_sizes = [4] + + +def export_flux_transformer_model_mlir( + model: FluxModelV1, + output_path: PathLike, + batch_sizes: list[int] = flux_transformer_default_batch_sizes, +): + export_static_model_mlir(model, output_path=output_path, batch_sizes=batch_sizes) + + +def export_flux_transformer_from_hugging_face( + repo_id: str, + mlir_output_path: PathLike, + parameters_output_path: PathLike, + batch_sizes: list[int] = flux_transformer_default_batch_sizes, +): + hf_dataset = get_dataset( + repo_id, + ).download() + + import_hf_dataset( + config_json_path=hf_dataset["config"][0], + param_paths=hf_dataset["parameters"], + output_irpa_file=parameters_output_path, + ) + + dataset = Dataset.load(parameters_output_path) + model = FluxModelV1( + theta=dataset.root_theta, + params=FluxParams.from_hugging_face_properties(dataset.properties), + ) + export_flux_transformer_model_mlir( + model, output_path=mlir_output_path, batch_sizes=batch_sizes + ) diff --git a/sharktank/sharktank/models/flux/flux.py b/sharktank/sharktank/models/flux/flux.py index ac63f47a0..d99b14ad4 100644 --- a/sharktank/sharktank/models/flux/flux.py +++ b/sharktank/sharktank/models/flux/flux.py @@ -9,6 +9,8 @@ https://github.com/black-forest-labs/flux/blob/main/src/flux/model.py """ +from typing import Any, Optional +from collections import OrderedDict import math from dataclasses import dataclass import torch @@ -45,6 +47,46 @@ class FluxParams: qkv_bias: bool guidance_embed: bool + @staticmethod + def from_hugging_face_properties(properties: dict[str, Any]) -> "FluxParams": + p = properties["hparams"] + + in_channels = p["in_channels"] + out_channels = p["in_channels"] + vec_in_dim = p["pooled_projection_dim"] + context_in_dim = p["joint_attention_dim"] + mlp_ratio = 4.0 + hidden_size = vec_in_dim * int(mlp_ratio) + num_heads = p["num_attention_heads"] + depth = p["num_layers"] + depth_single_blocks = p["num_single_layers"] + + # TODO: figure out relation between hidden_size, num_heads and + # attention_head_dim. + # diffusers.FluxTransformer2DModel also hardcodes this. + axes_dim = [16, 56, 56] + assert sum(axes_dim) == p["attention_head_dim"] + + theta = 10_000 + qkv_bias = True + guidance_embed = p["guidance_embeds"] + + return FluxParams( + in_channels=in_channels, + out_channels=out_channels, + vec_in_dim=vec_in_dim, + context_in_dim=context_in_dim, + mlp_ratio=mlp_ratio, + hidden_size=hidden_size, + num_heads=num_heads, + depth=depth, + depth_single_blocks=depth_single_blocks, + axes_dim=axes_dim, + theta=theta, + qkv_bias=qkv_bias, + guidance_embed=guidance_embed, + ) + class FluxModelV1(ThetaLayer): """FluxModel adapted from Black Forest Lab's implementation.""" @@ -71,16 +113,12 @@ def __init__(self, theta: Theta, params: FluxParams): dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim ) self.add_module("img_in", LinearLayer(theta("img_in"))) - # TODO: Refactor this pattern to an MLPEmbedder like src implementatio - self.add_module("time_in_0", LinearLayer(theta("time_in.0"))) - self.add_module("time_in_1", LinearLayer(theta("time_in.1"))) - self.add_module("vector_in_0", LinearLayer(theta("vector_in.0"))) - self.add_module("vector_in_1", LinearLayer(theta("vector_in.1"))) + self.add_module("time_in", MLPEmbedder(theta("time_in"))) + self.add_module("vector_in", MLPEmbedder(theta("vector_in"))) self.guidance = False if params.guidance_embed: self.guidance = True - self.add_module("guidance_in_0", LinearLayer(theta("guidance_in.0"))) - self.add_module("guidance_in_1", LinearLayer(theta("guidance_in.1"))) + self.add_module("guidance_in", MLPEmbedder(theta("guidance_in"))) self.add_module("txt_in", LinearLayer(theta("txt_in"))) self.double_blocks = nn.ModuleList( @@ -104,8 +142,8 @@ def __init__(self, theta: Theta, params: FluxParams): ) self.add_module( - "last_layer", - LastLayer(theta("last_layer")), + "final_layer", + LastLayer(theta("final_layer")), ) def forward( @@ -123,23 +161,14 @@ def forward( # running on sequences img img = self.img_in(img) - time_in_0 = self.time_in_0(timestep_embedding(timesteps, 256)) - time_in_silu = ops.elementwise(F.silu, time_in_0) - vec = self.time_in_1(time_in_silu) + vec = self.time_in(timestep_embedding(timesteps, 256)) if self.guidance: if guidance is None: raise ValueError( "Didn't get guidance strength for guidance distilled model." ) - guidance_inp = timestep_embedding(guidance, 256) - guidance0 = self.guidance_in0(guidance_inp) - guidance_silu = ops.elementwise(F.silu, guidance0) - guidance_out = self.guidance_in1(guidance_silu) - vec = vec + self.guidance_in(guidance_out) - vector_in_0 = self.vector_in_0(y) - vector_in_silu = ops.elementwise(F.silu, vector_in_0) - vector_in_1 = self.vector_in_1(vector_in_silu) - vec = vec + vector_in_1 + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) txt = self.txt_in(txt) @@ -154,9 +183,36 @@ def forward( img = block(img, vec=vec, pe=pe) img = img[:, txt.shape[1] :, ...] - img = self.last_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img + def sample_inputs( + self, batch_size: int = 1, function: Optional[str] = None + ) -> tuple[tuple[AnyTensor], OrderedDict[str, AnyTensor]]: + if not (function is None or function == "forward"): + raise ValueError(f'Only function "forward" is supported. Got "{function}"') + + # TODO: do not hardcode these but derive the required shapes from the config. + img = torch.rand([batch_size, 1024, 64]) + img_ids = torch.rand([batch_size, 1024, 3]) + txt = torch.rand([batch_size, 512, 4096]) + txt_ids = torch.rand([batch_size, 512, 3]) + timesteps = torch.rand([batch_size]) + y = torch.rand([batch_size, 768]) + + args = tuple() + kwargs = OrderedDict( + ( + ("img", img), + ("img_ids", img_ids), + ("txt", txt), + ("txt_ids", txt_ids), + ("timesteps", timesteps), + ("y", y), + ) + ) + return args, kwargs + ################################################################################ # Layers @@ -216,6 +272,18 @@ def rope(pos: AnyTensor, dim: int, theta: int) -> AnyTensor: return out.float() +class MLPEmbedder(ThetaLayer): + def __init__(self, theta: Theta): + super().__init__(theta) + self.in_layer = LinearLayer(theta("in_layer")) + self.out_layer = LinearLayer(theta("out_layer")) + + def forward(self, x: AnyTensor) -> AnyTensor: + x = self.in_layer(x) + x = ops.elementwise(torch.nn.functional.silu, x) + return self.out_layer(x) + + class EmbedND(torch.nn.Module): def __init__(self, dim: int, theta: int, axes_dim: list[int]): super().__init__() @@ -239,13 +307,15 @@ def __init__( theta: Theta, ): super().__init__(theta) - self.add_module("outlinear", LinearLayer(theta("outlinear"))) - self.add_module("ada_linear", LinearLayer(theta("ada_linear"))) + self.add_module( + "adaLN_modulation_linear", LinearLayer(theta("adaLN_modulation.1")) + ) + self.add_module("linear", LinearLayer(theta("linear"))) def forward(self, x: AnyTensor, vec: AnyTensor) -> AnyTensor: silu = ops.elementwise(F.silu, vec) - lin = self.ada_linear(silu) + lin = self.adaLN_modulation_linear(silu) shift, scale = lin.chunk(2, dim=1) x = (1 + scale[:, None, :]) * layer_norm(x) + shift[:, None, :] - x = self.outlinear(x) + x = self.linear(x) return x diff --git a/sharktank/sharktank/models/punet/tools/import_hf_dataset.py b/sharktank/sharktank/tools/import_hf_dataset.py similarity index 61% rename from sharktank/sharktank/models/punet/tools/import_hf_dataset.py rename to sharktank/sharktank/tools/import_hf_dataset.py index 0afa5222d..8b8feed9f 100644 --- a/sharktank/sharktank/models/punet/tools/import_hf_dataset.py +++ b/sharktank/sharktank/tools/import_hf_dataset.py @@ -14,21 +14,31 @@ Usage: python -m sharktank.models.punet.import_hf_dataset \ --output-irpa-file ~/models/punet/punet_fp16.irpa \ - --config-json ~/models/stable-diffusion-xl-base-1.0/unet/config.json + --config-json ~/models/stable-diffusion-xl-base-1.0/unet/config.json \ + --params diffusion_pytorch_model.fp16.safetensors The resulting dataset has all tensors as nested in the original model. Properties are separated into a "meta" dict (for "_" prefixed props) and an "hparams" dict. """ +from typing import Optional +from os import PathLike import json from pathlib import Path import sys +import logging -from ....types import * +from ..types import * +logger = logging.getLogger(__name__) -def import_hf_config(config_json_path: Path, params_path: Path) -> Dataset: + +def import_hf_dataset( + config_json_path: PathLike, + param_paths: list[PathLike], + output_irpa_file: Optional[PathLike] = None, +) -> Optional[Dataset]: import safetensors with open(config_json_path, "rb") as f: @@ -37,22 +47,28 @@ def import_hf_config(config_json_path: Path, params_path: Path) -> Dataset: meta_params = {k: v for k, v in config_json.items() if k.startswith("_")} hparams = {k: v for k, v in config_json.items() if not k.startswith("_")} - with safetensors.safe_open(params_path, framework="pt", device="cpu") as st: - tensors = [ - DefaultPrimitiveTensor(name=name, data=st.get_tensor(name)) - for name in st.keys() - ] + for params_path in param_paths: + with safetensors.safe_open(params_path, framework="pt", device="cpu") as st: + tensors = [ + DefaultPrimitiveTensor(name=name, data=st.get_tensor(name)) + for name in st.keys() + ] theta = Theta(tensors) props = { "meta": meta_params, "hparams": hparams, } - return Dataset(props, theta) + dataset = Dataset(props, theta) + + if output_irpa_file is None: + return dataset + + dataset.save(output_irpa_file, io_report_callback=logger.info) -def main(argv): - from ....utils import cli +def main(argv: list[str]): + from ..utils import cli parser = cli.create_parser() cli.add_output_dataset_options(parser) @@ -62,18 +78,22 @@ def main(argv): parser.add_argument( "--params", type=Path, + nargs="+", default=Path("diffusion_pytorch_model.fp16.safetensors"), - help="Parameter file name, relative to config.json", + help="Parameter file name(s), relative to config.json", ) args = cli.parse(parser, args=argv) config_json_path: Path = args.config_json - params_path: Path = args.params - if not params_path.is_absolute(): - params_path = config_json_path.parent / params_path - - dataset = import_hf_config(config_json_path, params_path) - dataset.save(args.output_irpa_file, io_report_callback=print) + param_paths: list[Path] = args.params + param_paths = [ + path if path.is_absolute() else config_json_path.parent / path + for path in param_paths + ] + + import_hf_dataset( + config_json_path, param_paths, output_irpa_file=args.output_irpa_file + ) if __name__ == "__main__": diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index 9fefeb66f..b4b405dca 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -100,7 +100,7 @@ def add_tokenizer_options(parser: argparse.ArgumentParser): ) -def get_input_data_files(args) -> Optional[dict[str, Path]]: +def get_input_data_files(args) -> Optional[dict[str, list[Path]]]: """Gets data files given the input arguments. Keys may contain: @@ -112,9 +112,9 @@ def get_input_data_files(args) -> Optional[dict[str, Path]]: dataset = hf_datasets.get_dataset(args.hf_dataset).download() return dataset elif args.gguf_file is not None: - return {"gguf": args.gguf_file} + return {"gguf": [args.gguf_file]} elif args.irpa_file is not None: - return {"irpa": args.irpa_file} + return {"irpa": [args.irpa_file]} def get_input_dataset(args) -> Dataset: @@ -124,10 +124,10 @@ def get_input_dataset(args) -> Dataset: """ data_files = get_input_data_files(args) if "gguf" in data_files: - return Dataset.load(data_files["gguf"], file_type="gguf") + return Dataset.load(data_files["gguf"][0], file_type="gguf") if "irpa" in data_files: - return Dataset.load(data_files["irpa"], file_type="irpa") + return Dataset.load(data_files["irpa"][0], file_type="irpa") raise ValueError(f'Dataset format unsupported. Must be "gguf" or "irpa".') @@ -142,7 +142,7 @@ def get_tokenizer(args) -> tokenizer.InferenceTokenizer: return tokenizer.fake_tokenizer() if args.tokenizer_config_json is not None: - data_files = {"tokenizer_config.json": args.tokenizer_config_json} + data_files = {"tokenizer_config.json": [args.tokenizer_config_json]} else: data_files = get_input_data_files(args) @@ -150,7 +150,7 @@ def get_tokenizer(args) -> tokenizer.InferenceTokenizer: if tokenizer_type is None: if "tokenizer_config.json" in data_files: return tokenizer.load_tokenizer( - data_files["tokenizer_config.json"].parent, + data_files["tokenizer_config.json"][0].parent, tokenizer_type="transformers", ) else: diff --git a/sharktank/sharktank/utils/hf_datasets.py b/sharktank/sharktank/utils/hf_datasets.py index c6a799404..6893b637a 100644 --- a/sharktank/sharktank/utils/hf_datasets.py +++ b/sharktank/sharktank/utils/hf_datasets.py @@ -33,16 +33,26 @@ class RemoteFile: filename: str extra_filenames: Sequence[str] = () - def download(self, *, local_dir: Optional[Path] = None) -> Path: - for extra_filename in self.extra_filenames: - hf_hub_download( - repo_id=self.repo_id, filename=extra_filename, local_dir=local_dir - ) - return Path( - hf_hub_download( - repo_id=self.repo_id, filename=self.filename, local_dir=local_dir + def download(self, *, local_dir: Optional[Path] = None) -> list[Path]: + res = [] + res.append( + Path( + hf_hub_download( + repo_id=self.repo_id, filename=self.filename, local_dir=local_dir + ) ) ) + for extra_filename in self.extra_filenames: + res.append( + Path( + hf_hub_download( + repo_id=self.repo_id, + filename=extra_filename, + local_dir=local_dir, + ) + ) + ) + return res @dataclass @@ -59,7 +69,7 @@ def alias_to(self, to_name: str) -> "Dataset": alias_dataset(self.name, to_name) return self - def download(self, *, local_dir: Optional[Path] = None) -> Dict[str, Path]: + def download(self, *, local_dir: Optional[Path] = None) -> Dict[str, list[Path]]: return {f.file_id: f.download(local_dir=local_dir) for f in self.files} @@ -363,6 +373,54 @@ def alias_dataset(from_name: str, to_name: str): ), ) +# The Flux transformer is in 2 formats. +# This is used in diffusers.FluxTransformer2DModel +Dataset( + "black-forest-labs/FLUX.1-schnell/transformer", + ( + RemoteFile( + "config", + "black-forest-labs/FLUX.1-schnell", + "transformer/config.json", + ), + RemoteFile( + "parameters", + "black-forest-labs/FLUX.1-schnell", + "transformer/diffusion_pytorch_model-00001-of-00003.safetensors", + extra_filenames=[ + "transformer/diffusion_pytorch_model-00002-of-00003.safetensors", + "transformer/diffusion_pytorch_model-00003-of-00003.safetensors", + ], + ), + RemoteFile( + "parameters-index", + "black-forest-labs/FLUX.1-schnell", + "transformer/diffusion_pytorch_model.safetensors.index.json", + ), + ), +) + +# The Flux transformer is in 2 formats. +# This is used in the Black Forest's Flux repo. +# https://github.com/black-forest-labs/flux +# We have based our implementation on that. +Dataset( + "black-forest-labs/FLUX.1-schnell/black-forest-labs-transformer", + ( + RemoteFile( + "config", + "black-forest-labs/FLUX.1-schnell", + "transformer/config.json", + ), + RemoteFile( + "parameters", + "black-forest-labs/FLUX.1-schnell", + "flux1-schnell.safetensors", + ), + ), +) + + ################################################################################ # Tool entrypoint ################################################################################ @@ -386,8 +444,8 @@ def main(): for dataset_name in args.dataset_name: print(f"Downloading dataset {dataset_name}") ds = get_dataset(dataset_name).download(local_dir=args.local_dir) - for key, path in ds.items(): - print(f" {key}: {path}") + for key, paths in ds.items(): + print(f" {key}: {paths}") if __name__ == "__main__": diff --git a/sharktank/tests/models/flux/flux_test.py b/sharktank/tests/models/flux/flux_test.py index ea80c7b42..fc4d23251 100644 --- a/sharktank/tests/models/flux/flux_test.py +++ b/sharktank/tests/models/flux/flux_test.py @@ -5,18 +5,17 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import logging - -logging.basicConfig(level=logging.DEBUG) - import unittest - import torch - -from iree.turbine import aot +import pytest from sharktank.models.flux.flux import ( FluxModelV1, FluxParams, ) +from sharktank.models.flux.export import ( + export_flux_transformer_model_mlir, + export_flux_transformer_from_hugging_face, +) import sharktank.ops as ops from sharktank.layers.testing import ( make_rand_torch, @@ -24,11 +23,14 @@ from sharktank.types.tensors import DefaultPrimitiveTensor from sharktank.types.theta import Dataset, Theta from sharktank.utils.testing import TempDirTestBase +from sharktank.utils.hf_datasets import get_dataset + +logging.basicConfig(level=logging.DEBUG) +with_flux_data = pytest.mark.skipif("not config.getoption('with_flux_data')") # TODO: Refactor this to a function that generates random toy weights, possibly # to another file -dtype = torch.float32 in_channels = 64 in_channels2 = 128 hidden_size = 3072 @@ -45,7 +47,7 @@ out_channels = 64 -def make_random_theta(): +def make_random_theta(dtype: torch.dtype): return Theta( { "img_in.weight": DefaultPrimitiveTensor( # @@ -60,34 +62,34 @@ def make_random_theta(): "txt_in.bias": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size,), dtype=dtype) ), - "time_in.0.weight": DefaultPrimitiveTensor( # + "time_in.in_layer.weight": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size, time_dim), dtype=dtype) ), - "time_in.0.bias": DefaultPrimitiveTensor( # + "time_in.in_layer.bias": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size,), dtype=dtype) ), - "time_in.1.weight": DefaultPrimitiveTensor( # + "time_in.out_layer.weight": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) ), - "time_in.1.bias": DefaultPrimitiveTensor( # + "time_in.out_layer.bias": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size,), dtype=dtype) ), - "vector_in.0.weight": DefaultPrimitiveTensor( # + "vector_in.in_layer.weight": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size, vec_dim), dtype=dtype) ), - "vector_in.0.bias": DefaultPrimitiveTensor( # + "vector_in.in_layer.bias": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size,), dtype=dtype) ), - "vector_in.1.weight": DefaultPrimitiveTensor( # + "vector_in.out_layer.weight": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) ), - "vector_in.1.bias": DefaultPrimitiveTensor( # + "vector_in.out_layer.bias": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size,), dtype=dtype) ), - "double_blocks.0.img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + "double_blocks.0.img_attn.norm.key_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels2,), dtype=dtype) ), - "double_blocks.0.img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + "double_blocks.0.img_attn.norm.query_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels2,), dtype=dtype) ), "double_blocks.0.img_attn.proj.bias": DefaultPrimitiveTensor( @@ -120,10 +122,10 @@ def make_random_theta(): "double_blocks.0.img_mod.lin.weight": DefaultPrimitiveTensor( data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype) ), - "double_blocks.0.txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + "double_blocks.0.txt_attn.norm.key_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels2,), dtype=dtype) ), - "double_blocks.0.txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + "double_blocks.0.txt_attn.norm.query_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels2,), dtype=dtype) ), "double_blocks.0.txt_attn.proj.bias": DefaultPrimitiveTensor( @@ -156,10 +158,10 @@ def make_random_theta(): "double_blocks.0.txt_mod.lin.weight": DefaultPrimitiveTensor( data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype) ), - "single_blocks.0.attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + "single_blocks.0.norm.key_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels2,), dtype=dtype) ), - "single_blocks.0.attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + "single_blocks.0.norm.query_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels2,), dtype=dtype) ), "single_blocks.0.attn.proj.bias": DefaultPrimitiveTensor( @@ -180,26 +182,26 @@ def make_random_theta(): "single_blocks.0.linear2.weight": DefaultPrimitiveTensor( data=make_rand_torch((hidden_size, mlp_hidden_size4), dtype=dtype) ), - "single_blocks.0.mod.lin.bias": DefaultPrimitiveTensor( + "single_blocks.0.modulation.lin.bias": DefaultPrimitiveTensor( data=make_rand_torch((mlp_hidden_size,), dtype=dtype) ), - "single_blocks.0.mod.lin.weight": DefaultPrimitiveTensor( + "single_blocks.0.modulation.lin.weight": DefaultPrimitiveTensor( data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) ), - "last_layer.outlinear.weight": DefaultPrimitiveTensor( # + "final_layer.linear.weight": DefaultPrimitiveTensor( # data=make_rand_torch( (patch_size * patch_size * out_channels, hidden_size), dtype=dtype ) ), - "last_layer.outlinear.bias": DefaultPrimitiveTensor( # + "final_layer.linear.bias": DefaultPrimitiveTensor( # data=make_rand_torch( (patch_size * patch_size * out_channels,), dtype=dtype ) ), - "last_layer.ada_linear.weight": DefaultPrimitiveTensor( # + "final_layer.adaLN_modulation.1.weight": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size * 2, hidden_size), dtype=dtype) ), - "last_layer.ada_linear.bias": DefaultPrimitiveTensor( # + "final_layer.adaLN_modulation.1.bias": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size * 2,), dtype=dtype) ), } @@ -214,7 +216,8 @@ def setUp(self): self.num_heads = 24 self.batch_size = 5 - def testExport(self): + def testExportBfloat16SingleLayer(self): + dtype = torch.bfloat16 params = FluxParams( in_channels=64, out_channels=64, @@ -230,32 +233,26 @@ def testExport(self): qkv_bias=True, guidance_embed=False, ) - theta = make_random_theta() + theta = make_random_theta(dtype) theta = self.save_load_theta(theta) flux = FluxModelV1( theta=theta, params=params, ) - img = torch.rand([self.batch_size, 1024, 64]) - img_ids = torch.rand([self.batch_size, 1024, 3]) - txt = torch.rand([self.batch_size, 512, 4096]) - txt_ids = torch.rand([self.batch_size, 512, 3]) - timesteps = torch.rand([self.batch_size]) - y = torch.rand([self.batch_size, 768]) - - flux.forward(img, img_ids, txt, txt_ids, timesteps, y) - fxb = aot.FxProgramsBuilder(flux) - - @fxb.export_program( - name="flux", args=(img, img_ids, txt, txt_ids, timesteps, y), strict=False + export_flux_transformer_model_mlir( + flux, + output_path=self._temp_dir / "model.mlir", + batch_sizes=[self.batch_size], ) - def _(model, img, img_ids, txt, txt_ids, timesteps, y) -> torch.Tensor: - return model.forward(img, img_ids, txt, txt_ids, timesteps, y) - output = aot.export(fxb) - output.verify() - asm = str(output.mlir_module) + @with_flux_data + def testExportSchnellFromHuggingFace(self): + export_flux_transformer_from_hugging_face( + "black-forest-labs/FLUX.1-schnell/black-forest-labs-transformer", + mlir_output_path=self._temp_dir / "model.mlir", + parameters_output_path=self._temp_dir / "parameters.irpa", + ) def save_load_theta(self, theta: Theta): # Roundtrip to disk to avoid treating parameters as constants that would appear diff --git a/sharktank/tests/models/llama/prefill_tests.py b/sharktank/tests/models/llama/prefill_tests.py index 093ecdfc9..f7b456389 100644 --- a/sharktank/tests/models/llama/prefill_tests.py +++ b/sharktank/tests/models/llama/prefill_tests.py @@ -86,7 +86,7 @@ def setUp(self): self.data_files = hf_datasets.get_dataset( default_arguments["hf_dataset"] ).download(local_dir=Path(".")) - self.dataset = Dataset.load(self.data_files["gguf"], file_type="gguf") + self.dataset = Dataset.load(self.data_files["gguf"][0], file_type="gguf") self.tokenizer_config = tokenizer.load_tokenizer( default_arguments["tokenizer-config-json"].parent, tokenizer_type="transformers", @@ -138,7 +138,7 @@ def setUp(self): self.data_files = hf_datasets.get_dataset( default_arguments["hf_dataset"] ).download(local_dir=Path(".")) - self.dataset = Dataset.load(self.data_files["gguf"], file_type="gguf") + self.dataset = Dataset.load(self.data_files["gguf"][0], file_type="gguf") self.tokenizer_config = tokenizer.load_tokenizer( default_arguments["tokenizer-config-json"].parent, tokenizer_type="transformers",