Skip to content

Commit

Permalink
fixes on reload
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
  • Loading branch information
fabianlim committed Sep 16, 2024
1 parent 9991375 commit be39ac9
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 14 deletions.
2 changes: 1 addition & 1 deletion plugins/framework/src/fms_acceleration/framework_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _update_config_contents(_cfg: Dict, content: Dict, key: str):
for key in _or_keys:
content = _trace_key_path(configuration, key)
if content is not None:
if reject is True:
if reject:
# it is an OR key, and if at least one of them specified
# then do not reject
reject = False
Expand Down
17 changes: 12 additions & 5 deletions plugins/framework/src/fms_acceleration/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,12 +354,19 @@ def _import_and_reload(model: torch.nn.Module):
# if target paths in rule s is a prefix of rule l, raise an error
_name_s, _obj_s, _path_s = rule_s.import_and_maybe_reload
_, _, _path_l = rule_l.import_and_maybe_reload
if _path_l.startswith(_path_s):
# - then since _s appears to the left of _l, we will rememeber
# its location and disable the reload later
# - doing this, actually _with_reload will contain things
# that are not reloaded

if _path_s == _path_l:
# - in the even the target is exactly the same, we will
# only reload once
rule_s.import_and_maybe_reload = (_name_s, _obj_s, None)
continue

# - otherwise, we do not consider the cases where the target
# is a subpath since this results in unpredictablity.
assert not _path_l.startswith(
_path_s
), f"Attempting to reload a subpath`{_path_s}` multiple times in \
{rule_s.rule_id} and {rule_l.rule_id}"

# handle those with reload first
for rule in _with_reload + _no_reload:
Expand Down
55 changes: 47 additions & 8 deletions plugins/framework/tests/test_model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_import_and_maybe_reload_rule_with_mp_replaces_old_attribute():
assert isinstance(module4.Module4Class().attribute, PatchedModuleClass)


def test_mp_throws_error_with_multiple_reloads_on_same_target():
def test_mp_multiple_reloads_on_same_target():
"""
Simulate a case where two rules attempt to reload on the same target prefix
Expand Down Expand Up @@ -196,19 +196,19 @@ def patched_mod_function():
# 1. Initialize a model with module path tests.model_patcher_fixtures.module4
model = module4.Module4Class()

# 2. Simulate patching a function in module4.module5.module5_1
# 2. Simulate patching a function in module4.module5
ModelPatcher.register(
ModelPatcherRule(
rule_id=f"{DUMMY_RULE_ID}.2",
import_and_maybe_reload=(
"tests.model_patcher_fixtures.module4.module5.module5_1.mod_5_function",
patched_mod_function,
"tests.model_patcher_fixtures.module4.module5.module5_1",
"tests.model_patcher_fixtures.module4.module5",
),
)
)

# 3. Simulate patching a class in module4.module5.module5_1
# 3. Simulate patching a class in module4 (an upstream path)
ModelPatcher.register(
ModelPatcherRule(
rule_id=f"{DUMMY_RULE_ID}.1",
Expand All @@ -221,12 +221,51 @@ def patched_mod_function():
)

# while there are occasions repeated reloads along the same target path prefix work,
# it is risky and not guaranteed to work for all cases.
# To prevent the risk of any of the patches conflicting,
# we throw an exception if a shorter target path is a prefix of another
# longer target path
# the model patch will only call a reload once on the path.
# - this is because reloading on upstream paths may intefere with downstream
# - reload on tests.model_patcher_fixtures.module4 (shorter) will be skipped
# - reload on tests.model_patcher_fixtures.module4.module5 (longer) will be called
ModelPatcher.patch(model)

# However the patch_target_module will be surreptiously called to prevent
# the overwrites demonstrated above if targets paths are
# are a prefixes of another longer target path
with isolate_test_module_fixtures():
with instantiate_model_patcher():
# 1. Initialize a model with module path tests.model_patcher_fixtures.module4
model = module4.Module4Class()

# 2. Simulate patching a function in module4.module5
ModelPatcher.register(
ModelPatcherRule(
rule_id=f"{DUMMY_RULE_ID}.2",
import_and_maybe_reload=(
"tests.model_patcher_fixtures.module4.module5.module5_1.mod_5_function",
patched_mod_function,
"tests.model_patcher_fixtures.module4.module5",
),
)
)

# 3. Simulate patching a class in module4 (an upstream path)
ModelPatcher.register(
ModelPatcherRule(
rule_id=f"{DUMMY_RULE_ID}.1",
import_and_maybe_reload=(
"tests.model_patcher_fixtures.module4.module5.module5_1.Module5Class",
PatchedModuleClass,
"tests.model_patcher_fixtures.module4.module5",
),
)
)

# while there are occasions repeated reloads along the same target path prefix work,
# the model patch will only call a reload once on the path.
ModelPatcher.patch(model)

# check that patching is applied to both
assert isinstance(module4.module5.Module5Class(), PatchedModuleClass)
assert module4.module5.mod_5_function() == "patched_mod_function"

def test_mp_throws_warning_with_multiple_patches():
"""
Expand Down

0 comments on commit be39ac9

Please sign in to comment.