diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6fd2570b8..ec25be2a5 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -30,7 +30,7 @@ def prod(iterable): if COMPILED_WITH_CUDA: """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = {} - str2optimizer32bit["adam"] = (lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16) #, lib.cadam32bit_grad_bf16) + str2optimizer32bit["adam"] = (lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16) str2optimizer32bit["momentum"] = ( lib.cmomentum32bit_grad_32, lib.cmomentum32bit_grad_16, @@ -39,7 +39,7 @@ def prod(iterable): lib.crmsprop32bit_grad_32, lib.crmsprop32bit_grad_16, ) - str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16) #, lib.clion32bit_grad_bf16) + str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16, lib.clion32bit_grad_bf16) str2optimizer32bit["adagrad"] = ( lib.cadagrad32bit_grad_32, lib.cadagrad32bit_grad_16, @@ -75,7 +75,7 @@ def prod(iterable): str2optimizer8bit_blockwise["adam"] = ( lib.cadam_8bit_blockwise_grad_fp32, lib.cadam_8bit_blockwise_grad_fp16, - #lib.cadam_8bit_blockwise_grad_bf16, + lib.cadam_8bit_blockwise_grad_bf16, ) str2optimizer8bit_blockwise["momentum"] = ( lib.cmomentum_8bit_blockwise_grad_fp32, @@ -88,7 +88,7 @@ def prod(iterable): str2optimizer8bit_blockwise["lion"] = ( lib.clion_8bit_blockwise_grad_fp32, lib.clion_8bit_blockwise_grad_fp16, - #lib.clion_8bit_blockwise_grad_bf16, + lib.clion_8bit_blockwise_grad_bf16, ) str2optimizer8bit_blockwise["adagrad"] = ( lib.cadagrad_8bit_blockwise_grad_fp32, diff --git a/csrc/kernels.hip b/csrc/kernels.hip index edcde6306..458f7f1c0 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -3836,7 +3836,7 @@ MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) MAKE_PreconditionOptimizer32bit1State(LION, half) MAKE_PreconditionOptimizer32bit1State(LION, float) -//MAKE_PreconditionOptimizer32bit1State(LION, hip_bfloat16) +MAKE_PreconditionOptimizer32bit1State(LION, hip_bfloat16) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) @@ -3850,7 +3850,7 @@ MAKE_Optimizer32bit1State(RMSPROP, half) MAKE_Optimizer32bit1State(RMSPROP, float) MAKE_Optimizer32bit1State(LION, half) MAKE_Optimizer32bit1State(LION, float) -//MAKE_Optimizer32bit1State(LION, hip_bfloat16) +MAKE_Optimizer32bit1State(LION, hip_bfloat16) MAKE_Optimizer32bit1State(ADAGRAD, half) MAKE_Optimizer32bit1State(ADAGRAD, float) @@ -3862,16 +3862,15 @@ template __global__ void kPreconditionOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -/* template __global__ void kOptimizer32bit2State(hip_bfloat16* g, hip_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -*/ + #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ template __global__ void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ @@ -4040,7 +4039,7 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise(gtype* g, gtype* p, \ MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, float) -//MAKE_optimizer32bit(ADAM, hip_bfloat16) +MAKE_optimizer32bit(ADAM, hip_bfloat16) MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, float) MAKE_optimizer32bit(LION, half) MAKE_optimizer32bit(LION, float) -//MAKE_optimizer32bit(LION, hip_bfloat16) +MAKE_optimizer32bit(LION, hip_bfloat16) MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, float) @@ -1009,11 +1009,11 @@ MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); MAKE_optimizerStatic8bitBlockwise(half, LION); MAKE_optimizerStatic8bitBlockwise(float, LION); -//MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, LION); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, LION); MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); -//MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADAM); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADAM); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index ba551dcc3..c06da38b6 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -60,12 +60,12 @@ MAKE_FUNC32(momentum, MOMENTUM, float, 32) MAKE_FUNC32(momentum, MOMENTUM, half, 16) MAKE_FUNC32(adam, ADAM, float, fp32) MAKE_FUNC32(adam, ADAM, half, fp16) -//MAKE_FUNC32(adam, ADAM, hip_bfloat16, bf16) +MAKE_FUNC32(adam, ADAM, hip_bfloat16, bf16) MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, half, 16) MAKE_FUNC32(lion, LION, float, fp32) MAKE_FUNC32(lion, LION, half, fp16) -//MAKE_FUNC32(lion, LION, hip_bfloat16, bf16) +MAKE_FUNC32(lion, LION, hip_bfloat16, bf16) MAKE_FUNC32(adagrad, ADAGRAD, float, 32) MAKE_FUNC32(adagrad, ADAGRAD, half, 16) @@ -105,10 +105,10 @@ MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16) MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32) MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16) MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32) -//MAKE_BLOCKWISE8(adam, ADAM, hip_bfloat16, bf16) +MAKE_BLOCKWISE8(adam, ADAM, hip_bfloat16, bf16) MAKE_BLOCKWISE8(lion, LION, half, fp16) MAKE_BLOCKWISE8(lion, LION, float, fp32) -//MAKE_BLOCKWISE8(lion, LION, hip_bfloat16, bf16) +MAKE_BLOCKWISE8(lion, LION, hip_bfloat16, bf16) void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } @@ -272,14 +272,14 @@ extern "C" MAKE_CFUNC32(adam, float, fp32) MAKE_CFUNC32(adam, half, fp16) - //MAKE_CFUNC32(adam, hip_bfloat16, bf16) + MAKE_CFUNC32(adam, hip_bfloat16, bf16) MAKE_CFUNC32(momentum, float, 32) MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(rmsprop, float, 32) MAKE_CFUNC32(rmsprop, half, 16) MAKE_CFUNC32(lion, float, fp32) MAKE_CFUNC32(lion, half, fp16) - //MAKE_CFUNC32(lion, hip_bfloat16, bf16) + MAKE_CFUNC32(lion, hip_bfloat16, bf16) MAKE_CFUNC32(adagrad, float, 32) MAKE_CFUNC32(adagrad, half, 16) @@ -319,10 +319,10 @@ extern "C" MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32) - //MAKE_CBLOCKWISE8(adam, ADAM, hip_bfloat16, bf16) + MAKE_CBLOCKWISE8(adam, ADAM, hip_bfloat16, bf16) MAKE_CBLOCKWISE8(lion, LION, half, fp16) MAKE_CBLOCKWISE8(lion, LION, float, fp32) - //MAKE_CBLOCKWISE8(lion, LION, hip_bfloat16, bf16) + MAKE_CBLOCKWISE8(lion, LION, hip_bfloat16, bf16) void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } diff --git a/tests/test_optim.py b/tests/test_optim.py index 2724436e5..c373a4f14 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -110,7 +110,6 @@ def rm_path(path): optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion'] values = list(product(dim1, dim2, gtype, optimizer_names)) names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values] -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) def test_optimizer32bit(dim1, dim2, gtype, optim_name): if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip() @@ -253,7 +252,6 @@ def test_global_config(dim1, dim2, gtype): ] -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) def test_optimizer8bit(dim1, dim2, gtype, optim_name): if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip()