forked from lshqqytiger/stable-diffusion-webui-amdgpu
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6130ef9
commit 620b78c
Showing
5 changed files
with
136 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import os | ||
import shutil | ||
import zipfile | ||
import tarfile | ||
import platform | ||
import urllib.request | ||
|
||
|
||
RELEASE = 'rel.9e97c717c3fef536d3116f39a15d95626c1dfe39' | ||
TARGETS = { | ||
'cublas.dll': 'cublas64_11.dll', | ||
'cusparse.dll': 'cusparse64_11.dll', | ||
'nvrtc.dll': 'nvrtc64_112_0.dll', | ||
} | ||
ZLUDA_PATH = None | ||
TORCHLIB_PATH = None | ||
|
||
|
||
def find_zluda_path(): | ||
zluda_path = os.environ.get('ZLUDA', None) | ||
if zluda_path is None: | ||
paths = os.environ.get('PATH', '').split(';') | ||
for path in paths: | ||
if os.path.exists(os.path.join(path, 'zluda_redirect.dll')): | ||
zluda_path = path | ||
break | ||
return zluda_path | ||
|
||
|
||
def find_venv_dir(): | ||
python_dir = os.path.dirname(shutil.which('python')) | ||
if shutil.which('conda') is None: | ||
python_dir = os.path.dirname(python_dir) | ||
return os.environ.get('VENV_DIR', python_dir) | ||
|
||
|
||
def reset_torch(): | ||
for dll in TARGETS.values(): | ||
path = os.path.join(TORCHLIB_PATH, dll) | ||
if os.path.exists(path): | ||
os.remove(path) | ||
|
||
|
||
def is_patched(): | ||
for dll in TARGETS.values(): | ||
if not os.path.islink(os.path.join(TORCHLIB_PATH, dll)): | ||
return False | ||
return True | ||
|
||
|
||
def check_dnn_dependency(): | ||
hip_path = os.environ.get("HIP_PATH", None) | ||
if hip_path is None: # unable to check | ||
return True | ||
if os.path.exists(os.path.join(hip_path, 'bin', 'MIOpen.dll')): | ||
return True | ||
return False | ||
|
||
|
||
def enable_dnn(): | ||
global RELEASE # pylint: disable=global-statement | ||
TARGETS['cudnn.dll'] = 'cudnn64_8.dll' | ||
RELEASE = 'v3.8-pre2-dnn' | ||
|
||
|
||
def install(): | ||
global ZLUDA_PATH, TORCHLIB_PATH # pylint: disable=global-statement | ||
ZLUDA_PATH = find_zluda_path() | ||
TORCHLIB_PATH = os.path.join(find_venv_dir(), 'Lib', 'site-packages', 'torch', 'lib') | ||
|
||
if ZLUDA_PATH is not None: | ||
return | ||
|
||
is_windows = platform.system() == 'Windows' | ||
archive_type = zipfile.ZipFile if is_windows else tarfile.TarFile | ||
urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/{RELEASE}/ZLUDA-{platform.system().lower()}-amd64.{"zip" if is_windows else "tar.gz"}', '_zluda') | ||
with archive_type('_zluda', 'r') as f: | ||
f.extractall('.zluda') | ||
ZLUDA_PATH = os.path.abspath('./.zluda') | ||
os.remove('_zluda') | ||
|
||
|
||
def resolve_path(): | ||
paths = os.environ.get('PATH', '.') | ||
if ZLUDA_PATH not in paths: | ||
os.environ['PATH'] = paths + ';' + ZLUDA_PATH | ||
|
||
|
||
def patch(): | ||
if is_patched(): | ||
return | ||
reset_torch() | ||
for k, v in TARGETS.items(): | ||
os.symlink(os.path.join(ZLUDA_PATH, k), os.path.join(TORCHLIB_PATH, v)) |