Skip to content

Commit 63e9305

Browse files
committed
added Eager Mode test case
1 parent 962d0c5 commit 63e9305

File tree

2 files changed

+55
-5
lines changed

2 files changed

+55
-5
lines changed
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"gaudi2": {
3+
"wikitext": {
4+
"num_train_epochs": 2,
5+
"eval_batch_size": 4,
6+
"distribution": {
7+
"single_card": {
8+
"learning_rate": 2e-4,
9+
"train_batch_size": 4,
10+
"perplexity": 26.69,
11+
"train_runtime": 560.8188,
12+
"train_samples_per_second": 8.597,
13+
"extra_arguments": [
14+
"--dataset_config_name wikitext-2-raw-v1"
15+
]
16+
}
17+
}
18+
}
19+
}
20+
}

tests/test_examples.py

+35-5
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,14 @@ class ExampleTestMeta(type):
219219

220220
@staticmethod
221221
def to_test(
222-
model_name: str, multi_card: bool, deepspeed: bool, example_name: str, fsdp: bool, fp8: bool, task_name: str
222+
model_name: str,
223+
multi_card: bool,
224+
deepspeed: bool,
225+
example_name: str,
226+
fsdp: bool,
227+
fp8: bool,
228+
eager_mode: bool,
229+
task_name: str,
223230
):
224231
models_with_specific_rules = [
225232
"albert-xxlarge-v1",
@@ -247,6 +254,8 @@ def to_test(
247254
"run_image2text_lora_finetune",
248255
]
249256

257+
models_measured_on_eager_mode = ["google/gemma-2b-it"]
258+
250259
if (fsdp or fp8) and not IS_GAUDI2:
251260
return False
252261
elif (
@@ -271,6 +280,8 @@ def to_test(
271280
"ln_tuning",
272281
):
273282
return False
283+
elif eager_mode and not model_name in models_measured_on_eager_mode:
284+
return False
274285
elif model_name not in models_with_specific_rules and not deepspeed:
275286
return True
276287
elif model_name == "gpt2-xl" and deepspeed:
@@ -321,6 +332,7 @@ def __new__(
321332
fsdp=False,
322333
torch_compile=False,
323334
fp8=False,
335+
eager_mode=False,
324336
):
325337
distribution = "single_card"
326338
if multi_card:
@@ -340,7 +352,7 @@ def __new__(
340352
)
341353

342354
for model_name, gaudi_config_name in models_to_test:
343-
if cls.to_test(model_name, multi_card, deepspeed, example_name, fsdp, fp8, attrs["TASK_NAME"]):
355+
if cls.to_test(model_name, multi_card, deepspeed, example_name, fsdp, fp8, eager_mode, attrs["TASK_NAME"]):
344356
attrs[f"test_{example_name}_{model_name.split('/')[-1]}_{distribution}"] = cls._create_test(
345357
model_name, gaudi_config_name, multi_card, deepspeed, fsdp, torch_compile, fp8
346358
)
@@ -424,9 +436,15 @@ def test(self):
424436
create_clip_roberta_model()
425437

426438
self._install_requirements(example_script.parent / "requirements.txt")
427-
path_to_baseline = BASELINE_DIRECTORY / Path(
428-
model_name.split("/")[-1].replace("-", "_").replace(".", "_")
429-
).with_suffix(".json")
439+
440+
# collect baseline from <model_name>_eager.json if eager_mode is True
441+
if self.EAGER_MODE:
442+
baseline_name = model_name.split("/")[-1].replace("-", "_").replace(".", "_") + "_eager"
443+
else:
444+
baseline_name = model_name.split("/")[-1].replace("-", "_").replace(".", "_")
445+
446+
path_to_baseline = BASELINE_DIRECTORY / Path(baseline_name).with_suffix(".json")
447+
430448
with path_to_baseline.open("r") as json_file:
431449
device = "gaudi2" if IS_GAUDI2 else "gaudi"
432450
baseline = json.load(json_file)[device]
@@ -474,6 +492,10 @@ def test(self):
474492

475493
extra_command_line_arguments = baseline.get("distribution").get(distribution).get("extra_arguments", [])
476494

495+
if self.EAGER_MODE:
496+
env_variables["PT_HPU_LAZY_MODE"] = "0"
497+
if "--use_hpu_graphs_for_inference" in extra_command_line_arguments:
498+
extra_command_line_arguments.remove("--use_hpu_graphs_for_inference")
477499
if os.environ.get("DATA_CACHE", None) is not None and self.EXAMPLE_NAME == "run_clip":
478500
extra_command_line_arguments[0] = "--data_dir {}".format(os.environ["DATA_CACHE"])
479501
elif torch_compile and (
@@ -548,6 +570,7 @@ class ExampleTesterBase(TestCase):
548570
"train_samples_per_second": (TestCase.assertGreaterEqual, 2 - TIME_PERF_FACTOR),
549571
"eval_samples_per_second": (TestCase.assertGreaterEqual, 2 - TIME_PERF_FACTOR),
550572
}
573+
EAGER_MODE = None
551574

552575
def _create_command_line(
553576
self,
@@ -723,6 +746,13 @@ class MultiCardQuestionAnsweringExampleTester(
723746
TASK_NAME = "squad"
724747

725748

749+
class EagerModeCausalLanguageModelingExampleTester(
750+
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_clm", eager_mode=True
751+
):
752+
TASK_NAME = "wikitext"
753+
EAGER_MODE = True
754+
755+
726756
class CausalLanguageModelingExampleTester(ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_clm"):
727757
TASK_NAME = "wikitext"
728758

0 commit comments

Comments
 (0)