Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update changes #9

Merged
merged 2 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def prod(iterable):
if COMPILED_WITH_CUDA:
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {}
str2optimizer32bit["adam"] = (lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16) #, lib.cadam32bit_grad_bf16)
str2optimizer32bit["adam"] = (lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16)
str2optimizer32bit["momentum"] = (
lib.cmomentum32bit_grad_32,
lib.cmomentum32bit_grad_16,
Expand All @@ -39,7 +39,7 @@ def prod(iterable):
lib.crmsprop32bit_grad_32,
lib.crmsprop32bit_grad_16,
)
str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16) #, lib.clion32bit_grad_bf16)
str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16, lib.clion32bit_grad_bf16)
str2optimizer32bit["adagrad"] = (
lib.cadagrad32bit_grad_32,
lib.cadagrad32bit_grad_16,
Expand Down Expand Up @@ -75,7 +75,7 @@ def prod(iterable):
str2optimizer8bit_blockwise["adam"] = (
lib.cadam_8bit_blockwise_grad_fp32,
lib.cadam_8bit_blockwise_grad_fp16,
#lib.cadam_8bit_blockwise_grad_bf16,
lib.cadam_8bit_blockwise_grad_bf16,
)
str2optimizer8bit_blockwise["momentum"] = (
lib.cmomentum_8bit_blockwise_grad_fp32,
Expand All @@ -88,7 +88,7 @@ def prod(iterable):
str2optimizer8bit_blockwise["lion"] = (
lib.clion_8bit_blockwise_grad_fp32,
lib.clion_8bit_blockwise_grad_fp16,
#lib.clion_8bit_blockwise_grad_bf16,
lib.clion_8bit_blockwise_grad_bf16,
)
str2optimizer8bit_blockwise["adagrad"] = (
lib.cadagrad_8bit_blockwise_grad_fp32,
Expand Down
13 changes: 6 additions & 7 deletions csrc/kernels.hip
Original file line number Diff line number Diff line change
Expand Up @@ -3836,7 +3836,7 @@ MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
MAKE_PreconditionOptimizer32bit1State(LION, half)
MAKE_PreconditionOptimizer32bit1State(LION, float)
//MAKE_PreconditionOptimizer32bit1State(LION, hip_bfloat16)
MAKE_PreconditionOptimizer32bit1State(LION, hip_bfloat16)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)

Expand All @@ -3850,7 +3850,7 @@ MAKE_Optimizer32bit1State(RMSPROP, half)
MAKE_Optimizer32bit1State(RMSPROP, float)
MAKE_Optimizer32bit1State(LION, half)
MAKE_Optimizer32bit1State(LION, float)
//MAKE_Optimizer32bit1State(LION, hip_bfloat16)
MAKE_Optimizer32bit1State(LION, hip_bfloat16)
MAKE_Optimizer32bit1State(ADAGRAD, half)
MAKE_Optimizer32bit1State(ADAGRAD, float)

Expand All @@ -3862,16 +3862,15 @@ template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8

MAKE_PreconditionOptimizer32bit2State(ADAM, float)
MAKE_PreconditionOptimizer32bit2State(ADAM, half)
//MAKE_PreconditionOptimizer32bit2State(ADAM, hip_bfloat16)
MAKE_PreconditionOptimizer32bit2State(ADAM, hip_bfloat16)

template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
/*
template __global__ void kOptimizer32bit2State<hip_bfloat16, ADAM>(hip_bfloat16* g, hip_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
*/


#define MAKE_PreconditionStatic8bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
Expand Down Expand Up @@ -4040,7 +4039,7 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block

MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 2048, 8)
//MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, hip_bfloat16, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, hip_bfloat16, 2048, 8)


#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
Expand All @@ -4059,6 +4058,6 @@ MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8)
//MAKE_OptimizerStatic8bit1StateBlockwise(LION, hip_bfloat16, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, hip_bfloat16, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)
8 changes: 4 additions & 4 deletions csrc/ops.hip
Original file line number Diff line number Diff line change
Expand Up @@ -966,14 +966,14 @@ template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \

MAKE_optimizer32bit(ADAM, half)
MAKE_optimizer32bit(ADAM, float)
//MAKE_optimizer32bit(ADAM, hip_bfloat16)
MAKE_optimizer32bit(ADAM, hip_bfloat16)
MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit(MOMENTUM, float)
MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit(RMSPROP, float)
MAKE_optimizer32bit(LION, half)
MAKE_optimizer32bit(LION, float)
//MAKE_optimizer32bit(LION, hip_bfloat16)
MAKE_optimizer32bit(LION, hip_bfloat16)
MAKE_optimizer32bit(ADAGRAD, half)
MAKE_optimizer32bit(ADAGRAD, float)

Expand Down Expand Up @@ -1009,11 +1009,11 @@ MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(half, LION);
MAKE_optimizerStatic8bitBlockwise(float, LION);
//MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, LION);
MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, LION);
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);

template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);

//MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADAM);
MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADAM);
16 changes: 8 additions & 8 deletions csrc/pythonInterface.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ MAKE_FUNC32(momentum, MOMENTUM, float, 32)
MAKE_FUNC32(momentum, MOMENTUM, half, 16)
MAKE_FUNC32(adam, ADAM, float, fp32)
MAKE_FUNC32(adam, ADAM, half, fp16)
//MAKE_FUNC32(adam, ADAM, hip_bfloat16, bf16)
MAKE_FUNC32(adam, ADAM, hip_bfloat16, bf16)
MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
MAKE_FUNC32(lion, LION, float, fp32)
MAKE_FUNC32(lion, LION, half, fp16)
//MAKE_FUNC32(lion, LION, hip_bfloat16, bf16)
MAKE_FUNC32(lion, LION, hip_bfloat16, bf16)
MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
MAKE_FUNC32(adagrad, ADAGRAD, half, 16)

Expand Down Expand Up @@ -105,10 +105,10 @@ MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32)
//MAKE_BLOCKWISE8(adam, ADAM, hip_bfloat16, bf16)
MAKE_BLOCKWISE8(adam, ADAM, hip_bfloat16, bf16)
MAKE_BLOCKWISE8(lion, LION, half, fp16)
MAKE_BLOCKWISE8(lion, LION, float, fp32)
//MAKE_BLOCKWISE8(lion, LION, hip_bfloat16, bf16)
MAKE_BLOCKWISE8(lion, LION, hip_bfloat16, bf16)


void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
Expand Down Expand Up @@ -272,14 +272,14 @@ extern "C"

MAKE_CFUNC32(adam, float, fp32)
MAKE_CFUNC32(adam, half, fp16)
//MAKE_CFUNC32(adam, hip_bfloat16, bf16)
MAKE_CFUNC32(adam, hip_bfloat16, bf16)
MAKE_CFUNC32(momentum, float, 32)
MAKE_CFUNC32(momentum, half, 16)
MAKE_CFUNC32(rmsprop, float, 32)
MAKE_CFUNC32(rmsprop, half, 16)
MAKE_CFUNC32(lion, float, fp32)
MAKE_CFUNC32(lion, half, fp16)
//MAKE_CFUNC32(lion, hip_bfloat16, bf16)
MAKE_CFUNC32(lion, hip_bfloat16, bf16)
MAKE_CFUNC32(adagrad, float, 32)
MAKE_CFUNC32(adagrad, half, 16)

Expand Down Expand Up @@ -319,10 +319,10 @@ extern "C"
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32)
//MAKE_CBLOCKWISE8(adam, ADAM, hip_bfloat16, bf16)
MAKE_CBLOCKWISE8(adam, ADAM, hip_bfloat16, bf16)
MAKE_CBLOCKWISE8(lion, LION, half, fp16)
MAKE_CBLOCKWISE8(lion, LION, float, fp32)
//MAKE_CBLOCKWISE8(lion, LION, hip_bfloat16, bf16)
MAKE_CBLOCKWISE8(lion, LION, hip_bfloat16, bf16)

void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
Expand Down
2 changes: 0 additions & 2 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def rm_path(path):
optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion']
values = list(product(dim1, dim2, gtype, optimizer_names))
names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values]
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip()
Expand Down Expand Up @@ -253,7 +252,6 @@ def test_global_config(dim1, dim2, gtype):
]


@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip()
Expand Down
Loading