diff --git a/plugins/framework/src/fms_acceleration/framework_plugin.py b/plugins/framework/src/fms_acceleration/framework_plugin.py index 5f111d5a..cf1764d5 100644 --- a/plugins/framework/src/fms_acceleration/framework_plugin.py +++ b/plugins/framework/src/fms_acceleration/framework_plugin.py @@ -101,7 +101,7 @@ def _update_config_contents(_cfg: Dict, content: Dict, key: str): for key in _or_keys: content = _trace_key_path(configuration, key) if content is not None: - if reject is True: + if reject: # it is an OR key, and if at least one of them specified # then do not reject reject = False diff --git a/plugins/framework/src/fms_acceleration/model_patcher.py b/plugins/framework/src/fms_acceleration/model_patcher.py index 56c0771f..ea5d8fb9 100644 --- a/plugins/framework/src/fms_acceleration/model_patcher.py +++ b/plugins/framework/src/fms_acceleration/model_patcher.py @@ -354,12 +354,19 @@ def _import_and_reload(model: torch.nn.Module): # if target paths in rule s is a prefix of rule l, raise an error _name_s, _obj_s, _path_s = rule_s.import_and_maybe_reload _, _, _path_l = rule_l.import_and_maybe_reload - if _path_l.startswith(_path_s): - # - then since _s appears to the left of _l, we will rememeber - # its location and disable the reload later - # - doing this, actually _with_reload will contain things - # that are not reloaded + + if _path_s == _path_l: + # - in the even the target is exactly the same, we will + # only reload once rule_s.import_and_maybe_reload = (_name_s, _obj_s, None) + continue + + # - otherwise, we do not consider the cases where the target + # is a subpath since this results in unpredictablity. + assert not _path_l.startswith( + _path_s + ), f"Attempting to reload a subpath`{_path_s}` multiple times in \ + {rule_s.rule_id} and {rule_l.rule_id}" # handle those with reload first for rule in _with_reload + _no_reload: diff --git a/plugins/framework/tests/test_model_patcher.py b/plugins/framework/tests/test_model_patcher.py index f9be7447..1bcb78b8 100644 --- a/plugins/framework/tests/test_model_patcher.py +++ b/plugins/framework/tests/test_model_patcher.py @@ -123,7 +123,7 @@ def test_import_and_maybe_reload_rule_with_mp_replaces_old_attribute(): assert isinstance(module4.Module4Class().attribute, PatchedModuleClass) -def test_mp_throws_error_with_multiple_reloads_on_same_target(): +def test_mp_multiple_reloads_on_same_target(): """ Simulate a case where two rules attempt to reload on the same target prefix @@ -196,19 +196,19 @@ def patched_mod_function(): # 1. Initialize a model with module path tests.model_patcher_fixtures.module4 model = module4.Module4Class() - # 2. Simulate patching a function in module4.module5.module5_1 + # 2. Simulate patching a function in module4.module5 ModelPatcher.register( ModelPatcherRule( rule_id=f"{DUMMY_RULE_ID}.2", import_and_maybe_reload=( "tests.model_patcher_fixtures.module4.module5.module5_1.mod_5_function", patched_mod_function, - "tests.model_patcher_fixtures.module4.module5.module5_1", + "tests.model_patcher_fixtures.module4.module5", ), ) ) - # 3. Simulate patching a class in module4.module5.module5_1 + # 3. Simulate patching a class in module4 (an upstream path) ModelPatcher.register( ModelPatcherRule( rule_id=f"{DUMMY_RULE_ID}.1", @@ -221,12 +221,51 @@ def patched_mod_function(): ) # while there are occasions repeated reloads along the same target path prefix work, - # it is risky and not guaranteed to work for all cases. - # To prevent the risk of any of the patches conflicting, - # we throw an exception if a shorter target path is a prefix of another - # longer target path + # the model patch will only call a reload once on the path. + # - this is because reloading on upstream paths may intefere with downstream + # - reload on tests.model_patcher_fixtures.module4 (shorter) will be skipped + # - reload on tests.model_patcher_fixtures.module4.module5 (longer) will be called ModelPatcher.patch(model) + # However the patch_target_module will be surreptiously called to prevent + # the overwrites demonstrated above if targets paths are + # are a prefixes of another longer target path + with isolate_test_module_fixtures(): + with instantiate_model_patcher(): + # 1. Initialize a model with module path tests.model_patcher_fixtures.module4 + model = module4.Module4Class() + + # 2. Simulate patching a function in module4.module5 + ModelPatcher.register( + ModelPatcherRule( + rule_id=f"{DUMMY_RULE_ID}.2", + import_and_maybe_reload=( + "tests.model_patcher_fixtures.module4.module5.module5_1.mod_5_function", + patched_mod_function, + "tests.model_patcher_fixtures.module4.module5", + ), + ) + ) + + # 3. Simulate patching a class in module4 (an upstream path) + ModelPatcher.register( + ModelPatcherRule( + rule_id=f"{DUMMY_RULE_ID}.1", + import_and_maybe_reload=( + "tests.model_patcher_fixtures.module4.module5.module5_1.Module5Class", + PatchedModuleClass, + "tests.model_patcher_fixtures.module4.module5", + ), + ) + ) + + # while there are occasions repeated reloads along the same target path prefix work, + # the model patch will only call a reload once on the path. + ModelPatcher.patch(model) + + # check that patching is applied to both + assert isinstance(module4.module5.Module5Class(), PatchedModuleClass) + assert module4.module5.mod_5_function() == "patched_mod_function" def test_mp_throws_warning_with_multiple_patches(): """