Skip to content

Commit

Permalink
Initial CMake support
Browse files Browse the repository at this point in the history
- For initial win32 support
- Change pythonInterface file extension to cpp because it is C++
- Requires external dependency pthread-win32 for Windows builds
  • Loading branch information
niclimcy committed Mar 25, 2023
1 parent 49a0425 commit 4c1d308
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 6 deletions.
78 changes: 78 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
cmake_minimum_required(VERSION 3.26.0)

# Set cuda12x supported archs only for now
# This has to be set before project()
set(CMAKE_CUDA_ARCHITECTURES "75;80;86;89;90")

project(bitsandbytes LANGUAGES CXX CUDA)

find_package(CUDAToolkit REQUIRED)

# Set global flags
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CUDA_STANDARD 14)
add_compile_definitions(BUILD_CUDA)

if(WIN32)
# Mute warnings
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -diag-suppress=177")

# Export all symbols
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif()

# Weird MSVC hacks
if(MSVC)
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /NODEFAULTLIB:msvcprtd /NODEFAULTLIB:MSVCRTD /NODEFAULTLIB:LIBCMT")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
endif()

# pthread does not exist on Windows
if(WIN32)
include(ExternalProject)
ExternalProject_Add(pthread-win32
GIT_REPOSITORY https://github.com/GerHobbelt/pthread-win32
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${EXTERNAL_INSTALL_LOCATION}
)

include_directories(${EXTERNAL_INSTALL_LOCATION}/include)
link_directories(${EXTERNAL_INSTALL_LOCATION}/lib)
endif()

# Add csrc files
add_library(bitsandbytes SHARED
csrc/ops.cu
csrc/kernels.cu
csrc/common.cpp
csrc/cpu_ops.cpp
csrc/pythonInterface.cpp)

target_include_directories(bitsandbytes PUBLIC
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
${CMAKE_CURRENT_SOURCE_DIR}/csrc
${CMAKE_CURRENT_SOURCE_DIR}/include
${CUDA_TOOLKIT_ROOT_DIR}/include)

# Pass options to NVCC
target_compile_options(bitsandbytes PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:
--use_fast_math
-Xptxas=-v
-dc
>)

set_target_properties(
bitsandbytes
PROPERTIES
CUDA_SEPARABLE_COMPILATION ON)

if(WIN32)
# pthread does not exist on Windows
add_dependencies(bitsandbytes pthread-win32)
target_link_libraries(bitsandbytes CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cusparse pthreadVC3)
else()
target_link_libraries(bitsandbytes CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cusparse)
endif()

# Set the output name of the CUDA library
set_target_properties(bitsandbytes PROPERTIES OUTPUT_NAME "bitsandbytes_cuda120")
12 changes: 6 additions & 6 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2664,12 +2664,12 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);

template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);

template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
Expand Down
2 changes: 2 additions & 0 deletions csrc/ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

#include <stdio.h>
#include <iostream>
#ifndef _MSC_VER
#include <unistd.h>
#endif
#include <assert.h>

#include <cuda_runtime_api.h>
Expand Down
File renamed without changes.
6 changes: 6 additions & 0 deletions include/SIMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ struct IVec;
template <InstrSet I, class T>
struct FVec1;

template <InstrSet I, class T>
struct InstrFloatTraits
{
typedef T vec_t;
};

template <> struct InstrIntTraits<SSE>
{
typedef __m128i vec_t;
Expand Down

0 comments on commit 4c1d308

Please sign in to comment.