Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BitsandBytes Enablement on ROCm #1207

Conversation

pnunna93
Copy link
Contributor

Overview

This PR introduces bitsandbytes enablement on ROCm for AMD GPUs. It adds hipified versions of CUDA kernels and ops which allow the flow to route bitsandbytes API function calls to use optimized version of HIP kernels for AMD GPUs.

In the multi-backend-refactor branch, there is a proposal to separate various backends to support multiple GPUs/accelerators. The core of bitsandbytes is built on top of PyTorch and decides the API function call of individual GPUs/accelerators based on the device_type of the tensor as highlighted here. ROCm recognizes cuda device type in PyTorch and runs seamlessly without the need to change anything in the application code. Hence, this PR updates cuda backend in bitsandbytes to enable its functionality on ROCm for AMD GPUs. This PR also adds support for ROCm in the cmake build and enables key functionality of bitsandbytes on AMD GPUs.

Summary of Changes

  • Updated CUDA backend to work seamlessly on ROCm

  • Integrated HIP environment into bitsandbytes through hipified versions of CUDA kernels and ops

  • Cmake build updates for ROCm

  • Enabled key features in bitsandbytes functional and autograd api

Impact

It enables to build and support bitsandbytes on ROCm for AMD GPUs . Bitsandbytes users can port applications smoothly onto AMD gpus as it requires minimal changes on their front. In addition to this, it also ensures that ROCm changes do not affect CUDA environment, thereby not affecting existing CUDA users.

CC: @Titus-von-Koeller @matthewdouglas @arlo-phoenix

@Titus-von-Koeller Titus-von-Koeller self-assigned this May 13, 2024
Comment on lines 447 to 448
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a short explanation we can add here to explain why this is the default, and likewise below why 64 is not supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its because of warpsize difference between AMD and NVIDIA GPUs. I have added comments - 410f499

@tpimh
Copy link

tpimh commented May 15, 2024

Issue #149: can Intel Arc GPUs be supported in a similar manner?

@matthewdouglas
Copy link
Member

Issue #149: can Intel Arc GPUs be supported in a similar manner?

@tpimh There's separate work in progress for Intel. So far there's been work on CPU with IPEX (#1178, #1206) and separately a SYCL port: #747.

@tpimh
Copy link

tpimh commented May 16, 2024

Thanks! This looks promising.

I will try on both AMD and Intel Arc.

@Titus-von-Koeller
Copy link
Collaborator

Dear @pnunna93,

thanks to you and your team for the amazing work. We're super excited about this and I'm very happy with what I'm seeing at an initial superficial review.

It would be great to have the AMD runner available relatively soon, otherwise it remains quite messy and work intensive to keep track of the correctness of the various backend implementations. Please let me know what I can do to help and I'll make sure to pull the right strings.

Regarding the review, as communicated in Slack, I have to first focus on wrapping up my deep dive in evaluating tensor-driven dispatch by integration with the PyTorch dispatcher via the torch.library APIs. I don't see any reason to not merge your PR, but I need to take another thorough look and I think it would be helpful for everyone to have clarity on the backend abstraction / dispatch mechanism asap and am therefore prioritizing that; so everyone can then refactor their code to account for that.

In that context, one important question came up:

Our paged optimizers use CUDA unified memory, as described in detail here.

Is that feature available on ROCm devices in one way or another? This would be quite important to understand for my analysis, as the handling of unified memory in relation to PyTorch is one of my last open questions. It's quite a special case, because it's a cornerstone of preventing OOMs in low resource environments -- a key feature for our user group -- and is not implemented/ accounted for in PyTorch and, therefore, we use that feature directly through CUDA related APIs: The underlying CUDA function is cudaMemPrefetchAsync AFAICT.

Thanks 🤗 and congrats on the great work in this PR, we're super excited about this ❤️

@Titus-von-Koeller
Copy link
Collaborator

Dear @pnunna93 et al,

Unfortunately we're (mostly me alone) quite resource constrained and humbled by the workload associated with the multi-backend-refactor. I just talked with my colleague @younesbelkada about the topic how to best handle the next steps.

We both took a look at this PR and the one from Intel and think that at first glance everything looks really good. At this time, both me and Younes are not in a position to give detailed feedback and I need to focus on concretizing the path forward on how to integrate with the PyTorch dispatcher (tensor driven dispatch, as requested) through the torch.library Python-level APIs. After extensive research and yesterday's consultation with 3 PyTorch devs at Meta that are experts on the topic I need to focus on making this new input concrete.

However, for the purpose of iterative progress (as agreed in our prior conversations), we've decided to already go ahead and merge both the open Intel and AMD branches into multi-backend-refactor, where interested parties can then compile from source and give the new functionality (we're so excited and grateful about this!) a thorough testing.

Once we've made some progress on the torch.library based refactor, I'll next focus on enabling the nightly releases for that branch as well. We're also looking forward to your feedback on the this torch.library / tensor-driven dispatch topic once the code is there on the basis of which to discuss (and refactor the backend specific code towards that new target, after we agreed with all of you that this is the right path).

Among other things, there's also been extensive ongoing work in the background on things like moving BNB to a new independent/non-profit Github org, but under the umbrella of Hugging Face and the support of their infra team for managing the complexities of the CI/CD backend and runners. Also, we're working to make Github runners for the different hardware platforms a reality (thanks for your help on that!).

Thanks again for the good work and active collaboration! ❤️ 🚀

@Titus-von-Koeller Titus-von-Koeller merged commit eb3b816 into bitsandbytes-foundation:multi-backend-refactor May 24, 2024
1 of 2 checks passed
@Titus-von-Koeller
Copy link
Collaborator

Titus-von-Koeller commented May 24, 2024

P.S. Also see this: README: asking for help from volunteer alpha testers

Let us know if you have further thoughts on this and how you think it's best to communicate about this.

@pnunna93 pnunna93 mentioned this pull request Jul 8, 2024
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants