diff --git a/modules/cmd_args.py b/modules/cmd_args.py index cd809341c49..437c3f38706 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -75,6 +75,7 @@ parser.add_argument("--use-cpu-torch", action="store_true", help="use torch built with cpu") parser.add_argument("--use-directml", action="store_true", help="use DirectML device as torch device") parser.add_argument("--use-zluda", action="store_true", help="use ZLUDA device as torch device") +parser.add_argument("--use-zluda-dnn", action="store_true", help="enable ZLUDA DNN") parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device") parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") diff --git a/modules/img2img.py b/modules/img2img.py index 847c3df235b..a1d042c2123 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -8,7 +8,7 @@ from modules import images from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters -from modules.processing import Processed, process_images +from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state from modules.sd_models import get_closet_checkpoint_match import modules.shared as shared @@ -186,7 +186,7 @@ def img2img(id_task: str, request: gr.Request, mode: int, prompt: str, negative_ assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' - p = processing.StableDiffusionProcessingImg2Img( + p = StableDiffusionProcessingImg2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples, outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids, diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 850eb395b8e..2e134bd0436 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -437,13 +437,13 @@ def prepare_environment(): rocm_found = False hip_found = False backend = "cuda" - torch_command = "pip install torch==2.2.2 torchvision" if args.use_cpu_torch else "pip install torch==2.2.2 torchvision --extra-index-url https://download.pytorch.org/whl/cu121" + torch_command = "pip install torch==2.3.0 torchvision" if args.use_cpu_torch else "pip install torch==2.2.2 torchvision --extra-index-url https://download.pytorch.org/whl/cu121" if args.use_cpu_torch: backend = "cpu" torch_command = os.environ.get( "TORCH_COMMAND", - "pip install torch==2.2.2 torchvision", + "pip install torch==2.3.0 torchvision", ) elif args.use_directml: backend = "directml" @@ -460,32 +460,23 @@ def prepare_environment(): ) torch_command = os.environ.get( "TORCH_COMMAND", - f"pip install torch==2.2.2 torchvision --index-url {torch_index_url}", + f"pip install torch==2.3.0 torchvision --index-url {torch_index_url}", ) - zluda_path = find_zluda() - if zluda_path is None: - is_windows = system == "Windows" - import urllib.request - import zipfile - import tarfile - archive_type = zipfile.ZipFile if is_windows else tarfile.TarFile - try: - urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/rel.9e97c717c3fef536d3116f39a15d95626c1dfe39/ZLUDA-{platform.system().lower()}-amd64.{"zip" if is_windows else "tar.gz"}', '_zluda') - with archive_type('_zluda', 'r') as f: - f.extractall('.zluda') - zluda_path = os.path.abspath('./.zluda') - os.remove('_zluda') - except Exception as e: - print(f'Failed to install ZLUDA: {e}') - if os.path.exists(os.path.join(zluda_path, 'nvcuda.dll')): - print(f'Using ZLUDA in {zluda_path}') - torch_command = os.environ.get( - 'TORCH_COMMAND', - 'pip install torch==2.2.2 torchvision --index-url https://download.pytorch.org/whl/cu118', - ) - paths = os.environ.get('PATH', '.') - if zluda_path not in paths: - os.environ['PATH'] = paths + ';' + zluda_path + try: + from modules import zluda_installer + if args.use_zluda_dnn: + if zluda_installer.check_dnn_dependency(): + zluda_installer.enable_dnn() + else: + print("Couldn't find the required dependency of ZLUDA DNN.") + zluda_installer.install() + zluda_installer.resolve_path() + torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch==2.3.0 torchvision --index-url https://download.pytorch.org/whl/cu118') + print(f'Using ZLUDA in {zluda_installer.ZLUDA_PATH}') + except Exception as e: + print(f'Failed to install ZLUDA: {e}') + print('Using CPU-only torch') + torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch torchvision') elif args.use_ipex: backend = "ipex" if system == "Windows": @@ -519,7 +510,7 @@ def prepare_environment(): ) torch_command = os.environ.get( "TORCH_COMMAND", - f"pip install torch==2.2.0 torchvision==0.17.0 --extra-index-url {torch_index_url}", + f"pip install torch==2.3.0 torchvision --extra-index-url {torch_index_url}", ) elif system == "Windows" and hip_found: # ZLUDA print("ROCm Toolkit was found.") @@ -529,17 +520,17 @@ def prepare_environment(): ) torch_command = os.environ.get( "TORCH_COMMAND", - f"pip install torch==2.2.1 torchvision --index-url {torch_index_url}", + f"pip install torch==2.3.0 torchvision --index-url {torch_index_url}", ) elif rocm_found: print("ROCm Toolkit was found.") backend = "rocm" torch_index_url = os.environ.get( - "TORCH_INDEX_URL", "https://download.pytorch.org/whl/rocm5.4.2" + "TORCH_INDEX_URL", "https://download.pytorch.org/whl/rocm6.0" ) torch_command = os.environ.get( "TORCH_COMMAND", - f"pip install torch==2.0.1 torchvision==0.15.2 --index-url {torch_index_url}", + f"pip install torch==2.3.0 torchvision --index-url {torch_index_url}", ) requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") @@ -585,7 +576,11 @@ def prepare_environment(): run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) startup_timer.record("install torch") if args.use_zluda: - patch_zluda() + try: + from modules.zluda_installer import patch as patch_torch + patch_torch() + except Exception as e: + print(f'ZLUDA: failed to automatically patch torch: {e}') if args.use_ipex or args.use_directml or args.use_cpu_torch: args.skip_torch_cuda_test = True diff --git a/modules/zluda.py b/modules/zluda.py index e9031d9b237..d428ceffdec 100644 --- a/modules/zluda.py +++ b/modules/zluda.py @@ -1,7 +1,15 @@ import platform import torch from torch._prims_common import DeviceLikeType -from modules import devices +from modules import shared, devices + + +conv2d = torch.nn.functional.conv2d +def conv2d_cudnn_disabled(*args, **kwargs): + torch.backends.cudnn.enabled = False + R = conv2d(*args, **kwargs) + torch.backends.cudnn.enabled = True + return R def is_zluda(device: DeviceLikeType): @@ -23,10 +31,12 @@ def test(device: DeviceLikeType): def initialize_zluda(): device = devices.get_optimal_device() if platform.system() == "Windows" and torch.cuda.is_available() and is_zluda(device): - torch.backends.cudnn.enabled = False + torch.backends.cudnn.enabled = shared.cmd_opts.use_zluda_dnn torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(False) + if shared.cmd_opts.use_zluda_dnn: + torch.nn.functional.conv2d = conv2d_cudnn_disabled devices.device_codeformer = devices.cpu if not test(device): diff --git a/modules/zluda_installer.py b/modules/zluda_installer.py new file mode 100644 index 00000000000..fc7ddc3b8b9 --- /dev/null +++ b/modules/zluda_installer.py @@ -0,0 +1,94 @@ +import os +import shutil +import zipfile +import tarfile +import platform +import urllib.request + + +RELEASE = 'rel.9e97c717c3fef536d3116f39a15d95626c1dfe39' +TARGETS = { + 'cublas.dll': 'cublas64_11.dll', + 'cusparse.dll': 'cusparse64_11.dll', + 'nvrtc.dll': 'nvrtc64_112_0.dll', +} +ZLUDA_PATH = None +TORCHLIB_PATH = None + + +def find_zluda_path(): + zluda_path = os.environ.get('ZLUDA', None) + if zluda_path is None: + paths = os.environ.get('PATH', '').split(';') + for path in paths: + if os.path.exists(os.path.join(path, 'zluda_redirect.dll')): + zluda_path = path + break + return zluda_path + + +def find_venv_dir(): + python_dir = os.path.dirname(shutil.which('python')) + if shutil.which('conda') is None: + python_dir = os.path.dirname(python_dir) + return os.environ.get('VENV_DIR', python_dir) + + +def reset_torch(): + for dll in TARGETS.values(): + path = os.path.join(TORCHLIB_PATH, dll) + if os.path.exists(path): + os.remove(path) + + +def is_patched(): + for dll in TARGETS.values(): + if not os.path.islink(os.path.join(TORCHLIB_PATH, dll)): + return False + return True + + +def check_dnn_dependency(): + hip_path = os.environ.get("HIP_PATH", None) + if hip_path is None: # unable to check + return True + if os.path.exists(os.path.join(hip_path, 'bin', 'MIOpen.dll')): + return True + return False + + +def enable_dnn(): + global RELEASE # pylint: disable=global-statement + TARGETS['cudnn.dll'] = 'cudnn64_8.dll' + RELEASE = 'v3.8-pre2-dnn' + + +def install(): + global ZLUDA_PATH, TORCHLIB_PATH # pylint: disable=global-statement + ZLUDA_PATH = find_zluda_path() + TORCHLIB_PATH = os.path.join(find_venv_dir(), 'Lib', 'site-packages', 'torch', 'lib') + + if ZLUDA_PATH is not None: + return + + is_windows = platform.system() == 'Windows' + archive_type = zipfile.ZipFile if is_windows else tarfile.TarFile + urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/{RELEASE}/ZLUDA-{platform.system().lower()}-amd64.{"zip" if is_windows else "tar.gz"}', '_zluda') + with archive_type('_zluda', 'r') as f: + f.extractall('.zluda') + ZLUDA_PATH = os.path.abspath('./.zluda') + os.remove('_zluda') + + +def resolve_path(): + paths = os.environ.get('PATH', '.') + if ZLUDA_PATH not in paths: + os.environ['PATH'] = paths + ';' + ZLUDA_PATH + + +def patch(): + if is_patched(): + return + reset_torch() + for k, v in TARGETS.items(): + os.symlink(os.path.join(ZLUDA_PATH, k), os.path.join(TORCHLIB_PATH, v))