From 1b70cde62790c84e9e9b456ea31932909c638798 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 14 Oct 2024 16:12:39 -0600 Subject: [PATCH] fix(llama): Overhaul use of sampling module for llama.cpp changes The changes here reflect the changes made in the big llama.cpp sampling PR https://github.com/ggerganov/llama.cpp/pull/9294 The sampling functionality is now broken into the base interface (llama_sampler) and the generation implementation (gpt_sampler). The changes here reflect that. Since the sampling.h/sampling.cpp code uses c++ STL headers, the sampling_ext.[h|cpp] wrapper is maintained to allow go to access a pure-C interface. Branch: IBMGraniteArchitectureSupport Signed-off-by: Gabe Goodhart --- llama/llama.go | 20 ++++++++------------ llama/runner/runner.go | 8 ++++---- llama/sampling_ext.cpp | 28 ++++++++++++++-------------- llama/sampling_ext.h | 14 +++++++------- 4 files changed, 33 insertions(+), 37 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index 620b54a10c3..b45a4962aae 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -445,7 +445,7 @@ func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data [] // sampling // TODO: this is a temporary wrapper to allow calling C++ code from CGo type SamplingContext struct { - c *C.struct_llama_sampling_context + c *C.struct_llama_sampler } type SamplingParams struct { @@ -467,7 +467,8 @@ type SamplingParams struct { Grammar string } -func NewSamplingContext(params SamplingParams) *SamplingContext { +func NewSamplingContext(model *Model, params SamplingParams) *SamplingContext { + var cparams C.struct_llama_sampling_cparams cparams.top_k = C.int32_t(params.TopK) cparams.top_p = C.float(params.TopP) @@ -489,7 +490,7 @@ func NewSamplingContext(params SamplingParams) *SamplingContext { defer C.free(unsafe.Pointer(grammar)) cparams.grammar = grammar - context := &SamplingContext{c: C.llama_sampling_cinit(&cparams)} + context := &SamplingContext{c: C.llama_sampling_cinit(model.c, &cparams)} runtime.SetFinalizer(context, func(s *SamplingContext) { C.llama_sampling_cfree(s.c) }) return context @@ -499,15 +500,10 @@ func (s *SamplingContext) Reset() { C.llama_sampling_creset(s.c) } -func (s *SamplingContext) Sample(ctxMain *Context, ctxConfig *Context, idx int) int { - // TODO (jmorganca): handle nil for all args - if ctxConfig == nil { - return int(C.llama_sampling_csample(s.c, ctxMain.c, nil, C.int(idx))) - } - - return int(C.llama_sampling_csample(s.c, ctxMain.c, ctxConfig.c, C.int(idx))) +func (s *SamplingContext) Sample(ctxMain *Context, idx int) int { + return int(C.llama_sampling_csample(s.c, ctxMain.c, C.int(idx))) } -func (s *SamplingContext) Accept(ctxMain *Context, id int, applyGrammar bool) { - C.llama_sampling_caccept(s.c, ctxMain.c, C.llama_token(id), C.bool(applyGrammar)) +func (s *SamplingContext) Accept(id int, applyGrammar bool) { + C.llama_sampling_caccept(s.c, C.llama_token(id), C.bool(applyGrammar)) } diff --git a/llama/runner/runner.go b/llama/runner/runner.go index bf799d37cc4..f4c45e0f22a 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -126,10 +126,10 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen var sc *llama.SamplingContext if params.samplingParams != nil { - sc = llama.NewSamplingContext(*params.samplingParams) + sc = llama.NewSamplingContext(s.model, *params.samplingParams) for _, input := range inputs { if input.embed == nil { - sc.Accept(s.lc, input.token, false) + sc.Accept(input.token, false) } } } @@ -429,8 +429,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) } // sample a token - token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch) - seq.samplingCtx.Accept(s.lc, token, true) + token := seq.samplingCtx.Sample(s.lc, seq.iBatch) + seq.samplingCtx.Accept(token, true) piece := s.model.TokenToPiece(token) seq.numPredicted++ diff --git a/llama/sampling_ext.cpp b/llama/sampling_ext.cpp index da92cedf0f2..9494d13886f 100644 --- a/llama/sampling_ext.cpp +++ b/llama/sampling_ext.cpp @@ -2,14 +2,15 @@ #include "sampling.h" #include "sampling_ext.h" -struct llama_sampling_context *llama_sampling_cinit(struct llama_sampling_cparams *params) +struct llama_sampler *llama_sampling_cinit( + const struct llama_model *model, struct llama_sampling_cparams *params) { - llama_sampling_params sparams; + gpt_sampler_params sparams; sparams.top_k = params->top_k; sparams.top_p = params->top_p; sparams.min_p = params->min_p; sparams.tfs_z = params->tfs_z; - sparams.typical_p = params->typical_p; + sparams.typ_p = params->typical_p; sparams.temp = params->temp; sparams.penalty_last_n = params->penalty_last_n; sparams.penalty_repeat = params->penalty_repeat; @@ -21,33 +22,32 @@ struct llama_sampling_context *llama_sampling_cinit(struct llama_sampling_cparam sparams.penalize_nl = params->penalize_nl; sparams.seed = params->seed; sparams.grammar = params->grammar; - return llama_sampling_init(sparams); + return (llama_sampler*)gpt_sampler_init(model, sparams); } -void llama_sampling_cfree(struct llama_sampling_context *ctx) +void llama_sampling_cfree(struct llama_sampler *sampler) { - llama_sampling_free(ctx); + gpt_sampler_free((gpt_sampler*)sampler); } -void llama_sampling_creset(struct llama_sampling_context *ctx) +void llama_sampling_creset(struct llama_sampler *sampler) { - llama_sampling_reset(ctx); + gpt_sampler_reset((gpt_sampler*)sampler); } llama_token llama_sampling_csample( - struct llama_sampling_context *ctx_sampling, + struct llama_sampler *sampler, struct llama_context *ctx_main, - struct llama_context *ctx_cfg, int idx) { - return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx); + // TODO (ggoodhart): Do we need to support grammar_first? + return gpt_sampler_sample((gpt_sampler*)sampler, ctx_main, idx); } void llama_sampling_caccept( - struct llama_sampling_context *ctx_sampling, - struct llama_context *ctx_main, + struct llama_sampler *sampler, llama_token id, bool apply_grammar) { - llama_sampling_accept(ctx_sampling, ctx_main, id, apply_grammar); + gpt_sampler_accept((gpt_sampler*)sampler, id, apply_grammar); } diff --git a/llama/sampling_ext.h b/llama/sampling_ext.h index 588ed5c1e46..c29d601da55 100644 --- a/llama/sampling_ext.h +++ b/llama/sampling_ext.h @@ -29,19 +29,19 @@ extern "C" char *grammar; }; - struct llama_sampling_context *llama_sampling_cinit(struct llama_sampling_cparams *params); - void llama_sampling_cfree(struct llama_sampling_context *ctx); - void llama_sampling_creset(struct llama_sampling_context *ctx); + struct llama_sampler *llama_sampling_cinit( + const struct llama_model *model, + struct llama_sampling_cparams *params); + void llama_sampling_cfree(struct llama_sampler *sampler); + void llama_sampling_creset(struct llama_sampler *sampler); llama_token llama_sampling_csample( - struct llama_sampling_context *ctx_sampling, + struct llama_sampler *sampler, struct llama_context *ctx_main, - struct llama_context *ctx_cfg, int idx); void llama_sampling_caccept( - struct llama_sampling_context *ctx_sampling, - struct llama_context *ctx_main, + struct llama_sampler *sampler, llama_token id, bool apply_grammar);