Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize RWKV6 Operator Naming and Implement Multi-core CPU/ SYCL Acceleration #10133

Merged
merged 33 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f66c75a
rwkv6: rename to wkv6
uniartisan Nov 1, 2024
b4254c5
rwkv6: support avx2 avx512 armv8 armv9
uniartisan Nov 1, 2024
e198f7b
rwkv6: update cuda file name
uniartisan Nov 1, 2024
3f75f12
rwkv6: rename params
uniartisan Nov 1, 2024
2fc42b6
wkv on sycl
uniartisan Nov 2, 2024
bee1cec
sycl: add some ops
uniartisan Nov 2, 2024
1c58096
sycl: Enhance OP support judgment
uniartisan Nov 3, 2024
042c3e0
Merge branch 'ggerganov:master' into master
uniartisan Nov 3, 2024
811aa87
wkv6: drop armv9 and tranfer to GGML style
uniartisan Nov 3, 2024
4d26631
flake.lock: Update (#10146)
ggerganov Nov 3, 2024
b189630
metal : minor fixup in FA kernel (#10143)
ggerganov Nov 3, 2024
89812b1
ggml : move CPU backend to a separate file (#10144)
slaren Nov 3, 2024
8050d02
metal : fix minor string leaks (ggml/1004)
pminev Nov 1, 2024
eb5711c
cmake : make it possible linking ggml as external lib (ggml/1003)
ykhrustalev Nov 2, 2024
153251f
sync : ggml
ggerganov Nov 4, 2024
5f79214
Merge branch 'ggerganov:master' into master
uniartisan Nov 4, 2024
61c665b
fix: update changes to upstream
uniartisan Nov 4, 2024
9ea34a7
fix: add defualt
uniartisan Nov 4, 2024
8c7b4ec
Update ggml/src/ggml-sycl/outprod.cpp
uniartisan Nov 4, 2024
bb0685f
Update ggml/src/ggml-sycl/wkv6.cpp
uniartisan Nov 4, 2024
81cb301
update the function to use appropriate types
uniartisan Nov 4, 2024
a878502
fix define error
uniartisan Nov 4, 2024
b816024
Update ggml/src/ggml-cpu.c
uniartisan Nov 4, 2024
72e4432
add appropriate asserts
uniartisan Nov 4, 2024
35a1a2d
move element-wise functions outside
uniartisan Nov 4, 2024
6a1e977
Update ggml/src/ggml-sycl/concat.cpp
uniartisan Nov 4, 2024
a749ba7
put the declaration outside the loop
uniartisan Nov 4, 2024
4693b46
rewrite to be more inline with the common pattern for distributing th…
uniartisan Nov 4, 2024
4574795
use recommended way GGML_TENSOR_LOCALS
uniartisan Nov 4, 2024
acb1b9d
Merge branch 'ggerganov:master' into master
uniartisan Nov 4, 2024
e264c35
remove some codes
uniartisan Nov 4, 2024
623db3b
update lint
uniartisan Nov 5, 2024
98e070c
Merge branch 'ggerganov:master' into master
uniartisan Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/backend/SYCL.md
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ found 2 SYCL devices:

|Chosen Device ID|Setting|
|-|-|
|0|`export ONEAPI_DEVICE_SELECTOR="level_zero:1"` or no action|
|0|`export ONEAPI_DEVICE_SELECTOR="level_zero:0"` or no action|
|1|`export ONEAPI_DEVICE_SELECTOR="level_zero:1"`|
|0 & 1|`export ONEAPI_DEVICE_SELECTOR="level_zero:0;level_zero:1"`|

Expand Down
4 changes: 2 additions & 2 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ extern "C" {
GGML_OP_WIN_UNPART,
GGML_OP_GET_REL_POS,
GGML_OP_ADD_REL_POS,
GGML_OP_RWKV_WKV,
GGML_OP_RWKV_WKV6,

GGML_OP_UNARY,

Expand Down Expand Up @@ -1819,7 +1819,7 @@ extern "C" {
struct ggml_tensor * pw,
struct ggml_tensor * ph);

GGML_API struct ggml_tensor * ggml_rwkv_wkv(
GGML_API struct ggml_tensor * ggml_rwkv_wkv6(
struct ggml_context * ctx,
struct ggml_tensor * k,
struct ggml_tensor * v,
Expand Down
208 changes: 160 additions & 48 deletions ggml/src/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -11642,79 +11642,191 @@ static void ggml_compute_forward_add_rel_pos(
}
}

// ggml_compute_forward_rwkv_wkv
// ggml_compute_forward_rwkv_wkv6

static void ggml_compute_forward_rwkv_wkv_f32(
static void ggml_compute_forward_rwkv_wkv6_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const size_t T = dst->src[1]->ne[3];
const size_t C = dst->ne[0];
const size_t H = dst->src[1]->ne[2];
const size_t n_seqs = dst->src[5]->ne[1];
const int64_t T = dst->src[1]->ne[3];
uniartisan marked this conversation as resolved.
Show resolved Hide resolved
const int64_t C = dst->ne[0];
const int64_t HEADS = dst->src[1]->ne[2];
const int64_t n_seqs = dst->src[5]->ne[1];
const int64_t head_size = C / HEADS;

float * dst_data = (float *) dst->data;
float * state = ((float *) dst->data) + C * T;

if (params->ith != 0) {
const int ith = params->ith;
const int nth = params->nth;

if (ith >= HEADS) {
return;
}

memset(dst_data, 0, T * C * sizeof(float));
const int h_start = (HEADS * ith) / nth;
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
(HEADS * (ith + 1)) / nth : HEADS;

float * k = (float *) dst->src[0]->data;
float * v = (float *) dst->src[1]->data;
float * r = (float *) dst->src[2]->data;
float * time_faaaa = (float *) dst->src[3]->data;
float * time_decay = (float *) dst->src[4]->data;

size_t t_stride = H * (C / H);
size_t t_stride = HEADS * head_size; // Same to C

size_t h_stride = C / H;
size_t h_stride_2d = (C / H) * (C / H);
size_t h_stride = C / HEADS;
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
size_t h_stride_2d = head_size * head_size;

// basically fused operations:
// dst = r @ (time_faaaa * (k @ v) + state),
// state = time_decay * state + (k @ v),
// recursive through each token
for (size_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride;
size_t state_offset = (C / H) * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
if (ith == 0) {
memset(dst_data, 0, T * C * sizeof(float));
}
ggml_barrier(params->threadpool);

for (size_t h = 0; h < H; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;

for (size_t i = 0; i < C / H; i++) {
size_t t_h_i_offset = t_h_offset + i;
size_t h_i_offset = h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
#if defined(__AVX__) && !defined(__AVX512F__)
#define GGML_F32X GGML_F32x8
#define GGML_F32X_SET1 GGML_F32x8_SET1
#define GGML_F32X_LOAD GGML_F32x8_LOAD
#define GGML_F32X_STORE GGML_F32x8_STORE
#define GGML_F32X_MUL GGML_F32x8_MUL
#define GGML_F32X_FMA GGML_F32x8_FMA
#define WKV_VECTOR_SIZE 8
#elif defined(__AVX512F__)
#define GGML_F32X GGML_F32x16
#define GGML_F32X_SET1 GGML_F32x16_SET1
#define GGML_F32X_LOAD GGML_F32x16_LOAD
#define GGML_F32X_STORE GGML_F32x16_STORE
#define GGML_F32X_MUL GGML_F32x16_MUL
#define GGML_F32X_FMA GGML_F32x16_FMA
#define WKV_VECTOR_SIZE 16
#elif defined(__ARM_NEON) && defined(__aarch64__)
#define GGML_F32X GGML_F32x4
#define GGML_F32X_SET1 GGML_F32x4_SET1
#define GGML_F32X_LOAD GGML_F32x4_LOAD
#define GGML_F32X_STORE GGML_F32x4_STORE
#define GGML_F32X_MUL GGML_F32x4_MUL
#define GGML_F32X_FMA GGML_F32x4_FMA
#define WKV_VECTOR_SIZE 4
#endif

float k_val = k[t_h_i_offset];
float r_val = r[t_h_i_offset];
float time_faaaa_val = time_faaaa[h_i_offset];
// RWKV v6: different time_decay for each token.
float time_decay_val = time_decay[t_h_i_offset];
#ifdef WKV_VECTOR_SIZE
const int64_t vec_count = head_size / WKV_VECTOR_SIZE;

for (int64_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride;
size_t state_offset = head_size * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;

for (int64_t h = h_start; h < h_end; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;

for (int64_t i = 0; i < head_size; i++) {
size_t t_h_i_offset = t_h_offset + i;
size_t h_i_offset = h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride;

float k_val = k[t_h_i_offset];
float r_val = r[t_h_i_offset];
float time_faaaa_val = time_faaaa[h_i_offset];
float time_decay_val = time_decay[t_h_i_offset];

// Broadcast scalar values to vectors
GGML_F32X k_vec = GGML_F32X_SET1(k_val);
GGML_F32X r_vec = GGML_F32X_SET1(r_val);
GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);

for (int64_t j = 0; j < vec_count; j++) {
size_t base_j = j * WKV_VECTOR_SIZE;
size_t t_h_j_offset = t_h_offset + base_j;
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;

// Load x elements at once
GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);

// Compute kv = v * k
GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);

// Compute temp = kv * time_faaaa + prev_state
GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);

// Update dst: dst += temp * r
dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);

// Update state: state = prev_state * time_decay + kv
GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
}

for (size_t j = 0; j < C / H; j ++) {
size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;
// Handle remaining elements, this will not be used.
for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;
float v_val = v[t_h_j_offset];
float kv_val = v_val * k_val;
float prev_state_val = state_prev[h_2d_i_j_offset];
float temp_val = kv_val * time_faaaa_val + prev_state_val;
dst_data[t_h_j_offset] += temp_val * r_val;
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
}
}
}
}

float v_val = v[t_h_j_offset];
float kv_val = v_val * k_val;
float prev_state_val = state_prev[h_2d_i_j_offset];
float temp_val = kv_val * time_faaaa_val + prev_state_val;
dst_data[t_h_j_offset] += temp_val * r_val;
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
#else
// basically fused operations:
// dst = r @ (time_faaaa * (k @ v) + state),
// state = time_decay * state + (k @ v),
// recursive through each token
for (int64_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride;
size_t state_offset = head_size * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;

for (int64_t h = h_start; h < h_end; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;

for (int64_t i = 0; i < head_size; i++) {
size_t t_h_i_offset = t_h_offset + i;
size_t h_i_offset = h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride;

float k_val = k[t_h_i_offset];
float r_val = r[t_h_i_offset];
float time_faaaa_val = time_faaaa[h_i_offset];
// RWKV v6: different time_decay for each token.
float time_decay_val = time_decay[t_h_i_offset];

for (int64_t j = 0; j < head_size; j++) {
size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;

float v_val = v[t_h_j_offset];
float kv_val = v_val * k_val;
float prev_state_val = state_prev[h_2d_i_j_offset];
float temp_val = kv_val * time_faaaa_val + prev_state_val;
dst_data[t_h_j_offset] += temp_val * r_val;
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
}
}
}
}
}
#endif
}

static void ggml_compute_forward_rwkv_wkv(

static void ggml_compute_forward_rwkv_wkv6(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {

Expand All @@ -11723,7 +11835,7 @@ static void ggml_compute_forward_rwkv_wkv(
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_rwkv_wkv_f32(params, dst);
ggml_compute_forward_rwkv_wkv6_f32(params, dst);
} break;
default:
{
Expand Down Expand Up @@ -12475,9 +12587,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_add_rel_pos(params, tensor);
} break;
case GGML_OP_RWKV_WKV:
case GGML_OP_RWKV_WKV6:
{
ggml_compute_forward_rwkv_wkv(params, tensor);
ggml_compute_forward_rwkv_wkv6(params, tensor);
} break;
case GGML_OP_MAP_UNARY:
{
Expand Down Expand Up @@ -12775,7 +12887,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART:
case GGML_OP_GET_REL_POS:
case GGML_OP_RWKV_WKV:
case GGML_OP_RWKV_WKV6:
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32:
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/rwkv-wkv.cuh"
#include "ggml-cuda/wkv6.cuh"

#include <algorithm>
#include <array>
Expand Down Expand Up @@ -2319,8 +2319,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;
case GGML_OP_RWKV_WKV:
ggml_cuda_op_rwkv_wkv(ctx, dst);
case GGML_OP_RWKV_WKV6:
ggml_cuda_op_rwkv_wkv6(ctx, dst);
break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
ggml_cuda_cross_entropy_loss_back(ctx, dst);
Expand Down Expand Up @@ -3153,7 +3153,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
case GGML_OP_RWKV_WKV:
case GGML_OP_RWKV_WKV6:
return true;
case GGML_OP_FLASH_ATTN_EXT: {
#ifndef FLASH_ATTN_AVAILABLE
Expand Down
5 changes: 0 additions & 5 deletions ggml/src/ggml-cuda/rwkv-wkv.cuh

This file was deleted.

6 changes: 3 additions & 3 deletions ggml/src/ggml-cuda/rwkv-wkv.cu → ggml/src/ggml-cuda/wkv6.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "common.cuh"
#include "rwkv-wkv.cuh"
#include "wkv6.cuh"

static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
const int tid = threadIdx.x;
Expand Down Expand Up @@ -64,7 +64,7 @@ static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const
}
}

void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float * k_d = (const float *)dst->src[0]->data;
const float * v_d = (const float *)dst->src[1]->data;
const float * r_d = (const float *)dst->src[2]->data;
Expand All @@ -83,7 +83,7 @@ void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64

rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
}
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/wkv6.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_WKV_BLOCK_SIZE 64

void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
Loading
Loading