Skip to content

Commit

Permalink
MUSA: support ARM64 and enable dp4a .etc (#11843)
Browse files Browse the repository at this point in the history
* MUSA:  support ARM64 and enable __dp4a .etc

* fix cross entropy loss op for musa

* update

* add cc info log for musa

* add comment for the MUSA .cc calculation block

---------

Co-authored-by: Bodhi Hu <huaishun.hu@mthreads.com>
  • Loading branch information
BodhiHu and Bodhi Hu authored Feb 21, 2025
1 parent ee02ad0 commit 0b3863f
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ ifdef GGML_MUSA
CXX := $(MUSA_PATH)/bin/clang++
MCC := $(CCACHE) $(MUSA_PATH)/bin/mcc

MUSAFLAGS = -x musa -mtgpu
MUSAFLAGS = -fsigned-char -x musa -mtgpu
MUSAFLAGS += $(foreach arch,$(subst ;, ,$(MUSA_ARCHITECTURES)),--cuda-gpu-arch=mp_$(arch))

ifdef GGML_CUDA_FORCE_MMQ
Expand Down
8 changes: 8 additions & 0 deletions docs/build.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,14 @@ This provides GPU acceleration using the MUSA cores of your Moore Threads MTT GP
cmake --build build --config Release
```

For static build:

```bash
cmake -B build -DGGML_MUSA=ON \
-DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON
cmake --build build --config Release
```

The environment variable [`MUSA_VISIBLE_DEVICES`](https://docs.mthreads.com/musa-sdk/musa-sdk-doc-online/programming_guide/Z%E9%99%84%E5%BD%95/) can be used to specify which GPU(s) will be used.

The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted.
Expand Down
6 changes: 3 additions & 3 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -411,13 +411,13 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i

#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)

#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
return __dp4a(a, b, c);
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
const int8_t * a8 = (const int8_t *) &a;
const int8_t * b8 = (const int8_t *) &b;
return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)

#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
}
Expand Down
10 changes: 5 additions & 5 deletions ggml/src/ggml-cuda/cross-entropy-loss.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);

if (nbytes_shared <= smpbo) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
} else {
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
Expand Down Expand Up @@ -175,13 +175,13 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;

if (nbytes_shared <= smpbo) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
} else {
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
Expand Down
10 changes: 6 additions & 4 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,12 @@ static ggml_cuda_device_info ggml_cuda_init() {
GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n",
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
device_vmm ? "yes" : "no", prop.warpSize);
#elif defined(GGML_USE_MUSA)
// TODO: refine the .cc to reflect MUSA's actual CC capabilities
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = 100*prop.major + 10*prop.minor;
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
#else
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = 100*prop.major + 10*prop.minor;
Expand Down Expand Up @@ -1782,9 +1788,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
}
}
#else
#ifdef GGML_USE_MUSA
GGML_ASSERT(false);
#else // !GGML_USE_MUSA
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
// use cublasGemmStridedBatchedEx
Expand Down Expand Up @@ -1827,7 +1830,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
#endif // GGML_USE_MUSA
#endif

if (dst->op_params[0] == GGML_PREC_DEFAULT) {
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include <arm_sve.h>
#endif // __ARM_FEATURE_SVE

#if defined(__ARM_NEON) && !defined(__CUDACC__)
#if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__)
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
//
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-musa/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ if (MUSAToolkit_FOUND)

set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)
foreach(SOURCE ${GGML_SOURCES_MUSA})
set(COMPILE_FLAGS "-x musa -mtgpu")
set(COMPILE_FLAGS "-fsigned-char -x musa -mtgpu")
foreach(ARCH ${MUSA_ARCHITECTURES})
set(COMPILE_FLAGS "${COMPILE_FLAGS} --cuda-gpu-arch=mp_${ARCH}")
endforeach()
Expand Down

0 comments on commit 0b3863f

Please sign in to comment.