Skip to content

Commit

Permalink
Fix a bug that error is raised when training a model after HPO on mul…
Browse files Browse the repository at this point in the history
…ti XPU environment (#3081)

fix HPO index error on XPU

Co-authored-by: Emily <emily.chun@intel.com>
  • Loading branch information
eunwoosh and chuneuny-emily authored Mar 13, 2024
1 parent e44e96b commit 793b136
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/otx/hpo/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def _make_env_var_for_train(self, device_arr: List[int]) -> Dict[str, str]:
class XPUResourceManager(AcceleratorManager):
"""Resource manager class for XPU."""

def __init__(self, num_devices_per_trial: int = 1, available_devices: Optional[str] = None):
super().__init__(num_devices_per_trial, available_devices)
torch.xpu.init() # Avoid default_generators index error in multi XPU environment

def _set_available_devices(self, available_devices: Optional[str] = None) -> List[int]:
if available_devices is None:
visible_devices = os.getenv("ONEAPI_DEVICE_SELECTOR", "").split(":")
Expand Down

0 comments on commit 793b136

Please sign in to comment.