Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load dynamic module (remote code) only once if code isn't change #33162

Merged
merged 6 commits into from
Sep 6, 2024

Conversation

XuehaiPan
Copy link
Contributor

@XuehaiPan XuehaiPan commented Aug 28, 2024

What does this PR do?

Fixes #30370 (comment)

Add an indicator __transformers_module_hash__ to the remote code module.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ydshieh

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 28, 2024

Hi @XuehaiPan, thank you for taking the action.

My first question is: what if the target (remote code) is changed since the last time it has been loaded.

I don't know very well what is the context mentioned in #30370 (comment), but IIRD, that happens within a single python process (i.e. multiple loading of the same module). And if that is the case, my question on top of this comment have to be considered, right?

Maybe @tmm1 could explain a bit more the situation, ideally with a code snippet to demonstrate the issue.

@XuehaiPan
Copy link
Contributor Author

My first question is: what if the target (remote code) is changed since the last time it has been loaded.

@ydshieh Thanks for raising this. I change this indicator to the hash of the source code.

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 28, 2024

Nice. I still prefer for @tmm1 to elaborate the issue a bit more though.

But in the meantime, could you tell me more about the usage of threading.Lock() in this PR 🙏 ?

@XuehaiPan
Copy link
Contributor Author

But in the meantime, could you tell me more about the usage of threading.Lock() in this PR 🙏 ?

It's just for thread safety. We are modifying the global variable sys.modules.

@tmm1
Copy link
Contributor

tmm1 commented Aug 28, 2024

prefer for @tmm1 to elaborate the issue a bit more though

sure, here is a simple example:

import sys
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM

def load_modeling_code(model_name):
    with init_empty_weights():
        model = AutoModelForCausalLM.from_pretrained(
            model_name, trust_remote_code=True
        )
        return sys.modules[model.__class__.__module__]

model_name = "deepseek-ai/DeepSeek-Coder-V2-Lite-Base"

# 1. patch modeling code
mod = load_modeling_code(model_name)
import liger_kernel
mod.CrossEntropyLoss = liger_kernel.transformers.cross_entropy.LigerCrossEntropyLoss
mod.DeepseekV2RMSNorm = liger_kernel.transformers.rms_norm.LigerRMSNorm
mod.DeepseekV2MLP.forward = liger_kernel.transformers.swiglu.LigerSwiGLUMLP.forward

# 2. check patch
mod = load_modeling_code(model_name)
print(mod.CrossEntropyLoss == liger_kernel.transformers.cross_entropy.LigerCrossEntropyLoss)

# 3. create and use model
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)

currently on main branch, the patch is missing in step (2) and (3)

@XuehaiPan
Copy link
Contributor Author

XuehaiPan commented Aug 29, 2024

It looks reasonable to assume the import statement returns the cached module object if it has already been imported.

Ideally, we may need to simulate this:

sys.path.insert(0, HF_MODULES_CACHE)
module = importlib.import_module(name)  # returns the same module object if it has already been imported (i.e. return sys.modules[name])
assert sys.path[0] == HF_MODULES_CACHE
del sys.path[0]

@ydshieh ydshieh self-assigned this Aug 29, 2024
@ydshieh
Copy link
Collaborator

ydshieh commented Aug 29, 2024

Hi @tmm1 Thank you for providing the detailed information.

I am wondering if it would make (more) sense for the patching code to have a factory method like

def load_modeling_code(model_name):
    with init_empty_weights():
        model = AutoModelForCausalLM.from_pretrained(
            model_name, trust_remote_code=True
        )
        return sys.modules[model.__class__.__module__]

def patch_modeling_code(model_name):

    # 1. patch modeling code
    mod = load_modeling_code(model_name)
    import liger_kernel
    mod.CrossEntropyLoss = liger_kernel.transformers.cross_entropy.LigerCrossEntropyLoss
    mod.DeepseekV2RMSNorm = liger_kernel.transformers.rms_norm.LigerRMSNorm
    mod.DeepseekV2MLP.forward = liger_kernel.transformers.swiglu.LigerSwiGLUMLP.forward

    return mod

model_name = "deepseek-ai/DeepSeek-Coder-V2-Lite-Base"
mod = patch_modeling_code(model_name)

But I am open to what @XuehaiPan have done so far.

@LysandreJik
Copy link
Member

cc @Rocketknight1 as well :)

@tmm1
Copy link
Contributor

tmm1 commented Aug 30, 2024

I am wondering if it would make (more) sense for the patching code to have a factory method like

it could work, but often the user wants to apply some patch then invoke a well known trainer framework. such a framework would not accept a factory callback

more importantly, the factory would still need to call AutoModelForCausalLM.from_pretrained two times:

  1. to download and load the remote code, so it is available inside sys.modules for patching
  2. without init_empty_weights(), and to actually use the patched code

without this PR, the second invocation will remove the patch. so it would be impossible to get a model object that actually is patched.

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 30, 2024

OK, I didn't know

# 3. create and use model
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)

in your previous comment is actually needed. Makes sense now.

I will let @Rocketknight1 to have a look on the changes of this PR too.

@Rocketknight1
Copy link
Member

It seems clean to me! I can think of two failure cases:

  1. This will not detect changes in other files (for example, if the model also has utils/functions in a separate file from the main modeling file, the hash of the main modeling file will not change and so it will not be reimported.
  2. I think this could cause a regression if users load a remote_code model, monkey-patch methods, and then reload the model to clear their changes, since now their changes will persist. However, I suspect there are no users depending on this weird behaviour, so it's not a serious problem.

Both of these are very minor issues that cannot really be fixed in this PR, and I don't think they should block it, so I'm happy with it!

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 30, 2024

Hi @XuehaiPan It would be great if you can add a corresponding test, like a (simple version of) patched module (like what @tmm1 has) will remain the same.

@XuehaiPan XuehaiPan force-pushed the remote-code-once branch 10 times, most recently from b7c4b72 to 742fda2 Compare August 31, 2024 19:09
@XuehaiPan
Copy link
Contributor Author

It would be great if you can add a corresponding test

Done.

@XuehaiPan XuehaiPan changed the title Load remote code only once Load dynamic module (remote code) only once if code isn't change Sep 1, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice. a few nit comments

@ydshieh
Copy link
Collaborator

ydshieh commented Sep 3, 2024

In my previous comment

        # The configuration file is cached in the snapshot directory. So the module file is not changed after dumping

in other files (other than tests/models/auto/test_configuration_auto.py)

here is not about configuration

I can go with the current version, just let me know if you want to add it back (with the correct comments).

What I mean is that we have to change the comment The configuration file to the actual file name (like the image processor file). Otherwise, that part of test is actually ok.

@XuehaiPan
Copy link
Contributor Author

I can go with the current version, just let me know if you want to add it back (with the correct comments).

@ydshieh I add a commit to address this.

@ydshieh
Copy link
Collaborator

ydshieh commented Sep 3, 2024

@Rocketknight1 in case if you want to take a final look. (and/or feel free to ping a core maintainer 🙏 )

@Rocketknight1
Copy link
Member

I'm happy with it, I think! The test failure seems unrelated, and I like the core goal of allowing the same model to be imported twice without wasting time, and without getting two different output classes.

cc @XuehaiPan maybe rebase onto main and see if that fixes the test?

cc @LysandreJik for core maintainer approval!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks good to me if approved by the esteemed @Rocketknight1

@Rocketknight1
Copy link
Member

@XuehaiPan let me know if you're happy to merge, or if there's anything else you want to tweak first!

@XuehaiPan
Copy link
Contributor Author

@XuehaiPan let me know if you're happy to merge, or if there's anything else you want to tweak first!

@Rocketknight1 I think this is the final version of the PR.

@Rocketknight1
Copy link
Member

Merging, in that case. Thank you for the PR!

@Rocketknight1 Rocketknight1 merged commit e1c2b69 into huggingface:main Sep 6, 2024
23 checks passed
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…gingface#33162)

* Load remote code only once

* Use hash as load indicator

* Add a new option `force_reload` for old behavior (i.e. always reload)

* Add test for dynamic module is cached

* Add more type annotations to improve code readability

* Address comments from code review
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants