-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Load dynamic module (remote code) only once if code isn't change #33162
Conversation
Hi @XuehaiPan, thank you for taking the action. My first question is: what if the target (remote code) is changed since the last time it has been loaded. I don't know very well what is the context mentioned in #30370 (comment), but IIRD, that happens within a single python process (i.e. multiple loading of the same module). And if that is the case, my question on top of this comment have to be considered, right? Maybe @tmm1 could explain a bit more the situation, ideally with a code snippet to demonstrate the issue. |
b9a0d02
to
1b86abe
Compare
@ydshieh Thanks for raising this. I change this indicator to the hash of the source code. |
Nice. I still prefer for @tmm1 to elaborate the issue a bit more though. But in the meantime, could you tell me more about the usage of |
It's just for thread safety. We are modifying the global variable |
sure, here is a simple example: import sys
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM
def load_modeling_code(model_name):
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True
)
return sys.modules[model.__class__.__module__]
model_name = "deepseek-ai/DeepSeek-Coder-V2-Lite-Base"
# 1. patch modeling code
mod = load_modeling_code(model_name)
import liger_kernel
mod.CrossEntropyLoss = liger_kernel.transformers.cross_entropy.LigerCrossEntropyLoss
mod.DeepseekV2RMSNorm = liger_kernel.transformers.rms_norm.LigerRMSNorm
mod.DeepseekV2MLP.forward = liger_kernel.transformers.swiglu.LigerSwiGLUMLP.forward
# 2. check patch
mod = load_modeling_code(model_name)
print(mod.CrossEntropyLoss == liger_kernel.transformers.cross_entropy.LigerCrossEntropyLoss)
# 3. create and use model
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) currently on main branch, the patch is missing in step (2) and (3) |
It looks reasonable to assume the import statement returns the cached module object if it has already been imported. Ideally, we may need to simulate this: sys.path.insert(0, HF_MODULES_CACHE)
module = importlib.import_module(name) # returns the same module object if it has already been imported (i.e. return sys.modules[name])
assert sys.path[0] == HF_MODULES_CACHE
del sys.path[0] |
Hi @tmm1 Thank you for providing the detailed information. I am wondering if it would make (more) sense for the patching code to have a factory method like def load_modeling_code(model_name):
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True
)
return sys.modules[model.__class__.__module__]
def patch_modeling_code(model_name):
# 1. patch modeling code
mod = load_modeling_code(model_name)
import liger_kernel
mod.CrossEntropyLoss = liger_kernel.transformers.cross_entropy.LigerCrossEntropyLoss
mod.DeepseekV2RMSNorm = liger_kernel.transformers.rms_norm.LigerRMSNorm
mod.DeepseekV2MLP.forward = liger_kernel.transformers.swiglu.LigerSwiGLUMLP.forward
return mod
model_name = "deepseek-ai/DeepSeek-Coder-V2-Lite-Base"
mod = patch_modeling_code(model_name) But I am open to what @XuehaiPan have done so far. |
cc @Rocketknight1 as well :) |
it could work, but often the user wants to apply some patch then invoke a well known trainer framework. such a framework would not accept a factory callback more importantly, the factory would still need to call
without this PR, the second invocation will remove the patch. so it would be impossible to get a |
OK, I didn't know
in your previous comment is actually needed. Makes sense now. I will let @Rocketknight1 to have a look on the changes of this PR too. |
It seems clean to me! I can think of two failure cases:
Both of these are very minor issues that cannot really be fixed in this PR, and I don't think they should block it, so I'm happy with it! |
Hi @XuehaiPan It would be great if you can add a corresponding test, like a (simple version of) patched module (like what @tmm1 has) will remain the same. |
b7c4b72
to
742fda2
Compare
Done. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice. a few nit comments
742fda2
to
178078e
Compare
178078e
to
106be50
Compare
In my previous comment
in other files (other than
I can go with the current version, just let me know if you want to add it back (with the correct comments). What I mean is that we have to change the comment |
@ydshieh I add a commit to address this. |
@Rocketknight1 in case if you want to take a final look. (and/or feel free to ping a core maintainer 🙏 ) |
I'm happy with it, I think! The test failure seems unrelated, and I like the core goal of allowing the same model to be imported twice without wasting time, and without getting two different output classes. cc @XuehaiPan maybe rebase onto cc @LysandreJik for core maintainer approval! |
57b2e06
to
777300a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This looks good to me if approved by the esteemed @Rocketknight1
@XuehaiPan let me know if you're happy to merge, or if there's anything else you want to tweak first! |
@Rocketknight1 I think this is the final version of the PR. |
Merging, in that case. Thank you for the PR! |
…gingface#33162) * Load remote code only once * Use hash as load indicator * Add a new option `force_reload` for old behavior (i.e. always reload) * Add test for dynamic module is cached * Add more type annotations to improve code readability * Address comments from code review
What does this PR do?
Fixes #30370 (comment)
Add an indicator
__transformers_module_hash__
to the remote code module.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ydshieh