Skip to content

Commit

Permalink
Add Flux transformer export for easier use outside of tests (nod-ai#700)
Browse files Browse the repository at this point in the history
Adapt the model to accept parameters as structured in the HF repo.

Make Punet parameters importation from HF more general to serve other
models as well.

When downloading a dataset from Hugging Face make it return the local
location of all downloaded files including extras, not just the
"leading" file.

Add sample_inputs method to the BaseLayer interface to help standardize
exportation.

Introduce a standard export function for static-sized models.
  • Loading branch information
sogartar authored Dec 17, 2024
1 parent b151ffa commit 83437b5
Show file tree
Hide file tree
Showing 14 changed files with 417 additions and 138 deletions.
10 changes: 7 additions & 3 deletions .github/workflows/ci-sharktank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,19 @@ jobs:
pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
- name: Run tests
# TODO: unify with-t5-data and with-clip-data flags into a single flag
# and make it possible to run only tests that require data.
# TODO: unify with-*-data flags into a single flag and make it possible to run
# only tests that require data.
# We would still want the separate flags as we may endup with data being
# scattered on different CI machines.
run: |
source ${VENV_DIR}/bin/activate
pytest \
--with-clip-data \
--with-clip-data \
--with-flux-data \
--with-t5-data \
sharktank/tests/models/clip/clip_test.py \
sharktank/tests/models/t5/t5_test.py \
sharktank/tests/models/flux/flux_test.py \
--durations=0
Expand Down
9 changes: 9 additions & 0 deletions sharktank/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ def pytest_addoption(parser):
"code. The user is expected to provide the data"
),
)
parser.addoption(
"--with-flux-data",
action="store_true",
default=False,
help=(
"Enable tests that use Flux data like models that is not a part of the source "
"code. The user is expected to provide the data"
),
)
parser.addoption(
"--with-t5-data",
action="store_true",
Expand Down
2 changes: 1 addition & 1 deletion sharktank/integration/models/punet/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def download(filename):

@pytest.fixture(scope="module")
def sdxl_fp16_dataset(sdxl_fp16_base_files, temp_dir):
from sharktank.models.punet.tools import import_hf_dataset
from sharktank.tools import import_hf_dataset

dataset = temp_dir / "sdxl_fp16_dataset.irpa"
import_hf_dataset.main(
Expand Down
53 changes: 52 additions & 1 deletion sharktank/sharktank/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Callable, Any
from typing import Callable, Optional, Any
import torch
from os import PathLike
import iree.turbine.aot as aot
from iree.turbine.aot import DeviceAffinity, FxProgramsBuilder
from torch.utils._pytree import tree_structure, tree_unflatten, tree_flatten
from .types.tensors import ShardedTensor
from .layers import BaseLayer
from torch.utils._pytree import PyTree, _is_leaf
import functools

Expand Down Expand Up @@ -172,3 +175,51 @@ def flat_fn(*args, **kwargs):
)

assert False, "TODO: implement the case when not using an FxProgramsBuilder"


def export_static_model_mlir(
model: BaseLayer,
output_path: PathLike,
function_batch_size_pairs: Optional[dict[Optional[str], list[int]]] = None,
batch_sizes: Optional[list[int]] = None,
):
"""Export a model with no dynamic dimensions.
For the set of provided function name batch sizes pair, the resulting MLIR will
have function names with the below format.
```
<function_name>_bs<batch_size>
```
If `batch_sizes` is given then it defaults to a single function with named
"forward".
The model is required to implement method `sample_inputs`.
"""

assert not (function_batch_size_pairs is not None and batch_sizes is not None)

if batch_sizes is not None:
function_batch_size_pairs = {None: batch_sizes}

if function_batch_size_pairs is None and batch_sizes is None:
function_batch_size_pairs = {None: batch_sizes}

fxb = FxProgramsBuilder(model)

for function, batch_sizes in function_batch_size_pairs.items():
for batch_size in batch_sizes:
args, kwargs = model.sample_inputs(batch_size, function)

@fxb.export_program(
name=f"{function or 'forward'}_bs{batch_size}",
args=args,
kwargs=kwargs,
dynamic_shapes=None,
strict=False,
)
def _(model, **kwargs):
return model(**kwargs)

output = aot.export(fxb)
output.save_mlir(output_path)
24 changes: 18 additions & 6 deletions sharktank/sharktank/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Dict

from typing import Dict, Optional
from collections import OrderedDict
import torch
import torch.nn as nn

from ..types import (
InferenceTensor,
Theta,
)
from ..types import InferenceTensor, Theta, AnyTensor
from ..utils import debugging

__all__ = [
Expand Down Expand Up @@ -56,6 +53,21 @@ def assert_not_nan(self, *ts: torch.Tensor):
if torch.isnan(t).any():
raise AssertionError(f"Tensor contains nans! {t}")

def sample_inputs(
self, batch_size: int = 1, function: Optional[str] = None
) -> tuple[tuple[AnyTensor], OrderedDict[str, AnyTensor]]:
"""Return sample inputs that can be used to run the function from the model.
If function is None then layer is treated as the callable.
E.g.
```
args, kwargs = model.sample_inputs()
model(*args, **kwargs)
```
One purpose of this method is to standardize exportation of models to MLIR.
"""
raise NotImplementedError()


class ThetaLayer(BaseLayer):
"Base class for layers that derive parameters from a Theta object."
Expand Down
25 changes: 17 additions & 8 deletions sharktank/sharktank/layers/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,15 @@ def __init__(self, theta, num_heads: int):
self.add_module("img_attn_qkv", LinearLayer(theta("img_attn.qkv")))
self.add_module(
"img_attn_norm_q",
RMSNormLayer(theta("img_attn.norm.query_norm"), epsilon=1e-6),
RMSNormLayer(
theta("img_attn.norm.query_norm"), weight_name="scale", epsilon=1e-6
),
)
self.add_module(
"img_attn_norm_k",
RMSNormLayer(theta("img_attn.norm.key_norm"), epsilon=1e-6),
RMSNormLayer(
theta("img_attn.norm.key_norm"), weight_name="scale", epsilon=1e-6
),
)
self.add_module("img_attn_proj", LinearLayer(theta("img_attn.proj")))

Expand All @@ -70,11 +74,15 @@ def __init__(self, theta, num_heads: int):
self.add_module("txt_attn_qkv", LinearLayer(theta("txt_attn.qkv")))
self.add_module(
"txt_attn_norm_q",
RMSNormLayer(theta("txt_attn.norm.query_norm"), epsilon=1e-6),
RMSNormLayer(
theta("txt_attn.norm.query_norm"), weight_name="scale", epsilon=1e-6
),
)
self.add_module(
"txt_attn_norm_k",
RMSNormLayer(theta("txt_attn.norm.key_norm"), epsilon=1e-6),
RMSNormLayer(
theta("txt_attn.norm.key_norm"), weight_name="scale", epsilon=1e-6
),
)
self.add_module("txt_attn_proj", LinearLayer(theta("txt_attn.proj")))

Expand Down Expand Up @@ -151,14 +159,15 @@ def __init__(self, theta, num_heads: int):
super().__init__(theta)

self.num_heads = num_heads
self.add_module("mod", ModulationLayer(theta("mod"), double=False))
self.add_module("mod", ModulationLayer(theta("modulation"), double=False))
self.add_module(
"attn_norm_q", RMSNormLayer(theta("attn.norm.query_norm"), epsilon=1e-6)
"attn_norm_q",
RMSNormLayer(theta("norm.query_norm"), weight_name="scale", epsilon=1e-6),
)
self.add_module(
"attn_norm_k", RMSNormLayer(theta("attn.norm.key_norm"), epsilon=1e-6)
"attn_norm_k",
RMSNormLayer(theta("norm.key_norm"), weight_name="scale", epsilon=1e-6),
)
self.add_module("attn_proj", LinearLayer(theta("attn.proj")))

self.add_module("linear1", LinearLayer(theta("linear1")))
self.add_module("linear2", LinearLayer(theta("linear2")))
Expand Down
16 changes: 8 additions & 8 deletions sharktank/sharktank/layers/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def make_mmdit_double_block_random_theta(
mlp_hidden_size3 = int(2 * (mlp_ratio - 1) * hidden_size)
return Theta(
{
"img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
"img_attn.norm.key_norm.scale": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels,), dtype=dtype)
),
"img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
"img_attn.norm.query_norm.scale": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels,), dtype=dtype)
),
"img_attn.proj.bias": DefaultPrimitiveTensor(
Expand Down Expand Up @@ -101,10 +101,10 @@ def make_mmdit_double_block_random_theta(
"img_mod.lin.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype)
),
"txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
"txt_attn.norm.key_norm.scale": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels,), dtype=dtype)
),
"txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
"txt_attn.norm.query_norm.scale": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels,), dtype=dtype)
),
"txt_attn.proj.bias": DefaultPrimitiveTensor(
Expand Down Expand Up @@ -155,10 +155,10 @@ def make_mmdit_single_block_random_theta(
mlp_hidden_size3 = int((2 * mlp_ratio - 1) * hidden_size)
return Theta(
{
"attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
"norm.key_norm.scale": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels,), dtype=dtype)
),
"attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
"norm.query_norm.scale": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels,), dtype=dtype)
),
"attn.proj.bias": DefaultPrimitiveTensor(
Expand All @@ -179,10 +179,10 @@ def make_mmdit_single_block_random_theta(
"linear2.weight": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype)
),
"mod.lin.bias": DefaultPrimitiveTensor(
"modulation.lin.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size,), dtype=dtype)
),
"mod.lin.weight": DefaultPrimitiveTensor(
"modulation.lin.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype)
),
}
Expand Down
49 changes: 49 additions & 0 deletions sharktank/sharktank/models/flux/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from os import PathLike

from ...export import export_static_model_mlir
from ...tools.import_hf_dataset import import_hf_dataset
from .flux import FluxModelV1, FluxParams
from ...types import Dataset
from ...utils.hf_datasets import get_dataset

flux_transformer_default_batch_sizes = [4]


def export_flux_transformer_model_mlir(
model: FluxModelV1,
output_path: PathLike,
batch_sizes: list[int] = flux_transformer_default_batch_sizes,
):
export_static_model_mlir(model, output_path=output_path, batch_sizes=batch_sizes)


def export_flux_transformer_from_hugging_face(
repo_id: str,
mlir_output_path: PathLike,
parameters_output_path: PathLike,
batch_sizes: list[int] = flux_transformer_default_batch_sizes,
):
hf_dataset = get_dataset(
repo_id,
).download()

import_hf_dataset(
config_json_path=hf_dataset["config"][0],
param_paths=hf_dataset["parameters"],
output_irpa_file=parameters_output_path,
)

dataset = Dataset.load(parameters_output_path)
model = FluxModelV1(
theta=dataset.root_theta,
params=FluxParams.from_hugging_face_properties(dataset.properties),
)
export_flux_transformer_model_mlir(
model, output_path=mlir_output_path, batch_sizes=batch_sizes
)
Loading

0 comments on commit 83437b5

Please sign in to comment.