Skip to content

Commit

Permalink
FlexLLM (part 3) (#106)
Browse files Browse the repository at this point in the history
* update

* longer prompts

* Inference test fixes (#105)

* FlexLLM (part 2) (#104)

* init

* update

* hip fixes

* fixes

* Update run.sh

* save outputs in json format, fix test script

* update

* update checker script

* fix

* update

* update

* update

* update

* update

* shellcheck

* rocm

* update
  • Loading branch information
goliaro authored Feb 26, 2025
1 parent 2488463 commit 2e764a9
Show file tree
Hide file tree
Showing 18 changed files with 526 additions and 450 deletions.
1 change: 1 addition & 0 deletions conda/flexflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ dependencies:
- loralib
- triton
- peft
- pytest
2 changes: 1 addition & 1 deletion docker/flexflow-environment/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ RUN rm /usr/local/bin/install_pytorch.sh
RUN pip3 install transformers>=4.47.1 sentencepiece einops
RUN pip3 install tensorflow notebook
# PEFT-related
RUN pip3 install scipy bitsandbytes datasets accelerate loralib triton peft
RUN pip3 install scipy bitsandbytes datasets accelerate loralib triton peft pytest
RUN pip3 install streamlit

# Install Rust
Expand Down
1 change: 1 addition & 0 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ class RequestManager {
void record_decoding_req_profiling_info(BatchConfig const &old_fwd_bc,
int req_idx);
void record_step_profile_info(BatchConfig const &old_bc);
void save_output_to_json();
void save_profiling_info_to_csv(std::string output_folder,
std::string dataset_name,
std::string llm_model_name,
Expand Down
2 changes: 1 addition & 1 deletion inference/python/incr_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def main():
)
llm.compile(
generation_config,
max_requests_per_batch=1,
max_requests_per_batch=4,
max_seq_length=256,
max_tokens_per_batch=64,
)
Expand Down
4 changes: 2 additions & 2 deletions inference/python/spec_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,15 @@ def main():
for ssm in ssms:
ssm.compile(
generation_config,
max_requests_per_batch=1,
max_requests_per_batch=4,
max_seq_length=256,
max_tokens_per_batch=64,
)

# Compile the LLM for inference and load the weights into memory
llm.compile(
generation_config,
max_requests_per_batch=1,
max_requests_per_batch=4,
max_seq_length=256,
max_tokens_per_batch=64,
ssms=ssms,
Expand Down
2 changes: 1 addition & 1 deletion python/flexflow/core/flexflow_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4718,7 +4718,7 @@ def generate(self, requests_list: List[Request]):
] # entry will be None for finetuning requests
c_output_texts = [
(
ffi.new("char[]", max_sequence_length * 5)
ffi.new("char[]", max_sequence_length * 10)
if request.req_type == RequestType.REQ_INFERENCE
else ffi.NULL
)
Expand Down
3 changes: 2 additions & 1 deletion python/flexflow/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,9 @@ def download_hf_weights_if_needed(self) -> None:
If not, or if the refresh_cache parameter is set to True, download new weights and convert them.
"""

# TODO: edit this to download the weights using snapshot_download and convert them to FlexFlow format without loading them to GPU
def download_and_convert_llm_weights(model_name):
num_cores = os.cpu_count() -1 if os.cpu_count() > 1 else 1
snapshot_download(repo_id=model_name, allow_patterns="*.safetensors", max_workers=min(30, num_cores))
hf_model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
Expand Down
7 changes: 4 additions & 3 deletions src/c/flexflow_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1780,10 +1780,11 @@ void flexflow_model_generate(flexflow_model_t handle_,
if (max_lengths[i] >= 0) {
assert(total_tokens <= max_lengths[i] || num_output_tokens == 0);
}
// assert(results[i].output_tokens.size() <= max_seq_lengths[i] ||
// results[i].output_tokens.size() ==
// results[i].input_tokens.size());
output_length_and_tokens[i][0] = results[i].output_tokens.size();
assert(results[i].output_tokens.size() <= max_lengths[i] + 100 &&
"Exceeding python buffer size for token ids");
assert(results[i].output_text.length() <= max_lengths[i] * 10 &&
"Exceeding python buffer size for output text");
std::copy(results[i].output_tokens.begin(),
results[i].output_tokens.end(),
output_length_and_tokens[i] + 1);
Expand Down
175 changes: 88 additions & 87 deletions src/ops/inc_multihead_self_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -989,67 +989,57 @@ __global__ void
int num_tokens,
int num_q_heads,
int num_kv_heads) {
CUDA_KERNEL_LOOP(i, num_tokens * num_q_heads * proj_size) {
size_t q_array_size = proj_size * num_q_heads * num_tokens;
int hidden_size = num_q_heads * proj_size;
int total_num_heads = num_q_heads + 2 * num_kv_heads;
// create complex number
bool q_tensor = i < (q_array_size / 2);
int real_i = q_tensor ? i : i - q_array_size / 2;
int token_idx = real_i / (hidden_size / 2);
int idx = real_i % (proj_size / 2);
int head_idx = (real_i - (token_idx * (hidden_size / 2))) / (proj_size / 2);
if (!q_tensor) {
head_idx /= (num_q_heads / num_kv_heads);
}

int real_part_index = idx + head_idx * proj_size +
token_idx * proj_size * total_num_heads +
hidden_size * (q_tensor ? 0 : 1);
int complex_part_index = real_part_index + (proj_size / 2);

complex_input[i] = {input_ptr[real_part_index],
input_ptr[complex_part_index]};

// get the freq_cis: shape 1 * (qProjSize/2) = 1 * 64
// apply a Cartesian coordinate transformation
// multiple with input & /copy back to q/k

// get position of token

// size_t pos = id_map[token_idx].token_position;
size_t pos = tokenInfos[token_idx].abs_depth_in_request;

// float before_real = complex_input[i].x, before_complex =
int pos_i = real_i % (proj_size / 2);

float freq =
pos * (1.0 / pow(rope_theta, (float)2 * pos_i / proj_size)); // θ_i
int half_proj = proj_size / 2;
int q_proj_work = num_tokens * num_q_heads * half_proj;
int kv_proj_work = num_tokens * num_kv_heads * half_proj;
int tot_num_heads = num_q_heads + 2 * num_kv_heads;
CUDA_KERNEL_LOOP(i, q_proj_work + kv_proj_work) {
bool q_tensor = i < q_proj_work;
int num_heads = q_tensor ? num_q_heads : num_kv_heads;

int real_i = q_tensor ? i : i - q_proj_work;
int token_idx = real_i / (half_proj * num_heads);
int pair_idx = real_i % half_proj;
int head_idx = (real_i / half_proj) % num_heads;

// input_ptr: [proj_size, tot_num_heads, num_tokens]
int real_part_index = token_idx * proj_size * tot_num_heads +
(q_tensor ? 0 : proj_size * num_q_heads) +
head_idx * proj_size + pair_idx;
int complex_part_index = real_part_index + half_proj;
complex_input[i] = {(float)input_ptr[real_part_index],
(float)input_ptr[complex_part_index]};

float inv_freq =
1.0 / pow(rope_theta, (float)2.0 * pair_idx / proj_size); // θ_i

if (llama3_rope) {
float pi = HIP_PI_F;
float wavelen = 2 * pi / freq;
float wavelen = 2 * pi / inv_freq;
float low_freq_wavelen =
original_max_position_embeddings / low_freq_factor;
float high_freq_wavelen =
original_max_position_embeddings / high_freq_factor;
if (wavelen < high_freq_wavelen) {
} else if (wavelen > low_freq_wavelen) {
freq = freq / factor;
} else {
assert(low_freq_wavelen != high_freq_wavelen);
float smooth =
(original_max_position_embeddings / wavelen - low_freq_factor) /
(high_freq_factor - low_freq_factor);
freq = ((1 - smooth) * freq / factor + smooth * freq);
if (wavelen > low_freq_wavelen) {
inv_freq = inv_freq / factor;
}
float smooth_factor =
(original_max_position_embeddings / wavelen - low_freq_factor) /
(high_freq_factor - low_freq_factor);
if (!(wavelen < high_freq_wavelen) && !(wavelen > low_freq_wavelen)) {
inv_freq = ((1 - smooth_factor) * inv_freq / factor +
smooth_factor * inv_freq);
}
}

hipFloatComplex complex_pos = {cos(freq), sin(freq)};
int pos = tokenInfos[token_idx].abs_depth_in_request;
inv_freq = inv_freq * pos;

hipFloatComplex complex_pos = {cos(inv_freq), sin(inv_freq)};

complex_input[i] = hipCmulf(complex_input[i], complex_pos);
input_ptr[real_part_index] = complex_input[i].x;
input_ptr[complex_part_index] = complex_input[i].y;
input_ptr[real_part_index] = (DT)complex_input[i].x;
input_ptr[complex_part_index] = (DT)complex_input[i].y;
}
}

Expand All @@ -1066,57 +1056,62 @@ __global__ void
int original_max_position_embeddings,
int proj_size,
int num_tokens,
int hidden_size) {
CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) {
int num_q_heads,
int num_kv_heads) {
int half_proj = proj_size / 2;
int q_proj_work = num_tokens * num_q_heads * half_proj;
int kv_proj_work = num_tokens * num_kv_heads * half_proj;
// int tot_num_heads = num_q_heads + 2 * num_kv_heads;
CUDA_KERNEL_LOOP(i, q_proj_work + kv_proj_work) {
// compute indexes to visit first half proj_size of each of q/k tensor.
// devQKVProj has shape [num_tokens, qProjSize, num_heads, 3] in peft_bwd
bool q_tensor = i < (num_tokens * hidden_size / 2);
int real_i = q_tensor ? i : i - num_tokens * hidden_size / 2;
assert(hidden_size % proj_size == 0);
int num_heads = hidden_size / proj_size;
// devQKVProj has shape [num_tokens, proj_size, tot_num_heads] in peft_bwd
bool q_tensor = i < q_proj_work;
int num_heads = q_tensor ? num_q_heads : num_kv_heads;
int real_i = q_tensor ? i : i - q_proj_work;

int token_idx = real_i % num_tokens;
int idx = (real_i / num_tokens) % (proj_size / 2);
int head_idx = real_i / (num_tokens * proj_size / 2);
int pair_idx = (real_i / num_tokens) % half_proj;
int head_idx = real_i / (num_tokens * half_proj);
assert(head_idx < num_heads);

int complex_part_index = (q_tensor ? 0 : 1) * num_tokens * hidden_size +
head_idx * num_tokens * proj_size +
idx * num_tokens + token_idx;
int real_part_index = complex_part_index + (proj_size / 2) * num_tokens;
int complex_part_index =
(q_tensor ? 0 : num_tokens * proj_size * num_q_heads) +
head_idx * proj_size * num_tokens + pair_idx * num_tokens + token_idx;
int real_part_index = complex_part_index + num_tokens * half_proj;

complex_input[i] = {input_ptr[real_part_index],
input_ptr[complex_part_index]};
complex_input[i] = {(float)input_ptr[real_part_index],
(float)input_ptr[complex_part_index]};

size_t pos = tokenInfos[token_idx].abs_depth_in_request;

float freq =
pos * (1.0 / pow(rope_theta, (float)2 * idx / proj_size)); // θ_i
float inv_freq =
1.0 / pow(rope_theta, (float)2.0 * pair_idx / proj_size); // θ_i

if (llama3_rope) {
float pi = HIP_PI_F;
float wavelen = 2 * pi / freq;
float wavelen = 2 * pi / inv_freq;
float low_freq_wavelen =
original_max_position_embeddings / low_freq_factor;
float high_freq_wavelen =
original_max_position_embeddings / high_freq_factor;
if (wavelen < high_freq_wavelen) {
} else if (wavelen > low_freq_wavelen) {
freq = freq / factor;
} else {
assert(low_freq_wavelen != high_freq_wavelen);
float smooth =
(original_max_position_embeddings / wavelen - low_freq_factor) /
(high_freq_factor - low_freq_factor);
freq = ((1 - smooth) * freq / factor + smooth * freq);
if (wavelen > low_freq_wavelen) {
inv_freq = inv_freq / factor;
}
float smooth_factor =
(original_max_position_embeddings / wavelen - low_freq_factor) /
(high_freq_factor - low_freq_factor);
if (!(wavelen < high_freq_wavelen) && !(wavelen > low_freq_wavelen)) {
inv_freq = ((1 - smooth_factor) * inv_freq / factor +
smooth_factor * inv_freq);
}
}

hipFloatComplex complex_pos = {cos(freq), sin(freq)};
int pos = tokenInfos[token_idx].abs_depth_in_request;
inv_freq = inv_freq * pos;

hipFloatComplex complex_pos = {cos(inv_freq), sin(inv_freq)};

complex_input[i] = hipCmulf(complex_input[i], complex_pos);
input_ptr[real_part_index] = complex_input[i].x;
input_ptr[complex_part_index] = complex_input[i].y;
input_ptr[real_part_index] = (DT)complex_input[i].x;
input_ptr[complex_part_index] = (DT)complex_input[i].y;
}
}

Expand Down Expand Up @@ -1151,7 +1146,10 @@ void apply_scaling_and_rotary(IncMultiHeadSelfAttentionMeta const *m,
// Step 3: apply rotary embedding if needed
if (m->rotary_embedding_meta->apply_rotary_embedding) {
/*q&k*/
parallelism = num_tokens * m->hidden_size;
int half_proj = m->qProjSize / 2;
int q_proj_work = num_tokens * m->num_q_heads * half_proj;
int kv_proj_work = num_tokens * m->num_kv_heads * half_proj;
parallelism = q_proj_work + kv_proj_work;
hipLaunchKernelGGL(
HIP_KERNEL_NAME(apply_rotary_embedding_fwd),
GET_BLOCKS(parallelism),
Expand Down Expand Up @@ -1781,15 +1779,17 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
assert(m->hidden_size == m->qProjSize * m->num_q_heads);
assert(m->qProjSize == m->kProjSize);
/*q&k*/
int parallelism = num_tokens * m->hidden_size;
DT *A = static_cast<DT *>(m->devQKVProjArray);
int half_proj = m->qProjSize / 2;
int q_proj_work = num_tokens * m->num_q_heads * half_proj;
int kv_proj_work = num_tokens * m->num_kv_heads * half_proj;
int parallelism = q_proj_work + kv_proj_work;
hipLaunchKernelGGL(
HIP_KERNEL_NAME(apply_rotary_embedding_bwd),
GET_BLOCKS(parallelism),
min(CUDA_NUM_THREADS, parallelism),
0,
stream,
A,
static_cast<DT *>(m->devQKVProjArray),
m->complex_input,
m->peft_token_infos_device,
m->rotary_embedding_meta->rope_theta,
Expand All @@ -1800,7 +1800,8 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
m->rotary_embedding_meta->original_max_position_embeddings,
m->qProjSize,
num_tokens,
m->hidden_size);
m->num_q_heads,
m->num_kv_heads);
DT *C = static_cast<DT *>(m->devQKVProjArray);
if (m->inference_debugging) {
std::string filename =
Expand Down Expand Up @@ -2092,7 +2093,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
max_tokens_per_batch * BatchConfig::max_sequence_length() * num_q_heads;
size_t attn_heads_size = max_tokens_per_batch * num_q_heads * vProjSize;
size_t complex_size = (max_tokens_per_batch * (qProjSize * num_q_heads +
kProjSize * num_q_heads)) /
kProjSize * num_kv_heads)) /
2;
if (enable_peft_finetuning) {
allocated_peft_buffer_size1 = BatchConfig::max_sequence_length() *
Expand Down
Loading

0 comments on commit 2e764a9

Please sign in to comment.