Skip to content

Commit

Permalink
Basic torch_directml support. Use --directml to use it.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Apr 28, 2023
1 parent ab9a9de commit 3baded9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
1 change: 1 addition & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
parser.add_argument("--directml", action="store_true", help="Use torch-directml.")

attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
Expand Down
27 changes: 26 additions & 1 deletion comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ class VRAMState(Enum):
accelerate_enabled = False
xpu_available = False

directml_enabled = False
if args.directml:
import torch_directml
print("Using directml")
directml_enabled = True
# torch_directml.disable_tiled_resources(True)

try:
import torch
try:
Expand Down Expand Up @@ -217,6 +224,9 @@ def unload_if_low_vram(model):

def get_torch_device():
global xpu_available
global directml_enabled
if directml_enabled:
return torch_directml.device()
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU:
Expand All @@ -234,8 +244,14 @@ def get_autocast_device(dev):


def xformers_enabled():
global xpu_available
global directml_enabled
if vram_state == VRAMState.CPU:
return False
if xpu_available:
return False
if directml_enabled:
return False
return XFORMERS_IS_AVAILABLE


Expand All @@ -251,14 +267,18 @@ def pytorch_attention_enabled():

def get_free_memory(dev=None, torch_free_too=False):
global xpu_available
global directml_enabled
if dev is None:
dev = get_torch_device()

if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
else:
if xpu_available:
if directml_enabled:
mem_free_total = 1024 * 1024 * 1024 #TODO
mem_free_torch = mem_free_total
elif xpu_available:
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
mem_free_torch = mem_free_total
else:
Expand Down Expand Up @@ -293,9 +313,14 @@ def mps_mode():

def should_use_fp16():
global xpu_available
global directml_enabled

if FORCE_FP32:
return False

if directml_enabled:
return False

if cpu_mode() or mps_mode() or xpu_available:
return False #TODO ?

Expand Down

0 comments on commit 3baded9

Please sign in to comment.