Skip to content

Commit

Permalink
we have proper metal support
Browse files Browse the repository at this point in the history
  • Loading branch information
iamlemec committed Oct 17, 2024
1 parent 9000953 commit 08a6d71
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 4 deletions.
2 changes: 1 addition & 1 deletion gadget/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def create_backend(self, name):
self.backend_type = 'cuda'
elif name == 'metal':
self.backend = ggml_backend_metal_init()
self.backend_type = 'metal'
self.backend_type = 'cpu'
else:
raise ValueError(f'unknown backend: {name}')

Expand Down
1 change: 1 addition & 0 deletions gadget/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# backend
ggml_backend_cpu_init,
ggml_backend_cuda_init,
ggml_backend_metal_init,
ggml_backend_free,
ggml_backend_alloc_ctx_tensors,
ggml_backend_get_default_buffer_type,
Expand Down
11 changes: 10 additions & 1 deletion gadget/libs/_libggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
## library
##

_ggml = load_shared_lib('libggml.so', 'GADGET_GGML_LIB')
_ggml = load_shared_lib('libggml', 'GADGET_GGML_LIB')

##
## function wrappers
Expand Down Expand Up @@ -489,6 +489,15 @@ def ggml_backend_cuda_init(): ...
except:
ggml_backend_cuda_init = DummyFunction('CUDA backend not found')

try:
@ctypes_function(_ggml,
None,
ggml_backend_p
)
def ggml_backend_metal_init(): ...
except:
ggml_backend_metal_init = DummyFunction('Metal backend not found')

@ctypes_function(_ggml,
None,
ggml_backend_p
Expand Down
14 changes: 13 additions & 1 deletion gadget/libs/general.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
### general stuff

import os
import sys
import ctypes

##
Expand All @@ -18,8 +19,19 @@ def __call__(self, *args, **kwargs):
## ctypes helpers
##

# get library extension for system
plat_ext = {
'linux': 'so',
'darwin': 'dylib',
'win32': 'dll',
}

# load a shared lib with env override
def load_shared_lib(lib_name, env_var=None):
def load_shared_lib(lib_base=None, env_var=None):
# get filename
lib_ext = plat_ext[sys.platform]
lib_name = f'{lib_base}.{lib_ext}'

# get shared library path
if env_var is not None and env_var in os.environ:
lib_path = os.environ[env_var]
Expand Down
2 changes: 1 addition & 1 deletion gadget/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __setitem__(self, key, val):
def __getitem__(self, key):
# handle tuple accessor case (prevent recursion)
if type(key) is tuple:
return [super().__getitem__(k) for k in key]
return [super(AttrDict, self).__getitem__(k) for k in key]

# key type validation
if type(key) is not str:
Expand Down

0 comments on commit 08a6d71

Please sign in to comment.