Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PyTorch Metal Performance Shaders #685

Merged
merged 17 commits into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion thinc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .util import torch2xp, xp2torch, tensorflow2xp, xp2tensorflow, mxnet2xp, xp2mxnet
from .compat import has_cupy
from .backends import get_ops, set_current_ops, get_current_ops, use_ops
from .backends import Ops, CupyOps, NumpyOps, set_gpu_allocator
from .backends import Ops, CupyOps, MPSOps, NumpyOps, set_gpu_allocator
from .backends import use_pytorch_for_gpu_memory, use_tensorflow_for_gpu_memory

from .layers import Dropout, Embed, expand_window, HashEmbed, LayerNorm, Linear
Expand Down
8 changes: 7 additions & 1 deletion thinc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from .ops import Ops
from .cupy_ops import CupyOps
from .numpy_ops import NumpyOps
from .mps_ops import MPSOps
from ._cupy_allocators import cupy_tensorflow_allocator, cupy_pytorch_allocator
from ._param_server import ParamServer
from ..util import assert_tensorflow_installed, assert_pytorch_installed
from ..util import is_cupy_array, require_cpu
from ..util import get_torch_default_device, is_cupy_array, require_cpu
from .. import registry
from ..compat import cupy, has_cupy

Expand Down Expand Up @@ -48,6 +49,10 @@ def use_pytorch_for_gpu_memory() -> None: # pragma: no cover
(or vice versa), but do not currently have an implementation for it.
"""
assert_pytorch_installed()

if get_torch_default_device().type != "cuda":
return

pools = context_pools.get()
if "pytorch" not in pools:
pools["pytorch"] = cupy.cuda.MemoryPool(allocator=cupy_pytorch_allocator)
Expand Down Expand Up @@ -169,6 +174,7 @@ def _create_thread_local(
"ParamServer",
"Ops",
"CupyOps",
"MPSOps",
"NumpyOps",
"has_cupy",
]
4 changes: 2 additions & 2 deletions thinc/backends/cupy_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..types import DeviceTypes
from ..util import torch2xp, tensorflow2xp, mxnet2xp
from ..util import is_cupy_array
from ..util import is_torch_gpu_array, is_tensorflow_gpu_array, is_mxnet_gpu_array
from ..util import is_torch_cuda_array, is_tensorflow_gpu_array, is_mxnet_gpu_array
from ..compat import cupy, cupyx


Expand Down Expand Up @@ -62,7 +62,7 @@ def asarray(self, data, dtype=None):
# We'll try to perform a zero-copy conversion if possible.
if is_cupy_array(data):
array = data
elif is_torch_gpu_array(data):
elif is_torch_cuda_array(data):
array = torch2xp(data)
elif is_tensorflow_gpu_array(data):
array = tensorflow2xp(data)
Expand Down
26 changes: 26 additions & 0 deletions thinc/backends/mps_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import TYPE_CHECKING
import numpy

from .. import registry
from . import NumpyOps, Ops

if TYPE_CHECKING:
# Type checking does not work with dynamic base classes, since MyPy cannot
# determine against which base class to check. So, always derive from Ops
# during type checking.
_Ops = Ops
else:
try:
from thinc_apple_ops import AppleOps

_Ops = AppleOps
except ImportError:
_Ops = NumpyOps
shadeMe marked this conversation as resolved.
Show resolved Hide resolved


@registry.ops("MPSOps")
class MPSOps(_Ops):
"""Ops class for Metal Performance shaders."""

name = "mps"
xp = numpy
13 changes: 12 additions & 1 deletion thinc/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@
import torch

has_torch = True
has_torch_gpu = torch.cuda.device_count() != 0
has_torch_cuda_gpu = torch.cuda.device_count() != 0
has_torch_mps_gpu = (
hasattr(torch, "has_mps")
and torch.has_mps
and torch.backends.mps.is_available()
)
has_torch_gpu = has_torch_cuda_gpu
torch_version = Version(str(torch.__version__))
has_torch_amp = (
torch_version >= Version("1.9.0")
Expand All @@ -40,7 +46,9 @@
except ImportError: # pragma: no cover
torch = None # type: ignore
has_torch = False
has_torch_cuda_gpu = False
has_torch_gpu = False
has_torch_mps_gpu = False
has_torch_amp = False
torch_version = Version("0.0.0")

Expand Down Expand Up @@ -68,3 +76,6 @@
import h5py
except ImportError: # pragma: no cover
h5py = None


has_gpu = has_cupy_gpu or has_torch_mps_gpu
38 changes: 28 additions & 10 deletions thinc/tests/layers/test_pytorch_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,37 @@
from thinc.api import Linear, SGD, PyTorchWrapper, PyTorchWrapper_v2
from thinc.api import xp2torch, torch2xp, ArgsKwargs, use_ops
from thinc.api import chain, get_current_ops, Relu
from thinc.api import CupyOps, MPSOps, NumpyOps
from thinc.backends import context_pools
from thinc.shims.pytorch_grad_scaler import PyTorchGradScaler
from thinc.compat import has_torch, has_torch_amp, has_torch_gpu
from thinc.compat import has_cupy
from thinc.compat import has_torch, has_torch_amp
from thinc.compat import has_cupy_gpu, has_torch_mps_gpu
import numpy
import pytest
from thinc.util import get_torch_default_device

from ..util import make_tempdir, check_input_converters


XP_OPS = [NumpyOps()]
if has_cupy_gpu:
XP_OPS.append(CupyOps())
if has_torch_mps_gpu:
XP_OPS.append(MPSOps())


if has_torch_amp:
TORCH_MIXED_PRECISION = [False, True]
else:
TORCH_MIXED_PRECISION = [False]

XP_OPS_MIXED = [
(ops, mixed)
for ops in XP_OPS
for mixed in TORCH_MIXED_PRECISION
if not mixed or isinstance(ops, CupyOps)
]


def check_learns_zero_output(model, sgd, X, Y):
"""Check we can learn to output a zero vector"""
Expand Down Expand Up @@ -64,32 +80,34 @@ def test_pytorch_wrapper(nN, nI, nO):
assert isinstance(model.predict(X), numpy.ndarray)


@pytest.mark.skipif(
not has_cupy or not has_torch_gpu, reason="needs PyTorch with CUDA-capable GPU"
)
@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
@pytest.mark.parametrize("ops_mixed", XP_OPS_MIXED)
@pytest.mark.parametrize("nN,nI,nO", [(2, 3, 4)])
@pytest.mark.parametrize("mixed_precision", TORCH_MIXED_PRECISION)
def test_pytorch_wrapper_thinc_input(nN, nI, nO, mixed_precision):
def test_pytorch_wrapper_thinc_input(ops_mixed, nN, nI, nO):
import torch.nn

with use_ops("cupy"):
ops, mixed_precision = ops_mixed

with use_ops(ops.name):
ops = get_current_ops()
pytorch_layer = torch.nn.Linear(nO, nO)
# Initialize with large weights to trigger overflow of FP16 in
# mixed-precision training.
torch.nn.init.uniform_(pytorch_layer.weight, 9.0, 11.0)
device = get_torch_default_device()
model = chain(
Relu(),
PyTorchWrapper_v2(
pytorch_layer.cuda(),
pytorch_layer.to(device),
mixed_precision=mixed_precision,
grad_scaler=PyTorchGradScaler(
enabled=mixed_precision, init_scale=2.0**16
),
).initialize(),
)
# pytorch allocator is set in PyTorchShim
assert "pytorch" in context_pools.get()
if isinstance(ops, CupyOps):
assert "pytorch" in context_pools.get()
sgd = SGD(0.001)
X = ops.xp.zeros((nN, nI), dtype="f")
X += ops.xp.random.uniform(size=X.size).reshape(X.shape)
Expand Down
4 changes: 2 additions & 2 deletions thinc/tests/regression/test_issue564.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest

from thinc.api import CupyOps
from thinc.compat import has_torch, has_torch_gpu
from thinc.compat import has_torch, has_torch_cuda_gpu


@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
@pytest.mark.skipif(not has_torch_gpu, reason="needs a GPU")
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs a GPU")
def test_issue564():
import torch

Expand Down
6 changes: 3 additions & 3 deletions thinc/tests/shims/test_pytorch_grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from hypothesis import given, settings
from hypothesis.strategies import lists, one_of, tuples
from thinc.compat import has_torch, has_torch_amp, has_torch_gpu, torch
from thinc.compat import has_torch, has_torch_amp, has_torch_cuda_gpu, torch
from thinc.util import is_torch_array
from thinc.api import PyTorchGradScaler

Expand All @@ -14,7 +14,7 @@ def tensors():


@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
@pytest.mark.skipif(not has_torch_gpu, reason="needs a GPU")
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs a GPU")
@pytest.mark.skipif(
not has_torch_amp, reason="requires PyTorch with mixed-precision support"
)
Expand All @@ -37,7 +37,7 @@ def test_scale_random_inputs(X):


@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
@pytest.mark.skipif(not has_torch_gpu, reason="needs a GPU")
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs a GPU")
@pytest.mark.skipif(
not has_torch_amp, reason="requires PyTorch with mixed-precision support"
)
Expand Down
46 changes: 31 additions & 15 deletions thinc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from contextvars import ContextVar
from dataclasses import dataclass
from .compat import has_cupy, has_mxnet, has_torch, has_tensorflow
from .compat import has_cupy_gpu, has_torch_gpu
from .compat import has_cupy_gpu, has_torch_cuda_gpu, has_gpu
from .compat import has_torch_mps_gpu
from .compat import torch, cupy, tensorflow as tf, mxnet as mx, cupy_from_dlpack

DATA_VALIDATION: ContextVar[bool] = ContextVar("DATA_VALIDATION", default=False)
Expand All @@ -33,11 +34,14 @@ def get_torch_default_device() -> "torch.device":

from .backends import get_current_ops
from .backends.cupy_ops import CupyOps
from .backends.mps_ops import MPSOps

ops = get_current_ops()
if isinstance(ops, CupyOps):
device_id = torch.cuda.current_device()
return torch.device(f"cuda:{device_id}")
elif isinstance(ops, MPSOps):
return torch.device("mps")

return torch.device("cpu")

Expand All @@ -50,7 +54,7 @@ def get_array_module(arr): # pragma: no cover


def gpu_is_available():
return has_cupy_gpu
return has_gpu


def fix_random_seed(seed: int = 0) -> None: # pragma: no cover
Expand All @@ -61,7 +65,7 @@ def fix_random_seed(seed: int = 0) -> None: # pragma: no cover
torch.manual_seed(seed)
if has_cupy_gpu:
cupy.random.seed(seed)
if has_torch and has_torch_gpu:
if has_torch and has_torch_cuda_gpu:
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
Expand Down Expand Up @@ -99,10 +103,18 @@ def is_torch_array(obj: Any) -> bool: # pragma: no cover
return False


def is_torch_gpu_array(obj: Any) -> bool: # pragma: no cover
def is_torch_cuda_array(obj: Any) -> bool: # pragma: no cover
return is_torch_array(obj) and obj.is_cuda


def is_torch_gpu_array(obj: Any) -> bool: # pragma: no cover
return is_torch_cuda_array(obj) or is_torch_mps_array(obj)


def is_torch_mps_array(obj: Any) -> bool: # pragma: no cover
return is_torch_array(obj) and hasattr(obj, "is_mps") and obj.is_mps


def is_tensorflow_array(obj: Any) -> bool: # pragma: no cover
if not has_tensorflow:
return False
Expand Down Expand Up @@ -146,7 +158,7 @@ def set_active_gpu(gpu_id: int) -> "cupy.cuda.Device": # pragma: no cover
device = cupy.cuda.device.Device(gpu_id)
device.use()

if has_torch_gpu:
if has_torch_cuda_gpu:
torch.cuda.set_device(gpu_id)

return device
Expand All @@ -164,21 +176,25 @@ def require_cpu() -> bool: # pragma: no cover

def prefer_gpu(gpu_id: int = 0) -> bool: # pragma: no cover
"""Use GPU if it's available. Returns True if so, False otherwise."""
if not has_cupy_gpu:
return False
else:
if has_gpu:
require_gpu(gpu_id=gpu_id)
return True
else:
return False
danieldk marked this conversation as resolved.
Show resolved Hide resolved


def require_gpu(gpu_id: int = 0) -> bool: # pragma: no cover
from .backends import set_current_ops, CupyOps
from .backends import set_current_ops, CupyOps, MPSOps

if not has_cupy_gpu:
raise ValueError("No CUDA GPU devices detected")
if not has_gpu:
raise ValueError("No GPU devices detected")

if has_cupy_gpu:
set_current_ops(CupyOps())
set_active_gpu(gpu_id)
else:
set_current_ops(MPSOps())

set_current_ops(CupyOps())
set_active_gpu(gpu_id)
return True


Expand Down Expand Up @@ -353,14 +369,14 @@ def torch2xp(
from .api import NumpyOps

assert_pytorch_installed()
if is_torch_gpu_array(torch_tensor):
if is_torch_cuda_array(torch_tensor):
if isinstance(ops, NumpyOps):
return torch_tensor.detach().cpu().numpy()
else:
return cupy_from_dlpack(torch.utils.dlpack.to_dlpack(torch_tensor))
else:
if isinstance(ops, NumpyOps) or ops is None:
return torch_tensor.detach().numpy()
return torch_tensor.detach().cpu().numpy()
else:
return cupy.asarray(torch_tensor)

Expand Down