Skip to content

Commit

Permalink
reply to comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kprokofi committed Mar 21, 2024
1 parent 8892171 commit 2489581
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 35 deletions.
8 changes: 4 additions & 4 deletions src/otx/algo/accelerators/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ class XPUAccelerator(Accelerator):

accelerator_name = "xpu"

def setup_device(self) -> None:
def setup_device(self, device: torch.device) -> None:
"""Sets up the specified device."""
if not is_xpu_available():
msg = f"XPU is not available. Please, check the environment."
if device.type != "xpu":
msg = f"Device should be xpu, got {device} instead"
raise RuntimeError(msg)

torch.xpu.set_device("xpu:0")
torch.xpu.set_device(device)
patch_packages_xpu()


Expand Down
10 changes: 5 additions & 5 deletions src/otx/algo/strategies/xpu_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def setup_optimizers(self, trainer: pl.Trainer) -> None:
self.model = model


# StrategyRegistry.register(
# SingleXPUStrategy.strategy_name,
# SingleXPUStrategy,
# description="Strategy that enables training on single XPU",
# )
StrategyRegistry.register(
SingleXPUStrategy.strategy_name,
SingleXPUStrategy,
description="Strategy that enables training on single XPU",
)
1 change: 1 addition & 0 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,7 @@ def _build_trainer(self, **kwargs) -> None:
self._cache.update(**kwargs)
# set up xpu device
if self._device.accelerator == DeviceType.xpu:
self._cache.update(strategy="xpu_single")
# add plugin for Automatic Mixed Precision on XPU
if self._cache.args["precision"] == 16:
self._cache.update(plugins=[MixedPrecisionXPUPlugin()])
Expand Down
28 changes: 2 additions & 26 deletions src/otx/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,44 +144,20 @@ def is_xpu_available() -> bool:

def patch_packages_xpu() -> None:
"""Patch packages when xpu is available."""
import lightning.pytorch as pl
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.core.optimizer import _init_optimizers_and_lr_schedulers

def patched_setup_optimizers(self, trainer: pl.Trainer) -> None:
"""Sets up optimizers."""
if trainer.state.fn != TrainerFn.FITTING:
return
assert self.lightning_module is not None
self.optimizers, self.lr_scheduler_configs = _init_optimizers_and_lr_schedulers(self.lightning_module)
if len(self.optimizers) != 1: # type: ignore[has-type]
msg = "XPU strategy doesn't support multiple optimizers"
raise RuntimeError(msg)
model, optimizer = torch.xpu.optimize(trainer.model, optimizer=self.optimizers[0]) # type: ignore[has-type]
self.optimizers = [optimizer]
self.model = model

# patch instance_data from mmengie
long_type_tensor = Union[torch.LongTensor, torch.xpu.LongTensor]
bool_type_tensor = Union[torch.BoolTensor, torch.xpu.BoolTensor]
instance_data.IndexType = Union[str, slice, int, list, long_type_tensor, bool_type_tensor, np.ndarray]

# patch nms, roi_align and setup_optimizers for the lightning strategy
global _nms_op_forward, _roi_align_forward, _setup_optimizers
# patch nms and roi_align
global _nms_op_forward, _roi_align_forward
_nms_op_forward = NMSop.forward
_roi_align_forward = RoIAlign.forward
_setup_optimizers = SingleDeviceStrategy.setup_optimizers
NMSop.forward = monkey_patched_nms
RoIAlign.forward = monkey_patched_roi_align
SingleDeviceStrategy.setup_optimizers = patched_setup_optimizers


def revert_packages_xpu():
from mmcv.ops.nms import NMSop
from mmcv.ops.roi_align import RoIAlign
from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
"""Revert packages when xpu is available."""
global _nms_op_forward, _roi_align_forward, _setup_optimizers
NMSop.forward = _nms_op_forward
RoIAlign.forward = _roi_align_forward
SingleDeviceStrategy.setup_optimizers = _setup_optimizers

0 comments on commit 2489581

Please sign in to comment.