Skip to content

Commit

Permalink
Add --fast argument to enable experimental optimizations.
Browse files Browse the repository at this point in the history
Optimizations that might break things/lower quality will be put behind
this flag first and might be enabled by default in the future.

Currently the only optimization is float8_e4m3fn matrix multiplication on
4000/ADA series Nvidia cards or later. If you have one of these cards you
will see a speed boost when using fp8_e4m3fn flux for example.
  • Loading branch information
comfyanonymous committed Aug 20, 2024
1 parent d1a6bd6 commit 9953f22
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 5 deletions.
1 change: 1 addition & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class LatentPreviewMethod(enum.Enum):

parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")

parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
Expand Down
5 changes: 1 addition & 4 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod

if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
if self.manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
Expand Down
10 changes: 10 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,16 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma

return False

def supports_fp8_compute(device=None):
props = torch.cuda.get_device_properties(device)
if props.major >= 9:
return True
if props.major < 8:
return False
if props.minor < 9:
return False
return True

def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:
Expand Down
41 changes: 40 additions & 1 deletion comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch
import comfy.model_management

from comfy.cli_args import args

def cast_to(weight, dtype=None, device=None, non_blocking=False):
if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
Expand Down Expand Up @@ -242,3 +242,42 @@ class ConvTranspose1d(disable_weight_init.ConvTranspose1d):

class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True


def fp8_linear(self, input):
dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]:
return None

if len(input.shape) == 3:
out = torch.empty((input.shape[0], input.shape[1], self.weight.shape[0]), device=input.device, dtype=input.dtype)
inn = input.to(dtype)
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t()
for i in range(input.shape[0]):
if self.bias is not None:
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking))
else:
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype)
out[i] = o
return out
return None

class fp8_ops(manual_cast):
class Linear(manual_cast.Linear):
def forward_comfy_cast_weights(self, input):
out = fp8_linear(self, input)
if out is not None:
return out

weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)


def pick_operations(weight_dtype, compute_dtype, load_device=None):
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init
if args.fast:
if comfy.model_management.supports_fp8_compute(load_device):
return fp8_ops
return manual_cast

10 comments on commit 9953f22

@jepjoo
Copy link

@jepjoo jepjoo commented on 9953f22 Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gave it a shot, but I get an error.

RTX 4090, Win11, pytorch version: 2.3.1+cu121

RuntimeError: _scaled_mm_out_cuda is not compiled for this platform.

The whole thing:

Error occurred when executing SamplerCustom:

_scaled_mm_out_cuda is not compiled for this platform.

  File "G:\ComfyUI_windows_portable\ComfyUI\execution.py", line 316, in execute
    output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\execution.py", line 191, in get_output_data
    return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\execution.py", line 168, in _map_node_over_list
    process_inputs(input_dict, i)
  File "G:\ComfyUI_windows_portable\ComfyUI\execution.py", line 157, in process_inputs
    results.append(getattr(obj, func)(**inputs))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\comfy_extras\nodes_custom_sampler.py", line 455, in sample
    samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\custom_nodes\ComfyUI-AnimateDiff-Evolved\animatediff\sampling.py", line 218, in motion_sample
    return orig_comfy_sample(model, noise, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\comfy\sample.py", line 48, in sample_custom
    samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 729, in sample
    return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 716, in sample
    output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 695, in inner_sample
    samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 600, in sample
    samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\comfy\k_diffusion\sampling.py", line 144, in sample_euler
    denoised = model(x, sigma_hat * s_in, **extra_args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 299, in __call__
    out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 682, in __call__
    return self.predict_noise(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 685, in predict_noise
    return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 279, in sampling_function
    out = calc_cond_batch(model, conds, x, timestep, model_options)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 228, in calc_cond_batch
    output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\comfy\model_base.py", line 142, in apply_model
    model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\comfy\ldm\flux\model.py", line 159, in forward
    out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\comfy\ldm\flux\model.py", line 104, in forward_orig
    img = self.img_in(img)
          ^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\comfy\ops.py", line 67, in forward
    return self.forward_comfy_cast_weights(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\comfy\ops.py", line 269, in forward_comfy_cast_weights
    out = fp8_linear(self, input)
          ^^^^^^^^^^^^^^^^^^^^^^^
  File "G:\ComfyUI_windows_portable\ComfyUI\comfy\ops.py", line 259, in fp8_linear
    o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking))

@comfyanonymous
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need a newer pytorch version on windows. I only tested it on Linux which is why it's behind the --fast argument.

@jepjoo
Copy link

@jepjoo jepjoo commented on 9953f22 Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need a newer pytorch version on windows. I only tested it on Linux which is why it's behind the --fast argument.

Works now after running update_comfyui_and_python_dependencies.bat which updated torch to 2.4.0 and some other things.

Pretty significant speed bump, ~40%-ish!

@Michoko92
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just FYI, here is a post about further optimizations for Flux: https://www.reddit.com/r/StableDiffusion/comments/1ex64jj/comment/lj3v03m/?context=3 .

@hanggun
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i just tested on the 4090 using flux with Q8.gguf, but the speed is not increase~~

@jepjoo
Copy link

@jepjoo jepjoo commented on 9953f22 Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i just tested on the 4090 using flux with Q8.gguf, but the speed is not increase~~

fp8_e4m3fn only

@bananasss00
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

works with SD15, Flux models, but on sdxl models I get black images

@De-Zoomer
Copy link

@De-Zoomer De-Zoomer commented on 9953f22 Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not working for me, getting this:

TypeError: _scaled_mm() missing 2 required positional argument: "scale_a", "scale_b"

My setup: Windows 11, 4090, non-standalone install, Python 3.12.5, pytorch 2.5.0.dev20240818+cu124, cuda_12.6.r12.6

@comfyanonymous
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update comfyui, pytorch nightly is supported now.

@De-Zoomer
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update comfyui, pytorch nightly is supported now.

Yes, it works now. Thanks!

Please sign in to comment.