Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auto-detect device when no device is passed to pipeline #31398

Merged
merged 5 commits into from
Jun 19, 2024

Conversation

faaany
Copy link
Contributor

@faaany faaany commented Jun 13, 2024

What does this PR do?

Currently, if no device is passed to pipeline, the model will stay on CPU. This PR makes it possible to auto-detect the underlying hardware environment and move the model to the corresponding accelerator.

import torch 
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

model_id = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
 
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, do_sample=False, max_new_tokens=300, num_beams=1)
print(pipe.model.device)
### before fix: cpu
### after fix: cuda

@amyeroberts and @muellerzr

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this @faaany

I'm not sure this is something we want to do - it can both lead to unexpected behaviour for the user (the default is changing), and doesn't match with the rest of the library (we don't automatically put a model onto a GPU if it's available in the environment when instantiating it). WDYT @ArthurZucker?

faaany and others added 2 commits June 13, 2024 17:25
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@faaany
Copy link
Contributor Author

faaany commented Jun 13, 2024

@yao-matrix

@faaany
Copy link
Contributor Author

faaany commented Jun 18, 2024

Hi @ArthurZucker , any thoughts on this PR?

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.02$, also don't think we should do this automatically for the user. Feels a bit too "magical", even if it's useful. I'd rather switch to a good default to use the device if possible instead so it's exposed at the top level, if we were to.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@amyeroberts
Copy link
Collaborator

How about we emit a warning if we detect there's a device but none is set for the pipeline?

@muellerzr
Copy link
Contributor

A warning sounds good, and at this level I feel users would pay attention to it (rather than it just being bloat on the warning logs etc)

@faaany
Copy link
Contributor Author

faaany commented Jun 18, 2024

Thx for the suggestion! I will update the PR.

@faaany
Copy link
Contributor Author

faaany commented Jun 18, 2024

I think we need to re-trigger the CI

@amyeroberts
Copy link
Collaborator

@faaany Could you try rebasing to include upstream changes from main?

@faaany
Copy link
Contributor Author

faaany commented Jun 19, 2024

@faaany Could you try rebasing to include upstream changes from main?

sure, done.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

@amyeroberts amyeroberts merged commit 4144c35 into huggingface:main Jun 19, 2024
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants