Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xpu: support xpu backend from stock pytorch (>=2.4) #31238

Merged
merged 2 commits into from
Jun 14, 2024

Conversation

dvrogozh
Copy link
Contributor

@dvrogozh dvrogozh commented Jun 4, 2024

Fixes: #31237

XPU backend is available in the stock PyTorch starting from version 2.4, see [1]. This commit extends huggingface transformers to support XPU from both IPEX and the stock pytorch. IPEX is being tried first.

Raising this PR as WIP and Draft to facilitate further discussion around XPU backend enabling in huggingface and be able to communicate observed XPU issues back to PyTorch.

See: pytorch/pytorch#114842
Requires: huggingface/accelerate#2825

cc: @muellerzr, @EikanWang, @jgong5, @kding1, @sywangyi

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! This PR should be merged in-tandem with the accelerate one here: huggingface/accelerate#2825

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 6, 2024

Happy to see this addition 🚀 ! Just wondering if now we should be careful with the names like Pytorch XPU and IPEX XPU. (and sorry if this doesn't make sense 😅 )

@dvrogozh
Copy link
Contributor Author

dvrogozh commented Jun 7, 2024

I tried this PR (+ huggingface/accelerate#2825 on which it depends) as much as I could in the IPEX-CPU, IPEX-XPU, Pytorch-XPU, Pytorch-CPU scenarios. Tried to run some tests from accelerate and transformers and some examples from transformers. All seem to work engaging with XPU when expected. I promote these PRs from drafts for the qualified review. Let me know if any concerns or any feedback needs to be addressed.

@dvrogozh dvrogozh changed the title [WIP] xpu: support xpu backend from stock pytorch (>=2.4) xpu: support xpu backend from stock pytorch (>=2.4) Jun 7, 2024
@dvrogozh dvrogozh marked this pull request as ready for review June 7, 2024 17:04
@dvrogozh
Copy link
Contributor Author

I added one more commit to enable some tests for xpu backend.

@dvrogozh
Copy link
Contributor Author

Applied python utils/check_copies.py --fix_and_overwrite to propagate change in gpt2 to decision_transformer. This fixes failure noted by ci. Test for the later passes for xpu backend.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Overall this looks fine to make, and makes sense why we need to adjust models/decision_transformer/... (to get the ipex patches in).

PR has been merged on the accelerate side, overall this seems good to me however: should we limit the accelerate version required for the xpu support to the new accelerate version? (to come out next month)

cc @amyeroberts

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@dvrogozh
Copy link
Contributor Author

however: should we limit the accelerate version required for the xpu support to the new accelerate version? (to come out next month)

@muellerzr : By bumping accelerate version to 0.32.0?

"accelerate>=0.21.0",

@muellerzr
Copy link
Contributor

muellerzr commented Jun 13, 2024

By bumping accelerate version to 0.32.0?

We most certainly shouldn't do that :)

Actually, I think we're fairly okay, as accelerate will do a passthrough and IIUC this PR doesn't break old behavior, correct? (Basically, per my understanding if users run an old accelerate version nothing will break, right?)

The question is if we should have a flag for a minimum accelerate version if they are on the xpu branch/logic

@dvrogozh
Copy link
Contributor Author

Actually, I think we're fairly okay, as accelerate will do a passthrough and IIUC this PR doesn't break old behavior, correct?

Yes, i think so. Till users are within previous usages (with IPEXes) nothing should change for them and be compatible.

The question is if we should have a flag for a minimum accelerate version if they are on the xpu branch/logic

New accelerate is indeed required on xpu branch, otherwise there will be runtime error. So check will be useful. I will add.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

Just one question on the availability of torch.xpu which we might have to take care of in the testing utils

import torch

if is_ipex_available():
import intel_extension_for_pytorch # noqa: F401
elif not is_accelerate_available("0.33.0.dev"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@muellerzr : I added this check, since python version comparison evaluates 0.32.0.dev0 >= 0.32.0 as False, so I compared with 0.32.0.dev and I am not sure that this is correct way. Please, advice.

Also, I can't raise exception here since is_torch_xpu_available is on a generic path and will fail non-xpu cases. And without exception the error which end user gets looks quite similar to what he will get running with wrong accelerate version. Do you have an idea where to raise exception notifying user that accelerate version is wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so, I remove accelerate check from is_torch_xpu_available() since I thought this function does its job and no need to add this here. And I added check with raising exception in here:

elif is_torch_xpu_available():
device = torch.device("xpu:0")
torch.xpu.set_device(device)

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - thanks for enabling this!

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LG2M as well :) (We can revisit the 0.32.0.dev after 0.32.0 is out, I'll keep it in my notes)

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 14, 2024

Hi @dvrogozh

Thank you for this support. As mentioned by @faaany and my comment, it would be better not to include things like

BACKEND_MANUAL_SEED["xpu"]
BACKEND_EMPTY_CACHE["xpu"]

in this PR (so far) and let the user to have them define when they need to use other device.

Happy to revise that design in a separate PR.

@dvrogozh
Copy link
Contributor Author

As mentioned by @faaany and #31402 (comment), it would be better not to include things like

@ydshieh : removed. Used TRANSFORMERS_TEST_DEVICE_SPEC=spec.py on my side:

import torch

DEVICE_NAME = 'xpu'

MANUAL_SEED_FN = torch.xpu.manual_seed
EMPTY_CACHE_FN = torch.xpu.empty_cache
DEVICE_COUNT_FN = torch.xpu.device_count

@dvrogozh
Copy link
Contributor Author

fyi, I think CI failure is unrelated to this PR. Needs re-triggering?

FAILED examples/tensorflow/test_tensorflow_examples.py::ExamplesTests::test_run_image_classification - ValueError: The repository for hf-internal-testing/cats_vs_dogs_sample contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/hf-internal-testing/cats_vs_dogs_sample.

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 14, 2024

For

contains custom code which must be executed to correctly load the dataset

not related to this PR.

And you can rebase on main to include #31407 that will make that failure disappear

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again. I think everything runs smooth on your side, right? Will merge once I get a confirmation from your side 💯 !

@dvrogozh
Copy link
Contributor Author

And you can rebase on main to include #31407 that will make that failure disappear

Hm. Code is already on top of latest master and includes #31407.

I think everything runs smooth on your side, right?

For xpu backend w/ spec.py? yes. And I ran non-xpu stuff as much as I locally could.

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 14, 2024

OK, I will check CI. Thank you again for contributing 💯 !

dvrogozh added 2 commits June 14, 2024 21:00
Fixes: huggingface#31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface transformers
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Requires: huggingface/accelerate#2825
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
Note that running xpu tests requires TRANSFORMERS_TEST_DEVICE_SPEC=spec.py
passed to the test runner:

  import torch
  DEVICE_NAME = 'xpu'
  MANUAL_SEED_FN = torch.xpu.manual_seed
  EMPTY_CACHE_FN = torch.xpu.empty_cache
  DEVICE_COUNT_FN = torch.xpu.device_count

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
@ydshieh ydshieh merged commit eed9ed6 into huggingface:main Jun 14, 2024
23 checks passed
@dvrogozh
Copy link
Contributor Author

Thank you for the merge. If all will be good, should land in transformers==4.42.0 as far as I could tell.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

xpu: Support new PyTorch XPU backend (>=2.4)
7 participants