Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perplexity : faster Winogrande via batching #5024

Merged
merged 3 commits into from
Jan 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 158 additions & 125 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,26 +423,31 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
return {tokens, ppl, logit_history, prob_history};
}

static std::vector<float> evaluate_tokens(llama_context * ctx, std::vector<int> & tokens,
int n_past, int n_batch, int n_vocab) {
std::vector<float> result;
result.reserve(tokens.size() * n_vocab);
size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch;
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
size_t n_tokens = tokens.size() - i_chunk * n_batch;
n_tokens = std::min(n_tokens, size_t(n_batch));
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {};
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));

llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The llama_kv_cache_seq_rm call is no longer needed here?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not needed because we clear the entire KV cache before each batch:

llama_kv_cache_clear(ctx);

In the old implementation, it was reusing tokens from a previous batch, so the llama_kv_cache_seq_rm was used to evict the unused ones (i.e. the second sentence).

const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false;
}

const auto logits = llama_get_logits(ctx);
result.insert(result.end(), logits, logits + n_tokens * n_vocab);

n_past += n_tokens;
memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
}
return result;

return true;
}

static void hellaswag_compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
Expand Down Expand Up @@ -576,7 +581,6 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {

// determine the common prefix of the endings
hs_cur.common_prefix = 0;
hs_cur.required_tokens = 0;
for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) {
if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] ||
hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] ||
Expand Down Expand Up @@ -609,45 +613,18 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch;

const int max_tasks_per_batch = params.n_parallel;
const int max_tasks_per_batch = 32;
const int max_seq = 4*max_tasks_per_batch;

llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);

std::vector<float> tok_logits(n_vocab);
std::vector<float> batch_logits(n_ctx*n_vocab);
std::vector<float> batch_logits(n_vocab*n_ctx);

std::vector<std::pair<size_t, llama_token>> eval_pairs;
std::vector<float> eval_results;
std::vector<std::thread> workers(std::thread::hardware_concurrency());

auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));

llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};

const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false;
}

memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
}

return true;
};

for (size_t i0 = 0; i0 < hs_task_count; i0++) {
int n_cur = 0;

Expand Down Expand Up @@ -696,7 +673,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
llama_kv_cache_clear(ctx);

// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, n_batch)) {
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
return;
}
Expand Down Expand Up @@ -772,6 +749,13 @@ struct winogrande_entry {
std::string second;
std::array<std::string, 2> choices;
int answer;

size_t i_batch;
size_t common_prefix;
size_t required_tokens;
size_t n_base1; // number of tokens for context + choice 1
size_t n_base2; // number of tokens for context + choice 2
std::vector<llama_token> seq_tokens[2];
};

static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string& prompt) {
Expand Down Expand Up @@ -875,115 +859,164 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
data = std::move(selected);
}

fprintf(stderr, "%s : tokenizing selected tasks\n", __func__);

// This is needed as usual for LLaMA models
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));

for (auto & task : data) {
task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, add_bos);
task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, add_bos);

task.common_prefix = 0;
for (size_t k = 0; k < task.seq_tokens[0].size(); k++) {
if (task.seq_tokens[0][k] != task.seq_tokens[1][k]) {
break;
}
task.common_prefix++;
}

task.required_tokens = task.common_prefix +
task.seq_tokens[0].size() - task.common_prefix +
task.seq_tokens[1].size() - task.common_prefix;

task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos).size();
task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos).size();
}

ggerganov marked this conversation as resolved.
Show resolved Hide resolved
fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__);

const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);
const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch;

const int max_tasks_per_batch = 128;
const int max_seq = 2*max_tasks_per_batch;

llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);

std::vector<float> tok_logits(n_vocab);
std::vector<float> batch_logits(n_vocab*n_ctx);

int n_correct = 0;
int n_done = 0;

for (size_t task_idx = 0; task_idx < data.size(); task_idx++) {
const auto& task = data[task_idx];
for (size_t i0 = 0; i0 < data.size(); i0++) {
int n_cur = 0;

auto base_context = ::llama_tokenize(ctx, task.first, add_bos);
auto base_ctx_1st = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos);
auto base_ctx_2nd = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos);
size_t i1 = i0;
size_t i_batch = 0;

auto sentence_1st = task.first + task.choices[0] + task.second;
auto sentence_2nd = task.first + task.choices[1] + task.second;
auto query_1st = ::llama_tokenize(ctx, sentence_1st, add_bos);
auto query_2nd = ::llama_tokenize(ctx, sentence_2nd, add_bos);
llama_batch_clear(batch);

if (query_1st.size() > (size_t)n_ctx || query_2nd.size() > (size_t)n_ctx) {
fprintf(stderr, "%s : number of tokens in queries %zu, %zu > n_ctxl\n", __func__, query_1st.size(), query_2nd.size());
return;
}
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
const int s0 = 2*(i1 - i0);
if (s0 + 2 > max_seq) {
break;
}

for (size_t i = 0; i < data[i1].common_prefix; ++i) {
llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1}, false);
}
batch.logits[batch.n_tokens - 1] = true;

auto query_1st_size = query_1st.size();
auto query_2nd_size = query_2nd.size();
for (int s = 0; s < 2; ++s) {
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
}
}

// Speedup small evaluations by evaluating atleast 32 tokens
// For Winogrande this seems to slow it down rather than speed it up.
//if (query_1st.size() < 32) query_1st.resize(32);
//if (query_2nd.size() < 32) query_2nd.resize(32);
data[i1].i_batch = i_batch;
i_batch += data[i1].required_tokens;

llama_kv_cache_clear(ctx);
auto logits_1st = evaluate_tokens(ctx, query_1st, 0, params.n_batch, n_vocab);
n_cur += data[i1].required_tokens;
if (++i1 == data.size()) {
break;
}
}

if (i0 == i1) {
fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
return;
}

llama_kv_cache_clear(ctx);
auto logits_2nd = evaluate_tokens(ctx, query_2nd, 0, params.n_batch, n_vocab);

if (logits_1st.empty() || logits_2nd.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
return;
}

bool skip_choice = query_1st_size - base_ctx_1st.size() > k_min_trailing_ctx &&
query_2nd_size - base_ctx_2nd.size() > k_min_trailing_ctx;

float score_1st = 0;
bool is_nan_1st = false;
const auto& base_1 = skip_choice ? base_ctx_1st : base_context;
const int last_1st = query_1st_size - base_1.size() > 1 ? 1 : 0;
for (size_t j = base_1.size()-1; j < query_1st_size-1-last_1st; ++j) {
std::memcpy(tok_logits.data(), logits_1st.data() + j*n_vocab, n_vocab*sizeof(float));
const float prob = softmax(tok_logits)[query_1st[j+1]];
if (std::isnan(prob) || !prob) {
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
prob, j, sentence_1st.c_str(), base_context.size());
is_nan_1st = true;
break;
for (size_t i = i0; i < i1; ++i) {
auto & task = data[i];

const bool skip_choice =
task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;

float score_1st = 0;
bool is_nan_1st = false;
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
size_t li = n_base1 - 1;
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float));
const float prob = softmax(tok_logits)[task.seq_tokens[0][j+1]];
if (std::isnan(prob) || !prob) {
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
prob, j, (task.first + task.choices[0] + task.second).c_str(), n_base1);
is_nan_1st = true;
break;
}
score_1st += std::log(prob);
}
score_1st += std::log(prob);
}
score_1st /= (query_1st_size - base_1.size() - last_1st);

float score_2nd = 0;
bool is_nan_2nd = false;
const auto& base_2 = skip_choice ? base_ctx_2nd : base_context;
const int last_2nd = query_2nd_size - base_2.size() > 1 ? 1 : 0;
for (size_t j = base_2.size()-1; j < query_2nd_size-1-last_2nd; ++j) {
std::memcpy(tok_logits.data(), logits_2nd.data() + j*n_vocab, n_vocab*sizeof(float));
const float prob = softmax(tok_logits)[query_2nd[j+1]];
if (std::isnan(prob) || !prob) {
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
prob, j, sentence_2nd.c_str(), base_context.size());
is_nan_2nd = true;
break;
score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st);

float score_2nd = 0;
bool is_nan_2nd = false;
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1;
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float));
const float prob = softmax(tok_logits)[task.seq_tokens[1][j+1]];
if (std::isnan(prob) || !prob) {
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
prob, j, (task.first + task.choices[1] + task.second).c_str(), n_base2);
is_nan_2nd = true;
break;
}
score_2nd += std::log(prob);
}
score_2nd += std::log(prob);
}
score_2nd /= (query_2nd_size - base_2.size() - last_2nd);
score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd);

if (is_nan_1st || is_nan_2nd) {
continue;
}
if (is_nan_1st || is_nan_2nd) {
continue;
}

if (std::isnan(score_1st) || std::isnan(score_2nd)) {
printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd);
printf("Q1: <%s> - %zu tokens\n", sentence_1st.c_str(), query_1st_size);
printf("Q2: <%s> - %zu tokens\n", sentence_2nd.c_str(), query_2nd_size);
printf("B : <%s> - %zu tokens\n", task.first.c_str(), base_context.size());
printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", base_1.size(), base_2.size(), skip_choice);
continue;
}
if (std::isnan(score_1st) || std::isnan(score_2nd)) {
printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd);
printf("Q1: <%s> - %zu tokens\n", (task.first + task.choices[0] + task.second).c_str(), task.seq_tokens[0].size());
printf("Q2: <%s> - %zu tokens\n", (task.first + task.choices[1] + task.second).c_str(), task.seq_tokens[1].size());
printf("B : <%s> - %zu tokens\n", task.first.c_str(), task.common_prefix);
printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", n_base1, n_base2, skip_choice);
continue;
}

int result = score_1st > score_2nd ? 1 : 2;
int result = score_1st > score_2nd ? 1 : 2;

if (result == task.answer) {
++n_correct;
}
++n_done;

if (result == task.answer) {
++n_correct;
// Print the accumulated accuracy mean x 100
printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer);
fflush(stdout);
}
++n_done;

// Print the accumulated accuracy mean x 100
printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n",task_idx+1, 100.0 * n_correct/n_done,score_1st,score_2nd,result,task.answer);
fflush(stdout);
i0 = i1 - 1;
}

printf("\n");
Expand Down
Loading