-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhabana_instructlab.patch
590 lines (543 loc) · 21.5 KB
/
habana_instructlab.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
diff --git a/README.md b/README.md
index 57ab962..231871b 100644
--- a/README.md
+++ b/README.md
@@ -38,6 +38,17 @@ You can then install the library for development:
pip install -e ./training
```
+### Additional Gaudi2 packages
+
+```bash
+pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.19.0
+git clone --branch v1.15.0 https://github.com/huggingface/optimum-habana.git
+cd optimum-habana && pip install .
+pip install git+https://github.com/huggingface/transformers.git@482cb28
+pip install -e ./training
+pip install accelerate==0.34.1
+```
+
### Additional NVIDIA packages
This library uses the `flash-attn` package as well as other packages, which rely on NVIDIA-specific CUDA tooling to be installed.
@@ -365,3 +376,9 @@ run_training(
train_args=training_args,
)
```
+## Example fine-tuning on Gaudi using FSDP
+
+A sample script for fine-tuning is in scripts/run_ds.py. Modify the script to specifiy the model and fine-tuning dataset and execute run.sh
+```python
+>run.sh
+```
diff --git a/scripts/run.sh b/scripts/run.sh
new file mode 100755
index 0000000..4a7948a
--- /dev/null
+++ b/scripts/run.sh
@@ -0,0 +1,2 @@
+export PT_HPU_LAZY_MODE=0
+python scripts/run_ds.py
diff --git a/scripts/run_ds.py b/scripts/run_ds.py
new file mode 100644
index 0000000..5d72742
--- /dev/null
+++ b/scripts/run_ds.py
@@ -0,0 +1,47 @@
+from instructlab.training import (
+ run_training,
+ TorchrunArgs,
+ TrainingArgs,
+ DeepSpeedOptions,
+ FSDPOptions,
+ LoraOptions,
+)
+
+training_args = TrainingArgs(
+ # define data-specific arguments
+ # model_path = "/models/granite-7b-base",
+ model_path = "/software/users/jojimon/models/granite-7b-base",
+ # data_path = "open-llm-leaderboard/ibm-granite__granite-7b-base-details",
+ data_path = "sample-data/train_all_pruned_SDG.jsonl",
+ ckpt_output_dir = "data/saved_checkpoints",
+ data_output_dir = "data/outputs",
+
+ # define model-trianing parameters
+ max_seq_len = 512,
+ max_batch_len = 60000,
+ num_epochs = 9,
+ effective_batch_size=8,
+ save_samples = 250000,
+ learning_rate = 2e-6,
+ warmup_steps = 800,
+ is_padding_free = True, # set this to true when using Granite-based models
+ random_seed = 42,
+ disable_flash_attn=True,
+ distributed_backend="fsdp", # change to fsdp
+ distributed_training_framework="fsdp",
+ use_dolomite =False,
+ device='hpu',
+ #accelerate_full_state_at_epoch=False,
+)
+training_args.fsdp_options=FSDPOptions(
+ cpu_offload_params=True,
+ use_fsdp_plugin=True,
+)
+torchrun_args = TorchrunArgs(
+ nnodes = 1, # number of machines
+ nproc_per_node = 2, # num GPUs per machine
+ node_rank = 0, # node rank for this machine
+ rdzv_id = 123,
+ rdzv_endpoint = '127.0.0.1:12345'
+)
+run_training(torch_args=torchrun_args, train_args=training_args)
diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py
index 43bc455..f0df336 100644
--- a/src/instructlab/training/config.py
+++ b/src/instructlab/training/config.py
@@ -135,7 +135,8 @@ class FSDPOptions(BaseModel):
Represents the options for configuring FSDP which are exposed by the Training Library
"""
- cpu_offload_params: Optional[bool] = False
+ cpu_offload_params: Optional[bool] = True
+ use_fsdp_plugin: Optional[bool] = False
sharding_strategy: ShardingStrategies = ShardingStrategies.SHARD_GRAD_OP
@@ -166,6 +167,7 @@ class TrainingArgs(BaseModel):
# this field defines where we should be saving the processed version of the training dataset
# after we have tokenized it
data_output_dir: str
+ device: str
max_seq_len: int
max_batch_len: int
@@ -192,7 +194,7 @@ class TrainingArgs(BaseModel):
)
fsdp_options: FSDPOptions = Field(
default_factory=lambda: FSDPOptions(
- cpu_offload_params=False, sharding_strategy=ShardingStrategies.SHARD_GRAD_OP
+ cpu_offload_params=True, use_fsdp_plugin=False, sharding_strategy=ShardingStrategies.SHARD_GRAD_OP
)
)
distributed_backend: DistributedBackend = DistributedBackend.FSDP
diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py
index 1e5ac8b..5f3d545 100644
--- a/src/instructlab/training/main_ds.py
+++ b/src/instructlab/training/main_ds.py
@@ -11,6 +11,8 @@ import subprocess
import time
# Third Party
+import habana_frameworks.torch.distributed.hccl
+import torch.distributed as dist
from accelerate import Accelerator
try:
@@ -77,6 +79,7 @@ from instructlab.training.utils import (
save_hf_format_accelerate,
set_random_seed,
setup_logger,
+ Device,
)
import instructlab.training.data_process as dp
@@ -109,7 +112,7 @@ def setup_optimizer(args, model):
return optimizer
-def setup_model(args, tokenizer, train_loader, grad_accum, flash_enabled):
+def setup_model(args, tokenizer, train_loader, grad_accum, flash_enabled, device):
bnb_config = None
if args.lora_r > 0 and args.lora_quant_bits == 4:
# Third Party
@@ -228,9 +231,23 @@ def setup_model(args, tokenizer, train_loader, grad_accum, flash_enabled):
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
- accelerator = setup_accelerator(args, model, grad_accum)
- if args.distributed_training_framework == DistributedBackend.FSDP.value:
+ accelerator = setup_accelerator(args, model, grad_accum, device)
+ def get_dtype(args, device):
+ if args.distributed_training_framework == DistributedBackend.DEEPSPEED.value:
+ dscon= accelerator.state.deepspeed_plugin.hf_ds_config
+ if dscon.config['bf16']['enabled']:
+ return torch.bfloat16
+ if dscon.config['fp16']['enabled']:
+ return torch.float16
+ elif device.is_hpu():
+ return torch.bfloat16
+ return torch.float32
+
+ if device.is_hpu():
+ model = model.to(dtype=get_dtype(args, device), device=device())
+ elif args.distributed_training_framework == DistributedBackend.FSDP.value:
model = accelerator.prepare(model)
+
optimizer = setup_optimizer(args, model)
lr_scheduler = get_scheduler(
@@ -321,6 +338,7 @@ def train(
train_loader: DataLoader,
grad_accum,
metric_logger,
+ device,
):
model.train()
@@ -366,6 +384,8 @@ def train(
inner_pb = tqdm(range(len(train_loader)), desc=f"Epoch {epoch}")
# blast through the batches in the train loader up to the last step within the epoch.
+ ddp_loss = torch.zeros(2).to(device=device())
+
for batch in train_loader:
if global_step <= args.last_step:
# in the case of resuming, last_step > 0
@@ -380,7 +400,8 @@ def train(
micro_batch_size = float(torch.tensor([batch.pop("num_samples")]))
if not args.use_dolomite:
for k in batch:
- batch[k] = batch[k].to(local_rank)
+ if torch.is_tensor(batch[k]):
+ batch[k] = batch[k].to(device=device())
output = model(
**batch,
use_cache=False,
@@ -409,10 +430,12 @@ def train(
f"Epoch: {epoch}, Step: {global_step}, Rank: {torch.distributed.get_rank()}, loss = {loss}"
)
accelerator.backward(loss)
+ device.mark_step()
if global_step % grad_accum == 0:
global_grad_norm = accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
+ device.mark_step()
lr_scheduler.step()
optimizer.zero_grad()
@@ -420,8 +443,8 @@ def train(
elapsed_time = time.time() - start
overall_throughput = args.samples_per_gpu * world_size / elapsed_time
current_lr = lr_scheduler.get_last_lr()[0]
- cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
- cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
+ mem_allocated = device.memory_allocated() / (1024**3)
+ malloc_retries = device.memory_stats()["num_alloc_retries"] if device.is_cuda() else 0
global_grad_norm = (
model.get_global_grad_norm()
if hasattr(model, "get_global_grad_norm")
@@ -443,8 +466,7 @@ def train(
"rank": torch.distributed.get_rank(),
"overall_throughput": overall_throughput,
"lr": current_lr,
- "cuda_mem_allocated": cuda_mem_allocated,
- "cuda_malloc_retries": cuda_malloc_retries,
+ "mem_allocated": mem_allocated,
"num_loss_counted_tokens": int(num_loss_counted_tokens),
"batch_size": int(micro_batch_size),
"total_loss": float(log_loss / num_loss_counted_tokens),
@@ -481,7 +503,13 @@ def train(
global_step += 1
if local_rank == 0:
inner_pb.update(1)
- torch.cuda.empty_cache()
+ device.empty_cache()
+ ddp_loss[0] += loss.item()
+ ddp_loss[1] += micro_batch_size
+
+ del loss, output, batch
+
+ dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if args.checkpoint_at_epoch:
save_checkpoint(
args=args,
@@ -490,7 +518,7 @@ def train(
tokenizer=tokenizer,
samples_seen=samples_seen,
is_lora=bool(args.lora_r),
- full_state=args.accelerate_full_state_at_epoch,
+ #full_state=args.accelerate_full_state_at_epoch,
hf_format=True,
epoch=epoch,
)
@@ -510,6 +538,9 @@ def main(args):
# Third Party
import yaml
+ device = Device(args.device)
+ set_random_seed(args.seed, device)
+
if args.distributed_training_framework == "deepspeed" and not FusedAdam:
raise ImportError(
"DeepSpeed was selected but we cannot import the `FusedAdam` optimizer"
@@ -535,17 +566,16 @@ def main(args):
setup_logger(args.log_level)
CHAT_TEMPLATE, SPECIAL_TOKENS = retrieve_chat_template(args.chat_tmpl_path)
tokenizer = setup_tokenizer(args.model_name_or_path, SPECIAL_TOKENS, CHAT_TEMPLATE)
- # device = torch.device("cuda", args.local_rank)
model_conf = AutoConfig.from_pretrained(args.model_name_or_path)
args.model_type = model_conf.model_type
#### distributed init #####
- torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
+ device.set_current()
args.local_rank = int(os.environ["LOCAL_RANK"])
- torch.distributed.init_process_group("nccl")
+ torch.distributed.init_process_group(device.dlib())
args.global_rank = torch.distributed.get_rank()
- tensor = torch.ByteTensor([False]).cuda()
+ tensor = torch.ByteTensor([False]).to(device())
torch.distributed.all_reduce(tensor)
torch.distributed.barrier()
@@ -632,7 +662,7 @@ def main(args):
)
model, lr_scheduler, optimizer, accelerator = setup_model(
- args, tokenizer, train_loader, grad_accum, flash_enabled
+ args, tokenizer, train_loader, grad_accum, flash_enabled, device
)
load_latest_full_state(args=args, accelerator=accelerator)
@@ -647,6 +677,7 @@ def main(args):
train_loader,
grad_accum,
metric_logger,
+ device,
)
torch.distributed.barrier()
@@ -705,6 +736,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
f"--max_batch_len={train_args.max_batch_len}",
f"--seed={train_args.random_seed}",
f"--chat-tmpl-path={train_args.chat_tmpl_path}",
+ f"--device={train_args.device}",
]
if train_args.keep_last_checkpoint_only:
@@ -787,6 +819,8 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
command.append(
f"--fsdp_sharding_strategy={train_args.fsdp_options.sharding_strategy.value}"
)
+ if train_args.fsdp_options.use_fsdp_plugin:
+ command.append(f"--use_fsdp_plugin=True")
if train_args.keep_last_checkpoint_only:
command.append("--keep_last_checkpoint_only")
@@ -844,6 +878,7 @@ if __name__ == "__main__":
parser.add_argument("--model_name_or_path", type=str)
parser.add_argument("--data_path", type=str)
parser.add_argument("--output_dir", type=str)
+ parser.add_argument("--device", type=str)
parser.add_argument("--num_epochs", type=int, default=1)
parser.add_argument(
"--current_epoch",
@@ -914,6 +949,7 @@ if __name__ == "__main__":
],
default=DistributedBackend.DEEPSPEED.value,
)
+ parser.add_argument("--use_fsdp_plugin", type=bool, default=False)
parser.add_argument(
"--fsdp_sharding_strategy",
type=str,
@@ -942,7 +978,7 @@ if __name__ == "__main__":
parser.add_argument(
"--cpu_offload_params_fsdp",
action="store_true",
- default=False,
+ default=True,
help="Offload to CPU when using FSDP.",
)
parser.add_argument(
@@ -975,7 +1011,6 @@ if __name__ == "__main__":
),
)
args = parser.parse_args()
- set_random_seed(args.seed)
main(args)
"""
diff --git a/src/instructlab/training/setup_accelerator.py b/src/instructlab/training/setup_accelerator.py
index c7d079e..77d6ec7 100644
--- a/src/instructlab/training/setup_accelerator.py
+++ b/src/instructlab/training/setup_accelerator.py
@@ -3,6 +3,7 @@ from functools import partial
# Third Party
from accelerate import Accelerator
+from optimum.habana.accelerate import GaudiAccelerator
from peft.utils.other import fsdp_auto_wrap_policy
from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
@@ -12,6 +13,9 @@ import torch
# First Party
from instructlab.training.config import DeepSpeedOptions
from instructlab.training.utils import get_module_class_from_name, patch_target_module
+from torch.distributed.fsdp.api import (
+ StateDictType,
+)
def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOptions):
@@ -49,9 +53,10 @@ def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOption
return ds_plugin
-def get_fsdp_config(args, model: PreTrainedModel):
+def get_fsdp_config(args, model: PreTrainedModel, device):
# Third Party
from accelerate.utils import FullyShardedDataParallelPlugin
+ from optimum.habana.accelerate.utils import GaudiFullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
is_lora = args.lora_r > 0
@@ -73,28 +78,34 @@ def get_fsdp_config(args, model: PreTrainedModel):
prefetch_policy = (
BackwardPrefetch.BACKWARD_POST if is_lora else BackwardPrefetch.BACKWARD_PRE
)
- fsdp_plugin = FullyShardedDataParallelPlugin(
- auto_wrap_policy=wrap_policy,
- limit_all_gathers=True,
- mixed_precision_policy=MixedPrecision(
- param_dtype=torch.bfloat16,
- reduce_dtype=torch.bfloat16,
- buffer_dtype=torch.bfloat16,
- ),
- backward_prefetch=prefetch_policy,
- sharding_strategy=ShardingStrategy[args.fsdp_sharding_strategy],
- cpu_offload=CPUOffload(args.cpu_offload_params_fsdp),
- )
- # `use_orig_params` must be disabled when using LoRA and FSDP together
- # Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts
- if args.lora_r > 0:
- fsdp_plugin.use_orig_params = False
+ if args.use_fsdp_plugin :
+ fsdp_plugin = GaudiFullyShardedDataParallelPlugin(
+ auto_wrap_policy=wrap_policy,
+ limit_all_gathers=True,
+ mixed_precision_policy=MixedPrecision(
+ param_dtype=torch.bfloat16,
+ reduce_dtype=torch.bfloat16,
+ buffer_dtype=torch.bfloat16,
+ ),
+ backward_prefetch=prefetch_policy,
+ sharding_strategy=ShardingStrategy[args.fsdp_sharding_strategy],
+ cpu_offload=CPUOffload(args.cpu_offload_params_fsdp),
+ #state_dict_type = StateDictType.FULL_STATE_DICT if device.is_hpu() else None,
+ )
+ #fsdp_plugin.set_state_dict_type(StateDictType.FULL_STATE_DICT)
+
+ # `use_orig_params` must be disabled when using LoRA and FSDP together
+ # Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts
+ if args.lora_r > 0:
+ fsdp_plugin.use_orig_params = False
+ else:
+ fsdp_plugin = None
return fsdp_plugin
-def setup_accelerator(args, model: PreTrainedModel, grad_accum):
+def setup_accelerator(args, model: PreTrainedModel, grad_accum, device):
if args.distributed_training_framework == "deepspeed":
try:
# Third Party
@@ -126,14 +137,20 @@ def setup_accelerator(args, model: PreTrainedModel, grad_accum):
}
elif args.distributed_training_framework == "fsdp":
accel_args = {
- "fsdp_plugin": get_fsdp_config(args, model),
+ "fsdp_plugin": get_fsdp_config(args, model, device),
}
else:
raise ValueError(
f"Unknown sharding framework: {args.distributed_training_framework}"
)
- accelerator = Accelerator(
- **accel_args,
- )
+
+ if device.is_hpu():
+ accelerator = GaudiAccelerator(
+ **accel_args,
+ )
+ else:
+ accelerator = Accelerator(
+ **accel_args,
+ )
accelerator.even_batches = False
return accelerator
diff --git a/src/instructlab/training/token_dataset.py b/src/instructlab/training/token_dataset.py
index fda9a75..fffc095 100644
--- a/src/instructlab/training/token_dataset.py
+++ b/src/instructlab/training/token_dataset.py
@@ -123,7 +123,7 @@ def setup_dataloader(
from torch.utils.data import DistributedSampler
sampler = (
- DistributedSampler(dataset) if torch.distributed.is_initialized() else None
+ DistributedSampler(dataset,rank=rank, num_replicas=world_size, shuffle=True) if torch.distributed.is_initialized() else None
)
sampler = {
"sampler": sampler,
diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py
index 7fdad3c..9ac6121 100644
--- a/src/instructlab/training/utils.py
+++ b/src/instructlab/training/utils.py
@@ -44,6 +44,7 @@ from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokeni
import numpy as np
import torch
import torch.nn.functional as F
+import habana_frameworks.torch.core as htcore
# First Party
from instructlab.training.config import (
@@ -170,13 +171,15 @@ class StreamablePopen(subprocess.Popen):
break
-def supports_flash_attention(device_id=0):
+def supports_flash_attention(device):
"""Check if a GPU supports FlashAttention."""
- major, minor = torch.cuda.get_device_capability(device_id)
+ if device.is_hpu():
+ return False
+ major, minor = device.get_capability().split('.')
# Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0)
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
- dev_name = torch.cuda.get_device_properties(device_id).gcnArchName.split(":")[0]
+ dev_name = device.get_properties().gcnArchName.split(":")[0]
is_compat_amd = dev_name in ("gfx90a", "gfx940", "gfx941", "gfx942")
return is_sm8x or is_sm90 or is_compat_amd
@@ -1071,12 +1074,12 @@ def save_model_ds_native(
log_rank_0(f"saving took {time.time() - start} seconds")
-def set_random_seed(seed):
+def set_random_seed(seed, device):
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
+ device.manual_seed_all(seed)
def save_checkpoint(
@@ -1205,3 +1208,56 @@ def get_projection_layer_names(model: PreTrainedModel) -> List[str]:
if name.endswith("_proj")
)
return list(proj_layers)
+
+class Device:
+ """
+ A generic class to handle all device types and
+ related api calls
+ """
+
+ def __init__(self, device_type):
+ self.device_type = device_type
+ self.local_rank = int(os.environ["LOCAL_RANK"])
+ self.device_apis = None
+ if device_type == 'cuda':
+ self.device_apis = torch.cuda
+ elif device_type == 'hpu':
+ self.device_apis = torch.hpu
+
+ def __call__(self):
+ return self.device_type
+
+ def memory_allocated(self):
+ return self.device_apis.memory_allocated()
+
+ def memory_stats(self):
+ return self.device_apis.memory_stats()
+
+ def empty_cache(self):
+ if self.device_type =='cuda':
+ self.device_apis.empty_cache()
+
+ def set_current(self):
+ self.device_apis.set_device(self.local_rank)
+
+ def is_hpu(self):
+ return self.device_type == 'hpu'
+
+ def is_cuda(self):
+ return self.device_type == 'cuda'
+
+ def get_capability(self):
+ return self.device_apis.get_device_capability(self.local_rank)
+
+ def get_properties(self):
+ return self.device_apis.get_device_properties(self.local_rank)
+
+ def manual_seed_all(self, seed):
+ self.device_apis.manual_seed_all(seed)
+
+ def mark_step(self):
+ if self.is_hpu():
+ htcore.mark_step()
+
+ def dlib(self):
+ return 'hccl' if self.device_type == 'hpu' else 'nccl'