-
Notifications
You must be signed in to change notification settings - Fork 381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add experimental refactorings. #683
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
using LLama.Experimental.Common; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
|
||
namespace LLama.Experimental.Abstractions | ||
{ | ||
/// <summary> | ||
/// It defines how to execute the model. | ||
/// </summary> | ||
public interface IModelRunner: IDisposable | ||
{ | ||
/// <summary> | ||
/// Deal with the scheduled sequences to get the output. | ||
/// </summary> | ||
/// <param name="seqGroupMetadataList"></param> | ||
/// <returns></returns> | ||
SamplerOutput ExecuteModel(IEnumerable<SequenceGroupMetadata> seqGroupMetadataList); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
using LLama.Experimental.Common; | ||
using LLama.Experimental.Runner.LLamaCpp; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
|
||
namespace LLama.Experimental.Abstractions | ||
{ | ||
/// <summary> | ||
/// Method to sample the model output. | ||
/// </summary> | ||
public interface ISamplingMethod | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the flexibility of public interface ISampler<T> { T Sample(in Span<float> logits, SamplingMetaData metadata = null); }
public class BaseTokenSampler : ISampler<LLamaToken> { ... }
public interface ISamplerOption {
bool ShouldApply(in Span<float> logits, SamplingMetadata data) => true;
void Apply(ref Span<float> logits, SamplingMetadata data);
} And for the highest level something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you a lot for your suggestions! I like the way of
This looks a bit confusing to me, could you please tell me more about your idea? When and who should call
From my point of view they are entirely two different things. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I’m glad you like it! I actually made a POC for the modals with some draft code, I’ll share tomorrow but one idea is to call them from the high level generator that has access to Regarding Batch and Sequence from an inference perspective it just makes more sense to me for a ‘Batch’ to be a ‘batch’ for the user to process/consume, rather than a ‘batch’ for the models internals — but I get it’s an established term for that right now :p Anyway whenever I wrote ‘batch’ above I meant what’s established as ‘sequence’! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's certainly okay to expose There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I like the idea of
I don't like this idea though. One of the motivations of creating In general I don't think designs that try to have a single all-powerful config object are very good. I think it's almost always simpler and more powerful to expose primitives that can be chained together in a nice simple chain of calls like this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We agree ^^ I now regret the naming I used lol.. I meant to say suggest we wrap it all up inside a
You're right! But I understand how that might be considered boilerplate/bloat and leaving it up to a pipeline might be best here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before I pushed the current sampler API I experimented with a few different APIs, one was encapsulating every sampler operation into an object (with all the relevant parameters) so a pipeline would just be a list of sampler stage objects. I didn't find it very good to work with, in reality sampler stages are not always completely isolated and so it was a very leaky abstraction. |
||
// TODO: We should reconsider this design. Maybe it's better to use `SamplingParams` to let user set, | ||
// and choose the actual sampler internally according to the params. | ||
{ | ||
/// <summary> | ||
///The maximum number of sequences running in parallel. | ||
/// | ||
/// If you don't know what to return, you can return the default value. | ||
/// | ||
/// Generally, if you want to select several results from n results, you need | ||
/// to set the maximum number of sequences to run. | ||
/// </summary> | ||
/// <param name="defaultValue"></param> | ||
/// <param name="currentNumSeqs"></param> | ||
/// <returns></returns> | ||
int GetMaxNumRunningSeqs(int defaultValue, int currentNumSeqs); | ||
|
||
/// <summary> | ||
/// Whether to skip special tokens. | ||
/// </summary> | ||
bool SkipSpecialTokens { get; } | ||
|
||
/// <summary> | ||
/// Sample the sequence logits to get the token. | ||
/// </summary> | ||
/// <param name="logits"></param> | ||
/// <param name="seqId"></param> | ||
/// <param name="samplingMetadata"></param> | ||
/// <returns></returns> | ||
SequenceOutput SampleSequence(Span<float> logits, int seqId, SamplingMetadata samplingMetadata); | ||
// TODO: maybe we shouldn't expose all the samplingMetadata to users here. | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
using LLama.Experimental.Common; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
|
||
namespace LLama.Experimental.Abstractions | ||
{ | ||
/// <summary> | ||
/// Define the scheduling policy, which decides the priority orders of sequences. | ||
/// </summary> | ||
public interface ISchedulingPolicy | ||
{ | ||
/// <summary> | ||
/// Get the priority of a sequence group. | ||
/// </summary> | ||
/// <param name="now"></param> | ||
/// <param name="seqGroup"></param> | ||
/// <returns></returns> | ||
int GetPriority(DateTime now, SequenceGroup seqGroup); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
using LLama.Experimental.Common; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
|
||
namespace LLama.Experimental.Abstractions | ||
{ | ||
/// <summary> | ||
/// Stopping criteria that can be applied during generation. | ||
/// </summary> | ||
public interface IStoppingCriteria | ||
{ | ||
/// <summary> | ||
/// Check if the sequence should be stopped and return the status. | ||
/// | ||
/// If it's not supposed to be stopped, be sure to return its current status. | ||
/// </summary> | ||
/// <param name="seq"></param> | ||
/// <returns></returns> | ||
StoppingCriteriaOutput CheckStop(Sequence seq); // TODO: include other params? | ||
} | ||
|
||
/// <summary> | ||
/// The output of <see cref="IStoppingCriteria.CheckStop(Sequence)"/> | ||
/// </summary> | ||
/// <param name="Status">The sequence status.</param> | ||
/// <param name="StoppingString">If the sequence stops because of the appearance of a string, please set it here.</param> | ||
/// <param name="StoppingTokenId">If the sequence stops because of the appearance of a token, please set it here.</param> | ||
public record class StoppingCriteriaOutput(SequenceStatus Status, string? StoppingString = null, int? StoppingTokenId = null); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
|
||
namespace LLama.Experimental.Abstractions | ||
{ | ||
/// <summary> | ||
/// The interface for tokenizer in LLamaSharp. It's responsible for converting text to token ids, or vice versa. | ||
/// </summary> | ||
public interface ITokenizer | ||
{ | ||
// TODO: `ApplyChatTemplate` API | ||
|
||
// TODO: Batched Encode? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As far as I know there's no batched tokenization APIs in llama.cpp we could use. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd be all in to allow bypassing llama.cpp's built-in tokenizer -- no reason to lock the architecture with that as dependency. |
||
|
||
/// <summary> | ||
/// Get the token ids from the text | ||
/// </summary> | ||
/// <param name="input"></param> | ||
/// <returns></returns> | ||
IList<int> Tokenize(string input); | ||
|
||
/// <summary> | ||
/// Convert the token ids to text. | ||
/// </summary> | ||
/// <param name="tokenIds"></param> | ||
/// <param name="result"></param> | ||
/// <param name="skipSpecialTokens"></param> | ||
/// <returns>The consumed tokens for decoding.</returns> | ||
int ConvertIdsToText(IEnumerable<int> tokenIds, out string result, bool skipSpecialTokens = false); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't have an API like this. It's the same problem as the old A single character may be several tokens. So for example say the tokens There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is an |
||
|
||
// TODO: decode from Logprobs | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
|
||
namespace LLama.Experimental.Common | ||
{ | ||
/// <summary> | ||
/// The input prepared for model runner. | ||
/// </summary> | ||
/// <param name="TokenIds">The tokens to feed to the model.</param> | ||
/// <param name="Positions">The positions of these tokens.</param> | ||
/// <param name="SeqIds">The sequence ids of these tokens.</param> | ||
/// <param name="WithLogits">Whether the logits need to be computed for the token.</param> | ||
/// <param name="PromptLengths">The lengths of the prompts if the input is at prefill stage, otherwise empty.</param> | ||
/// <param name="SubqueryLengths">The lengths of the subqueries if the input is at prefill stage, otherwise empty.</param> | ||
public record class ModelRunnerInput( | ||
int[] TokenIds, | ||
int[] Positions, | ||
int[] SeqIds, | ||
bool[] WithLogits, | ||
int[] PromptLengths, | ||
int[] SubqueryLengths | ||
); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
|
||
namespace LLama.Experimental.Common | ||
{ | ||
/// <summary> | ||
/// Metrics associated with a request. | ||
/// </summary> | ||
public class RequestMetrics | ||
{ | ||
/// <summary> | ||
/// The time when the request arrived. | ||
/// </summary> | ||
public DateTime ArrivalTime { get; set; } | ||
|
||
/// <summary> | ||
/// The time when the request was first scheduled. | ||
/// </summary> | ||
public DateTime? FirstScheduledTime { get; set; } | ||
|
||
/// <summary> | ||
/// The time when the first token was generated. | ||
/// </summary> | ||
public DateTime? FirstTokenTime { get; set; } | ||
|
||
/// <summary> | ||
/// The time when the last token was generated. | ||
/// </summary> | ||
public DateTime? LastTokenTime { get; set; } | ||
|
||
/// <summary> | ||
/// The time the request spent in the queue. | ||
/// </summary> | ||
public TimeSpan? TimeInQueue { get; set; } | ||
|
||
/// <summary> | ||
/// The time when the request was finished. | ||
/// </summary> | ||
public DateTime? FinishedTime { get; set; } | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
using LLama.Experimental.Utils; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Text; | ||
|
||
namespace LLama.Experimental.Common | ||
{ | ||
/// <summary> | ||
/// The output data of a request to the LLM. | ||
/// </summary> | ||
/// <param name="RequestId">The unique ID of the request.</param> | ||
/// <param name="Prompt">The prompt string of the request.</param> | ||
/// <param name="PromptTokenIds">The token IDs of the prompt.</param> | ||
/// <param name="Outputs">The output sequences of the request.</param> | ||
/// <param name="Finished">Whether the whole request is finished.</param> | ||
/// <param name="Metrics">Metrics associated with the request.</param> | ||
public record class RequestOutput( | ||
string RequestId, | ||
string? Prompt, | ||
IList<int> PromptTokenIds, | ||
IList<CompletionOutput> Outputs, | ||
bool Finished, | ||
RequestMetrics Metrics | ||
) | ||
{ | ||
/// <summary> | ||
/// Create an instance from <see cref="SequenceGroup"/>. | ||
/// </summary> | ||
/// <param name="seqGroup"></param> | ||
/// <returns></returns> | ||
/// <exception cref="NotImplementedException"></exception> | ||
public static RequestOutput FromSeqGroup(SequenceGroup seqGroup) | ||
{ | ||
var seqs = seqGroup.GetAllSeqs(); | ||
if(seqs.Count() != 1) | ||
{ | ||
// TODO: deal with beam search here. | ||
throw new NotImplementedException(); | ||
} | ||
|
||
List<CompletionOutput> outputs = new(); | ||
int index = 0; | ||
foreach(var seq in seqs) | ||
{ | ||
outputs.Add(new CompletionOutput(index, seq.OutputText, seq.OutputTokens, | ||
seq.Status.GetFinishedReason(), seq.StoppingString, seq.StoppingTokenId)); | ||
index++; | ||
} | ||
|
||
if (seqGroup.IsFinished) | ||
{ | ||
seqGroup.SetFinishedTime(DateTime.Now); | ||
} | ||
return new RequestOutput(seqGroup.RequestId, seqGroup.Prompt, seqGroup.PromptTokenIds, | ||
outputs, seqGroup.IsFinished, seqGroup.Metrics); | ||
} | ||
|
||
/// <inheritdoc/> | ||
public override string ToString() | ||
{ | ||
return ClassStringFormatter.Format(this); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// The output data of one completion output of a request. | ||
/// </summary> | ||
/// <param name="Index">The index of the output in the request.</param> | ||
/// <param name="Text">The generated output text.</param> | ||
/// <param name="TokenIds">The token IDs of the generated output text.</param> | ||
/// <param name="FinishReason">The reason why the sequence is finished.</param> | ||
/// <param name="StoppingString"> | ||
/// The stop string that caused the completion to stop, | ||
/// Null if the completion finished for some other reason. | ||
/// </param> | ||
/// <param name="StoppingToken"> | ||
/// The stop string that caused the completion to stop, | ||
/// Null if the completion finished for some other reason. | ||
/// </param> | ||
public record class CompletionOutput( | ||
int Index, | ||
string Text, | ||
IList<int> TokenIds, | ||
string FinishReason, | ||
string? StoppingString, | ||
int? StoppingToken | ||
) | ||
{ | ||
/// <summary> | ||
/// Whether the completion has finished. | ||
/// </summary> | ||
public bool IsFinished => !string.IsNullOrEmpty(FinishReason); | ||
|
||
/// <inheritdoc/> | ||
public override string ToString() | ||
{ | ||
return ClassStringFormatter.Format(this); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks similar to
ISamplingPipeline
, is it basically that?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I renamed it because I was not sure if there'll be slightly different from it. Using a different name will not introduce modifications in non-experimental part.