Skip to content

Commit

Permalink
Enforce single XPU training (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT authored May 23, 2024
1 parent 9b97344 commit bea14a4
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 24 deletions.
28 changes: 5 additions & 23 deletions component-tests/training/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from parsl.app.python import PythonApp
from parsl.executors import HighThroughputExecutor
from parsl.providers import PBSProProvider
from parsl.launchers import MpiExecLauncher
from parsl.launchers import SimpleLauncher

from mofa.model import MOFRecord

Expand Down Expand Up @@ -111,29 +111,12 @@ def test_function(model_path: Path, config_path: Path, training_set: list, num_e
)
])
elif args.config.startswith("sunspot"):
if args.config == "sunspot":
accel_ids = [
f"{gid}.{tid}"
for gid in range(6)
for tid in range(2)
]
elif args.config == "sunspot-device":
accel_ids = [
f"{gid}.0,{gid}.1"
for gid in range(6)
]
else:
raise ValueError(f'Not supported: {args.config}')
config = Config(
retries=2,
executors=[
HighThroughputExecutor(
label="sunspot_test",
available_accelerators=accel_ids, # Ensures one worker per accelerator
cpu_affinity="block", # Assigns cpus in sequential order
prefetch_capacity=0,
max_workers=len(accel_ids),
cores_per_worker=208 // len(accel_ids),
max_workers=1,
provider=PBSProProvider(
account="CSC249ADCD08_CNDA",
queue="workq",
Expand All @@ -147,16 +130,15 @@ def test_function(model_path: Path, config_path: Path, training_set: list, num_e
module load gcc/12.2.0
module list
{"" if len(accel_ids) == 12 else "export IPEX_TILE_AS_DEVICE=0"}
python -c "import intel_extension_for_pytorch as ipex; print(ipex.xpu.device_count())"
cd $PBS_O_WORKDIR
pwd
which python
hostname
""",
walltime="1:10:00",
launcher=MpiExecLauncher(
bind_cmd="--cpu-bind", overrides="--depth=208 --ppn 1"
), # Ensures 1 manger per node and allows it to divide work among all 208 threads
launcher=SimpleLauncher(),
select_options="system=sunspot,place=scatter",
nodes_per_block=1,
min_blocks=0,
Expand Down
2 changes: 1 addition & 1 deletion envs/build-aurora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ conda activate ./env

# Build torch_ccl locally
# Clone from: https://github.com/intel/torch-ccl
cd libs/torch_ccl
cd libs/torch-ccl
COMPUTE_BACKEND=dpcpp pip install -e .

# Now install Corey's stuff
Expand Down
9 changes: 9 additions & 0 deletions envs/environment-aurora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,18 @@ dependencies:
- pytorch==2.1.0
- intel-extension-for-pytorch==2.1.10

# Tools to build CCL locally
- conda-forge::cmake
- ninja

- pip
- pip:
- git+https://gitlab.com/ase/ase.git
- git+https://github.com/exalearn/colmena.git # Fixes for streaming not yet on PyPI

# Install ccl manually for now, uncomment when SSL doesn't disagree between
# the following wheel's version and Sunspot/Aurora
#- --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
# - oneccl_bind_pt==2.1.200+xpu
- -e ..[test]

13 changes: 13 additions & 0 deletions mofa/difflinker_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@

from pytorch_lightning import Trainer, callbacks
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.strategies import SingleDeviceStrategy

try:
import intel_extension_for_pytorch as ipex # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
except ImportError:
pass

try:
import intel_extension_for_pytorch as ipex # noqa: F401
Expand Down Expand Up @@ -150,6 +157,11 @@ def main(
if '.' in args.train_data_prefix:
context_node_nf += 1

# Lock XPU to single device for now
strategy = 'auto'
if args.device == 'xpu':
strategy = SingleDeviceStrategy(device='xpu')

checkpoint_callback = [callbacks.ModelCheckpoint(
dirpath=checkpoints_dir,
filename='difflinker_{epoch:02d}',
Expand All @@ -164,6 +176,7 @@ def main(
accelerator=args.device,
num_sanity_val_steps=0,
enable_progress_bar=args.enable_progress_bar,
strategy=strategy
)

# Add a callback for fit setup
Expand Down

0 comments on commit bea14a4

Please sign in to comment.