Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MUSA: support ARM64 and enable dp4a .etc #11843

Merged
merged 5 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ ifdef GGML_MUSA
CXX := $(MUSA_PATH)/bin/clang++
MCC := $(CCACHE) $(MUSA_PATH)/bin/mcc

MUSAFLAGS = -x musa -mtgpu
MUSAFLAGS = -fsigned-char -x musa -mtgpu
MUSAFLAGS += $(foreach arch,$(subst ;, ,$(MUSA_ARCHITECTURES)),--cuda-gpu-arch=mp_$(arch))

ifdef GGML_CUDA_FORCE_MMQ
Expand Down
7 changes: 7 additions & 0 deletions docs/build.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@ This provides GPU acceleration using the MUSA cores of your Moore Threads MTT GP
cmake -B build -DGGML_MUSA=ON
cmake --build build --config Release
```
- For static build:

```bash
cmake -B build -DGGML_MUSA=ON \
-DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON
cmake --build build --config Release
```

The environment variable [`MUSA_VISIBLE_DEVICES`](https://docs.mthreads.com/musa-sdk/musa-sdk-doc-online/programming_guide/Z%E9%99%84%E5%BD%95/) can be used to specify which GPU(s) will be used.

Expand Down
6 changes: 3 additions & 3 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -411,13 +411,13 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i

#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)

#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
return __dp4a(a, b, c);
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
const int8_t * a8 = (const int8_t *) &a;
const int8_t * b8 = (const int8_t *) &b;
return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)

#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
}
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/cross-entropy-loss.cu
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,10 @@ static ggml_cuda_device_info ggml_cuda_init() {
GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n",
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
device_vmm ? "yes" : "no", prop.warpSize);
#elif defined(GGML_USE_MUSA)
// NOTE: MUSA will reserve some shared mem, and 24B should be enough
info.devices[id].smpbo = prop.sharedMemPerBlockOptin - 24;
info.devices[id].cc = 100*prop.major + 10*prop.minor;
#else
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = 100*prop.major + 10*prop.minor;
Expand Down Expand Up @@ -1782,9 +1786,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
}
}
#else
#ifdef GGML_USE_MUSA
GGML_ASSERT(false);
#else // !GGML_USE_MUSA
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
// use cublasGemmStridedBatchedEx
Expand Down Expand Up @@ -1827,7 +1828,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
#endif // GGML_USE_MUSA
#endif

if (dst->op_params[0] == GGML_PREC_DEFAULT) {
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include <arm_sve.h>
#endif // __ARM_FEATURE_SVE

#if defined(__ARM_NEON) && !defined(__CUDACC__)
#if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__)
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
//
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-musa/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ if (MUSAToolkit_FOUND)

set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)
foreach(SOURCE ${GGML_SOURCES_MUSA})
set(COMPILE_FLAGS "-x musa -mtgpu")
set(COMPILE_FLAGS "-fsigned-char -x musa -mtgpu")
foreach(ARCH ${MUSA_ARCHITECTURES})
set(COMPILE_FLAGS "${COMPILE_FLAGS} --cuda-gpu-arch=mp_${ARCH}")
endforeach()
Expand Down