diff --git a/shark/examples/shark_eager/squeezenet_lockstep.py b/shark/examples/shark_eager/squeezenet_lockstep.py new file mode 100644 index 0000000000000..f5fcf42d33a3b --- /dev/null +++ b/shark/examples/shark_eager/squeezenet_lockstep.py @@ -0,0 +1,73 @@ +import torch +import numpy as np + +model = torch.hub.load( + "pytorch/vision:v0.10.0", "squeezenet1_0", pretrained=True +) +model.eval() + +# from PIL import Image +# from torchvision import transforms +# import urllib +# +# url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg") +# try: urllib.URLopener().retrieve(url, filename) +# except: urllib.request.urlretrieve(url, filename) +# +# +# input_image = Image.open(filename) +# preprocess = transforms.Compose([ +# transforms.Resize(256), +# transforms.CenterCrop(224), +# transforms.ToTensor(), +# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +# ]) +# input_tensor = preprocess(input_image) +# input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model +# print(input_batch.shape) # size = [1, 3, 224, 224] + +# The above is code for generating sample inputs from an image. We can just use +# random values for accuracy testing though +input_batch = torch.randn(1, 3, 224, 224) + + +# Focus on CPU for now +if False and torch.cuda.is_available(): + input_batch = input_batch.to("cuda") + model.to("cuda") + +with torch.no_grad(): + output = model(input_batch) +# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes +golden_confidences = output[0] +# The output has unnormalized scores. To get probabilities, you can run a softmax on it. +golden_probabilities = torch.nn.functional.softmax( + golden_confidences, dim=0 +).numpy() + +golden_confidences = golden_confidences.numpy() + +from shark.torch_mlir_lockstep_tensor import TorchMLIRLockstepTensor + +input_detached_clone = input_batch.clone() +eager_input_batch = TorchMLIRLockstepTensor(input_detached_clone) + +print("getting torch-mlir result") + +output = model(eager_input_batch) + +static_output = output.elem +confidences = static_output[0] +probabilities = torch.nn.functional.softmax( + torch.from_numpy(confidences), dim=0 +).numpy() + +print("The obtained result via shark is: ", confidences) +print("The golden result is:", golden_confidences) + +np.testing.assert_allclose( + golden_confidences, confidences, rtol=1e-02, atol=1e-03 +) +np.testing.assert_allclose( + golden_probabilities, probabilities, rtol=1e-02, atol=1e-03 +) diff --git a/shark/torch_mlir_lockstep_tensor.py b/shark/torch_mlir_lockstep_tensor.py new file mode 100644 index 0000000000000..b5c829e57644c --- /dev/null +++ b/shark/torch_mlir_lockstep_tensor.py @@ -0,0 +1,206 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. +import contextlib +import re +import traceback +import warnings +from typing import Any +import numpy as np + +import torch +from torch.utils._pytree import tree_map + +from torch_mlir.eager_mode.ir_building import build_mlir_module +from torch_mlir.eager_mode.torch_mlir_dispatch import ( + UnsupportedByTorchMlirEagerMode, + normalize_args_kwargs, + check_get_aliased_arg, +) +from torch_mlir.eager_mode import EAGER_MODE_DEBUG +from torch_mlir.eager_mode.torch_mlir_tensor import ( + TorchMLIRTensor, + check_requires_grad, + make_wrapper_subclass_from_torch_tensor, + make_bare_wrapper_subclass, + UNSUPPORTED_OPS, + no_dispatch, +) +from torch_mlir.eager_mode import torch_mlir_tensor +from shark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend + + +backend = EagerModeIREELinalgOnTensorsBackend("cpu") +torch_mlir_tensor.backend = backend +rtol = 1e-04 +atol = 1e-05 + + +class TorchMLIRLockstepTensor(TorchMLIRTensor): + """This class overrides the dispatching for TorchMLIRTensor to allow for an op-by-op numerical comparison between PyTorch and the Torch-MLIR -> IREE backend compilation pipeline. This only supports the IREE backend and focuses on op-by-op level verification. + + TODO: Extend this to do a cumulative trace with summary statistics at the end. Possibly requires a wrapper environment to store full trace info. + """ + + def __new__(cls, elem, **kwargs): + if kwargs.get("constructing_from_device_tensor", False): + tensor_meta_data = backend.get_torch_metadata(elem, kwargs) + r = make_bare_wrapper_subclass( + cls=cls, + size=tensor_meta_data.size, + strides=tensor_meta_data.strides, + storage_offset=tensor_meta_data.storage_offset, + dtype=tensor_meta_data.dtype, + layout=tensor_meta_data.layout, + device=tensor_meta_data.device, + requires_grad=tensor_meta_data.requires_grad, + ) + r.elem = elem + elif isinstance(elem, torch.nn.Parameter): + r = make_wrapper_subclass_from_torch_tensor( + cls, elem.data, **kwargs + ) + # This is a hack to handle non-contiguous data through IREE-backend + nt = elem.detach().data.numpy() + if not nt.flags["C_CONTIGUOUS"]: + nt = np.ascontiguousarray(nt, dtype=nt.dtype) + r.elem = backend.transfer_from_torch_to_device(torch.Tensor(nt)) + elif isinstance(elem, torch.Tensor): + r = make_wrapper_subclass_from_torch_tensor(cls, elem, **kwargs) + # Ditto TODO: Find a better way to handle this + nt = elem.numpy() + if not nt.flags["C_CONTIGUOUS"]: + nt = np.ascontiguousarray(nt, dtype=nt.dtype) + r.elem = backend.transfer_from_torch_to_device(torch.Tensor(nt)) + # This branch handles the case when a python scalar is passed to some op + # or is returned from some aten op, such as _local_scalar_dense. + elif isinstance(elem, (int, float, bool)): + return elem + else: + raise ValueError(f"Unknown element type: {type(elem)}") + return r + + def __repr__(self): + if self.grad_fn: + return f"TorchMLIRLockstepTensor({self.elem}, backend={backend.__class__.__name__}, grad_fn={self.grad_fn})" + else: + return f"TorchMLIRLockstepTensor({self.elem}, backend={backend.__class__.__name__})" + + """This does essentially the same dispatch as TorchMLIRTensor but operates as if debug mode is enabled. The numeric verification happens after the Torch-MLIR result is obtained by comparing against the + """ + + @classmethod + def __torch_dispatch__(cls, func, _types, args=(), kwargs=None): + requires_grad = check_requires_grad(*args, **kwargs) + try: + with no_dispatch(): + if hasattr(func, "op_name"): + op_name = func.op_name + elif hasattr(func, "__name__"): + # Handle builtin_function_or_method. + op_name = func.__name__ + else: + raise RuntimeError(f"op {func} has no name") + + if UNSUPPORTED_OPS.match(op_name): + raise UnsupportedByTorchMlirEagerMode(op_name) + + if not hasattr(func, "_schema"): + raise RuntimeError(f"op {func} has no schema.") + + normalized_kwargs = normalize_args_kwargs(func, args, kwargs) + + if "layout" in normalized_kwargs and normalized_kwargs[ + "layout" + ] not in {0, None}: + raise UnsupportedByTorchMlirEagerMode( + f"{normalized_kwargs['layout']} layout not supported." + ) + if "memory_format" in normalized_kwargs and normalized_kwargs[ + "memory_format" + ] not in {0, None}: + raise UnsupportedByTorchMlirEagerMode( + f"{normalized_kwargs['memory_format']} memory format not supported." + ) + eager_module = build_mlir_module(func, normalized_kwargs) + device_tensor_args = [ + kwarg.elem + for _, kwarg in normalized_kwargs.items() + if isinstance(kwarg, cls) + ] + assert len(eager_module.body.operations[0].arguments) == len( + device_tensor_args + ), "Number of parameters and number of arguments differs." + op_mlir_backend_callable = backend.compile(eager_module) + out = op_mlir_backend_callable(*device_tensor_args) + out = tree_map( + lambda x: cls( + x, + requires_grad=requires_grad, + constructing_from_device_tensor=True, + ), + out, + ) + + # Numeric verification; Value for comparison comes from PyTorch eager + with no_dispatch(): + unwrapped_args = tree_map(cls.unwrap, args) + unwrapped_kwargs = tree_map(cls.unwrap, kwargs) + native_out = func(*unwrapped_args, **unwrapped_kwargs) + + native_out = tree_map( + lambda x: cls(x, requires_grad=requires_grad), native_out + ).elem + tmp_out = out.elem + + try: + np.testing.assert_allclose( + native_out.to_host(), + tmp_out.to_host(), + rtol=rtol, + atol=atol, + ) + except Exception as e: + shaped_args = [ + arg.shape if torch.is_tensor(arg) else arg + for arg in unwrapped_args + ] + shaped_kwargs = [ + kwarg.shape if torch.is_tensor(kwarg) else kwarg + for kwarg in unwrapped_kwargs + ] + warnings.warn( + f"Lockstep accuracy verification failed with error: *{str(e)}*; " + f"Dispatched function name: *{str(func)}*; " + f"Dispatched function args: *{str(shaped_args)}*; " + f"Dispatched function kwargs: *{str(shaped_kwargs)}*; " + ) + except Exception as e: + warnings.warn(traceback.format_exc()) + if isinstance(e, UnsupportedByTorchMlirEagerMode): + warnings.warn( + f"Couldn't use TorchMLIR eager because current incompatibility: *{str(e)}*; running through PyTorch eager." + ) + else: + warnings.warn( + f"Couldn't use TorchMLIR eager because of error: *{str(e)}*; " + f"Running through PyTorch eager" + ) + + with no_dispatch(): + unwrapped_args = tree_map(cls.unwrap, args) + unwrapped_kwargs = tree_map(cls.unwrap, kwargs) + out = func(*unwrapped_args, **unwrapped_kwargs) + + out = tree_map(lambda x: cls(x, requires_grad=requires_grad), out) + + maybe_aliased_arg_name = check_get_aliased_arg(func) + if maybe_aliased_arg_name is not None: + warnings.warn( + f"Found aliased arg, but didn't copy tensor contents. This could lead to incorrect results for E2E model execution but doesn't affect the validity of the lockstep op verification." + ) + # TODO: Find a way to handle argument aliasing for IREE backend + # backend.copy_into(normalized_kwargs[maybe_aliased_arg_name].elem, out.elem) + + return out diff --git a/tank/pytorch/v_diffusion_pytorch/cfg_sample_eager.py b/tank/pytorch/v_diffusion_pytorch/cfg_sample_eager.py new file mode 100755 index 0000000000000..2d0e390a57d14 --- /dev/null +++ b/tank/pytorch/v_diffusion_pytorch/cfg_sample_eager.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 + +"""Classifier-free guidance sampling from a diffusion model.""" + +import argparse +from functools import partial +from pathlib import Path + +from PIL import Image +import torch +from torch import nn +from torch.nn import functional as F +from torchvision import transforms +from torchvision.transforms import functional as TF +from tqdm import trange + +from shark.shark_inference import SharkInference +from shark.torch_mlir_lockstep_tensor import TorchMLIRLockstepTensor + +import sys + +sys.path.append("v-diffusion-pytorch") +from CLIP import clip +from diffusion import get_model, get_models, sampling, utils + +MODULE_DIR = Path(__file__).resolve().parent + + +def parse_prompt(prompt, default_weight=3.0): + if prompt.startswith("http://") or prompt.startswith("https://"): + vals = prompt.rsplit(":", 2) + vals = [vals[0] + ":" + vals[1], *vals[2:]] + else: + vals = prompt.rsplit(":", 1) + vals = vals + ["", default_weight][len(vals) :] + return vals[0], float(vals[1]) + + +def resize_and_center_crop(image, size): + fac = max(size[0] / image.size[0], size[1] / image.size[1]) + image = image.resize( + (int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS + ) + return TF.center_crop(image, size[::-1]) + + +# def main(): +p = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter +) +p.add_argument( + "prompts", type=str, default=[], nargs="*", help="the text prompts to use" +) +p.add_argument( + "--images", + type=str, + default=[], + nargs="*", + metavar="IMAGE", + help="the image prompts", +) +p.add_argument( + "--batch-size", + "-bs", + type=int, + default=1, + help="the number of images per batch", +) +p.add_argument("--checkpoint", type=str, help="the checkpoint to use") +p.add_argument("--device", type=str, help="the device to use") +p.add_argument( + "--eta", + type=float, + default=0.0, + help="the amount of noise to add during sampling (0-1)", +) +p.add_argument("--init", type=str, help="the init image") +p.add_argument( + "--method", + type=str, + default="plms", + choices=["ddpm", "ddim", "prk", "plms", "pie", "plms2", "iplms"], + help="the sampling method to use", +) +p.add_argument( + "--model", + type=str, + default="cc12m_1_cfg", + choices=["cc12m_1_cfg"], + help="the model to use", +) +p.add_argument( + "-n", type=int, default=1, help="the number of images to sample" +) +p.add_argument("--seed", type=int, default=0, help="the random seed") +p.add_argument("--size", type=int, nargs=2, help="the output image size") +p.add_argument( + "--starting-timestep", + "-st", + type=float, + default=0.9, + help="the timestep to start at (used with init images)", +) +p.add_argument("--steps", type=int, default=50, help="the number of timesteps") +args = p.parse_args() + +if args.device: + device = torch.device(args.device) +else: + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +print("Using device:", device) + +model = get_model(args.model)() +_, side_y, side_x = model.shape +if args.size: + side_x, side_y = args.size +checkpoint = args.checkpoint +if not checkpoint: + checkpoint = MODULE_DIR / f"checkpoints/{args.model}.pth" +model.load_state_dict(torch.load(checkpoint, map_location="cpu")) +if device.type == "cuda": + model = model.half() +model = model.to(device).eval().requires_grad_(False) +clip_model_name = ( + model.clip_model if hasattr(model, "clip_model") else "ViT-B/16" +) +clip_model = clip.load(clip_model_name, jit=False, device=device)[0] +clip_model.eval().requires_grad_(False) +normalize = transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], +) + +if args.init: + init = Image.open(utils.fetch(args.init)).convert("RGB") + init = resize_and_center_crop(init, (side_x, side_y)) + init = ( + utils.from_pil_image(init).to(device)[None].repeat([args.n, 1, 1, 1]) + ) + +zero_embed = torch.zeros([1, clip_model.visual.output_dim], device=device) +target_embeds, weights = [zero_embed], [] + +for prompt in args.prompts: + txt, weight = parse_prompt(prompt) + target_embeds.append( + clip_model.encode_text(clip.tokenize(txt).to(device)).float() + ) + weights.append(weight) + +for prompt in args.images: + path, weight = parse_prompt(prompt) + img = Image.open(utils.fetch(path)).convert("RGB") + clip_size = clip_model.visual.input_resolution + img = resize_and_center_crop(img, (clip_size, clip_size)) + batch = TF.to_tensor(img)[None].to(device) + embed = F.normalize( + clip_model.encode_image(normalize(batch)).float(), dim=-1 + ) + target_embeds.append(embed) + weights.append(weight) + +weights = torch.tensor([1 - sum(weights), *weights], device=device) + +torch.manual_seed(args.seed) + + +def cfg_model_fn(x, t): + n = x.shape[0] + n_conds = len(target_embeds) + x_in = x.repeat([n_conds, 1, 1, 1]) + t_in = t.repeat([n_conds]) + clip_embed_in = torch.cat([*target_embeds]).repeat([n, 1]) + vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]]) + v = vs.mul(weights[:, None, None, None, None]).sum(0) + return v + + +x = torch.randn([args.n, 3, side_y, side_x], device=device) +t = torch.linspace(1, 0, args.steps + 1, device=device)[:-1] +steps = utils.get_spliced_ddpm_cosine_schedule(t) +min_batch_size = min(args.n, args.batch_size) +x_in = x[0:min_batch_size, :, :, :] +ts = x_in.new_ones([x_in.shape[0]]) +t_in = t[0] * ts + +from torch.fx.experimental.proxy_tensor import make_fx +from torch._decomp import get_decompositions +import torch_mlir + +fx_g = make_fx( + cfg_model_fn, + decomposition_table=get_decompositions( + [ + torch.ops.aten.embedding_dense_backward, + torch.ops.aten.native_layer_norm_backward, + torch.ops.aten.slice_backward, + torch.ops.aten.select_backward, + torch.ops.aten.norm.ScalarOpt_dim, + torch.ops.aten.native_group_norm, + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes, + ] + ), +)(x_in, t_in) + +fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) +fx_g.recompile() + + +def strip_overloads(gm): + """ + Modifies the target of graph nodes in :attr:`gm` to strip overloads. + Args: + gm(fx.GraphModule): The input Fx graph module to be modified + """ + for node in gm.graph.nodes: + if isinstance(node.target, torch._ops.OpOverload): + node.target = node.target.overloadpacket + gm.recompile() + + +strip_overloads(fx_g) + +ts_g = torch.jit.script(fx_g) + +# module = torch_mlir.compile( +# ts_g, +# [x_in, t_in], +# torch_mlir.OutputType.LINALG_ON_TENSORS, +# use_tracing=False, +# ) +# +# mlir_model = module +# func_name = "forward" +# +# shark_module = SharkInference( +# mlir_model, func_name, device="gpu", mlir_dialect="linalg" +# ) +# shark_module.compile() + + +def compiled_cfg_model_fn(x, t): + x_in_eager = TorchMLIRLockstepTensor(x.clone()) + t_in_eager = TorchMLIRLockstepTensor(t.clone()) + return ts_g(x_in_eager, t_in_eager) + + +def run(x, steps): + if args.method == "ddpm": + return sampling.sample(compiled_cfg_model_fn, x, steps, 1.0, {}) + if args.method == "ddim": + return sampling.sample(compiled_cfg_model_fn, x, steps, args.eta, {}) + if args.method == "prk": + return sampling.prk_sample(compiled_cfg_model_fn, x, steps, {}) + if args.method == "plms": + return sampling.plms_sample(compiled_cfg_model_fn, x, steps, {}) + if args.method == "pie": + return sampling.pie_sample(compiled_cfg_model_fn, x, steps, {}) + if args.method == "plms2": + return sampling.plms2_sample(compiled_cfg_model_fn, x, steps, {}) + if args.method == "iplms": + return sampling.iplms_sample(compiled_cfg_model_fn, x, steps, {}) + assert False + + +def run_all(x, t, steps, n, batch_size): + x = torch.randn([n, 3, side_y, side_x], device=device) + t = torch.linspace(1, 0, args.steps + 1, device=device)[:-1] + steps = utils.get_spliced_ddpm_cosine_schedule(t) + if args.init: + steps = steps[steps < args.starting_timestep] + alpha, sigma = utils.t_to_alpha_sigma(steps[0]) + x = init * alpha + x * sigma + for i in trange(0, n, batch_size): + cur_batch_size = min(n - i, batch_size) + outs = run(x[i : i + cur_batch_size], steps) + for j, out in enumerate(outs): + utils.to_pil_image(out).save(f"out_{i + j:05}.png") + + +steps = 1 + +run_all(x, t, steps, args.n, args.batch_size)