diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py index 9a1c0a57..ea9f527e 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py @@ -168,21 +168,16 @@ def get_callbacks_and_ready_for_train( accelerator is not None and getattr(accelerator.state, "fsdp_plugin", None) is not None ): - # - use an internal function call to get the no split + # - use an internal function call to get the no split # module names, which are typically layers _layers = model._get_no_split_modules('') accelerator.state.fsdp_plugin.ignored_modules = [ getattr(layer, name) - for name in moe_component_module_names + for name in self._moe_component_module_names for layer in model.modules() if layer.__class__.__name__ in _layers ] -FSDP( - model, - ignored_modules=ignored_modules, -) - return callbacks