@@ -219,7 +219,14 @@ class ExampleTestMeta(type):
219
219
220
220
@staticmethod
221
221
def to_test (
222
- model_name : str , multi_card : bool , deepspeed : bool , example_name : str , fsdp : bool , fp8 : bool , task_name : str
222
+ model_name : str ,
223
+ multi_card : bool ,
224
+ deepspeed : bool ,
225
+ example_name : str ,
226
+ fsdp : bool ,
227
+ fp8 : bool ,
228
+ eager_mode : bool ,
229
+ task_name : str ,
223
230
):
224
231
models_with_specific_rules = [
225
232
"albert-xxlarge-v1" ,
@@ -247,6 +254,8 @@ def to_test(
247
254
"run_image2text_lora_finetune" ,
248
255
]
249
256
257
+ models_measured_on_eager_mode = ["google/gemma-2b-it" ]
258
+
250
259
if (fsdp or fp8 ) and not IS_GAUDI2 :
251
260
return False
252
261
elif (
@@ -271,6 +280,8 @@ def to_test(
271
280
"ln_tuning" ,
272
281
):
273
282
return False
283
+ elif eager_mode and not model_name in models_measured_on_eager_mode :
284
+ return False
274
285
elif model_name not in models_with_specific_rules and not deepspeed :
275
286
return True
276
287
elif model_name == "gpt2-xl" and deepspeed :
@@ -321,6 +332,7 @@ def __new__(
321
332
fsdp = False ,
322
333
torch_compile = False ,
323
334
fp8 = False ,
335
+ eager_mode = False ,
324
336
):
325
337
distribution = "single_card"
326
338
if multi_card :
@@ -340,7 +352,7 @@ def __new__(
340
352
)
341
353
342
354
for model_name , gaudi_config_name in models_to_test :
343
- if cls .to_test (model_name , multi_card , deepspeed , example_name , fsdp , fp8 , attrs ["TASK_NAME" ]):
355
+ if cls .to_test (model_name , multi_card , deepspeed , example_name , fsdp , fp8 , eager_mode , attrs ["TASK_NAME" ]):
344
356
attrs [f"test_{ example_name } _{ model_name .split ('/' )[- 1 ]} _{ distribution } " ] = cls ._create_test (
345
357
model_name , gaudi_config_name , multi_card , deepspeed , fsdp , torch_compile , fp8
346
358
)
@@ -424,9 +436,15 @@ def test(self):
424
436
create_clip_roberta_model ()
425
437
426
438
self ._install_requirements (example_script .parent / "requirements.txt" )
427
- path_to_baseline = BASELINE_DIRECTORY / Path (
428
- model_name .split ("/" )[- 1 ].replace ("-" , "_" ).replace ("." , "_" )
429
- ).with_suffix (".json" )
439
+
440
+ # collect baseline from <model_name>_eager.json if eager_mode is True
441
+ if self .EAGER_MODE :
442
+ baseline_name = model_name .split ("/" )[- 1 ].replace ("-" , "_" ).replace ("." , "_" ) + "_eager"
443
+ else :
444
+ baseline_name = model_name .split ("/" )[- 1 ].replace ("-" , "_" ).replace ("." , "_" )
445
+
446
+ path_to_baseline = BASELINE_DIRECTORY / Path (baseline_name ).with_suffix (".json" )
447
+
430
448
with path_to_baseline .open ("r" ) as json_file :
431
449
device = "gaudi2" if IS_GAUDI2 else "gaudi"
432
450
baseline = json .load (json_file )[device ]
@@ -474,6 +492,10 @@ def test(self):
474
492
475
493
extra_command_line_arguments = baseline .get ("distribution" ).get (distribution ).get ("extra_arguments" , [])
476
494
495
+ if self .EAGER_MODE :
496
+ env_variables ["PT_HPU_LAZY_MODE" ] = "0"
497
+ if "--use_hpu_graphs_for_inference" in extra_command_line_arguments :
498
+ extra_command_line_arguments .remove ("--use_hpu_graphs_for_inference" )
477
499
if os .environ .get ("DATA_CACHE" , None ) is not None and self .EXAMPLE_NAME == "run_clip" :
478
500
extra_command_line_arguments [0 ] = "--data_dir {}" .format (os .environ ["DATA_CACHE" ])
479
501
elif torch_compile and (
@@ -548,6 +570,7 @@ class ExampleTesterBase(TestCase):
548
570
"train_samples_per_second" : (TestCase .assertGreaterEqual , 2 - TIME_PERF_FACTOR ),
549
571
"eval_samples_per_second" : (TestCase .assertGreaterEqual , 2 - TIME_PERF_FACTOR ),
550
572
}
573
+ EAGER_MODE = None
551
574
552
575
def _create_command_line (
553
576
self ,
@@ -723,6 +746,13 @@ class MultiCardQuestionAnsweringExampleTester(
723
746
TASK_NAME = "squad"
724
747
725
748
749
+ class EagerModeCausalLanguageModelingExampleTester (
750
+ ExampleTesterBase , metaclass = ExampleTestMeta , example_name = "run_clm" , eager_mode = True
751
+ ):
752
+ TASK_NAME = "wikitext"
753
+ EAGER_MODE = True
754
+
755
+
726
756
class CausalLanguageModelingExampleTester (ExampleTesterBase , metaclass = ExampleTestMeta , example_name = "run_clm" ):
727
757
TASK_NAME = "wikitext"
728
758
0 commit comments