Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add experimental refactorings. #683

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions LLama/Experimental/Abstractions/IModelRunner.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using LLama.Experimental.Common;
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Experimental.Abstractions
{
/// <summary>
/// It defines how to execute the model.
/// </summary>
public interface IModelRunner: IDisposable
{
/// <summary>
/// Deal with the scheduled sequences to get the output.
/// </summary>
/// <param name="seqGroupMetadataList"></param>
/// <returns></returns>
SamplerOutput ExecuteModel(IEnumerable<SequenceGroupMetadata> seqGroupMetadataList);
}
}
44 changes: 44 additions & 0 deletions LLama/Experimental/Abstractions/ISamplingMethod.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using LLama.Experimental.Common;
using LLama.Experimental.Runner.LLamaCpp;
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Experimental.Abstractions
{
/// <summary>
/// Method to sample the model output.
/// </summary>
public interface ISamplingMethod
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks similar to ISamplingPipeline, is it basically that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I renamed it because I was not sure if there'll be slightly different from it. Using a different name will not introduce modifications in non-experimental part.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the flexibility of SamplingMetaData replacing current ISamplingPipeline's lastTokens! A lot of info could be passed that way. A slightly more flexible design could be this:

public interface ISampler<T> { T Sample(in Span<float> logits, SamplingMetaData metadata = null); }
public class BaseTokenSampler : ISampler<LLamaToken> { ... }
public interface ISamplerOption {
    bool ShouldApply(in Span<float> logits, SamplingMetadata data) => true;
    void Apply(ref Span<float> logits, SamplingMetadata data);
}

And for the highest level something like samplingParams.CreateSampler() could unify these.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you a lot for your suggestions! I like the way of ILLamaSymbol you mentioned. Actually Martin has made some similar changes before, (for example LLamaToken and LLamaSeqId). I skipped these abstractions to rush in this prototype and will add them later.

Add abstractions for Modals

This looks a bit confusing to me, could you please tell me more about your idea? When and who should call NeedsEvaluation?

IReadOnlyList batches; // Name 'Batch' would make more sense instead of 'Sequence' imo

From my point of view they are entirely two different things. Sequence is a collection of semantic-continuous information and its lengths grows during the inference, while Batch is the data fed into the model in one-step inference. Batch might contain data from multiple sequences, as you can see in llama.cpp llama_batch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m glad you like it! I actually made a POC for the modals with some draft code, I’ll share tomorrow but one idea is to call them from the high level generator that has access to model.modals to digest the prompt.
Modals don’t need to have state, and would mostly be logic containers. (e.g. ImageModal would process pending image links in the llamabatch, and TextModal would finally process the tokens).

Regarding Batch and Sequence from an inference perspective it just makes more sense to me for a ‘Batch’ to be a ‘batch’ for the user to process/consume, rather than a ‘batch’ for the models internals — but I get it’s an established term for that right now :p Anyway whenever I wrote ‘batch’ above I meant what’s established as ‘sequence’!

Copy link
Collaborator Author

@AsakusaRinne AsakusaRinne Apr 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's certainly okay to expose batch to users, but I prefer to expose it in low-level APIs (ModelRunner in this draft) because batch is related with llama.cpp directly.

Copy link
Member

@martindevans martindevans Apr 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SamplingMetaData

I like the idea of SamplingMetaData containing more info than just the raw lastTokens that ISamplingPipeline has at the moment (although I'd probably try to come up with another name). Would be a nice simple PR if someone want to adjust the ISamplingPipeline interface to accept something like that (even if it only wraps up lastTokens for now, it's a good place to add more things in the future).

samplingParams.CreateSampler()

I don't like this idea though. One of the motivations of creating ISamplingPipeline was that we have/had configurable sampling and it's a mess - see ISamplingParams. Many of those properties don't mean anything depending on the value of other properties. Adding new sampling stages needs even more properties. Stages are not re-orderable so even with all this it's not powerful enough. New/custom sampling stages cannot be added at all.

In general I don't think designs that try to have a single all-powerful config object are very good. I think it's almost always simpler and more powerful to expose primitives that can be chained together in a nice simple chain of calls like this.

Copy link
Contributor

@Lyrcaxis Lyrcaxis Apr 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AsakusaRinne It's certainly okay to expose batch to users, but I prefer to expose it in low-level APIs

We agree ^^ I now regret the naming I used lol.. I meant to say suggest we wrap it all up inside a Context class for all other low-level and middle-level classes.

@martindevans I don't like this idea though. [..] - see ISamplingParams. [..] it's not powerful enough.

You're right! ISamplerOption in a dynamic list would fix that in my mind, but I get where you're coming from.
I was also having some chain design in mind, but it would shift control during creation, somewhat like so:
var sampler = p.With(new CustomTemperatureSamplerOption()).With(new BaseTopPSamplerOption(0.95f)).ToSampler();

But I understand how that might be considered boilerplate/bloat and leaving it up to a pipeline might be best here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before I pushed the current sampler API I experimented with a few different APIs, one was encapsulating every sampler operation into an object (with all the relevant parameters) so a pipeline would just be a list of sampler stage objects. I didn't find it very good to work with, in reality sampler stages are not always completely isolated and so it was a very leaky abstraction.

// TODO: We should reconsider this design. Maybe it's better to use `SamplingParams` to let user set,
// and choose the actual sampler internally according to the params.
{
/// <summary>
///The maximum number of sequences running in parallel.
///
/// If you don't know what to return, you can return the default value.
///
/// Generally, if you want to select several results from n results, you need
/// to set the maximum number of sequences to run.
/// </summary>
/// <param name="defaultValue"></param>
/// <param name="currentNumSeqs"></param>
/// <returns></returns>
int GetMaxNumRunningSeqs(int defaultValue, int currentNumSeqs);

/// <summary>
/// Whether to skip special tokens.
/// </summary>
bool SkipSpecialTokens { get; }

/// <summary>
/// Sample the sequence logits to get the token.
/// </summary>
/// <param name="logits"></param>
/// <param name="seqId"></param>
/// <param name="samplingMetadata"></param>
/// <returns></returns>
SequenceOutput SampleSequence(Span<float> logits, int seqId, SamplingMetadata samplingMetadata);
// TODO: maybe we shouldn't expose all the samplingMetadata to users here.
}
}
21 changes: 21 additions & 0 deletions LLama/Experimental/Abstractions/ISchedulingPolicy.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using LLama.Experimental.Common;
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Experimental.Abstractions
{
/// <summary>
/// Define the scheduling policy, which decides the priority orders of sequences.
/// </summary>
public interface ISchedulingPolicy
{
/// <summary>
/// Get the priority of a sequence group.
/// </summary>
/// <param name="now"></param>
/// <param name="seqGroup"></param>
/// <returns></returns>
int GetPriority(DateTime now, SequenceGroup seqGroup);
}
}
30 changes: 30 additions & 0 deletions LLama/Experimental/Abstractions/IStoppingCriteria.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using LLama.Experimental.Common;
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Experimental.Abstractions
{
/// <summary>
/// Stopping criteria that can be applied during generation.
/// </summary>
public interface IStoppingCriteria
{
/// <summary>
/// Check if the sequence should be stopped and return the status.
///
/// If it's not supposed to be stopped, be sure to return its current status.
/// </summary>
/// <param name="seq"></param>
/// <returns></returns>
StoppingCriteriaOutput CheckStop(Sequence seq); // TODO: include other params?
}

/// <summary>
/// The output of <see cref="IStoppingCriteria.CheckStop(Sequence)"/>
/// </summary>
/// <param name="Status">The sequence status.</param>
/// <param name="StoppingString">If the sequence stops because of the appearance of a string, please set it here.</param>
/// <param name="StoppingTokenId">If the sequence stops because of the appearance of a token, please set it here.</param>
public record class StoppingCriteriaOutput(SequenceStatus Status, string? StoppingString = null, int? StoppingTokenId = null);
}
34 changes: 34 additions & 0 deletions LLama/Experimental/Abstractions/ITokenizer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Experimental.Abstractions
{
/// <summary>
/// The interface for tokenizer in LLamaSharp. It's responsible for converting text to token ids, or vice versa.
/// </summary>
public interface ITokenizer
{
// TODO: `ApplyChatTemplate` API

// TODO: Batched Encode?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know there's no batched tokenization APIs in llama.cpp we could use.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be all in to allow bypassing llama.cpp's built-in tokenizer -- no reason to lock the architecture with that as dependency.


/// <summary>
/// Get the token ids from the text
/// </summary>
/// <param name="input"></param>
/// <returns></returns>
IList<int> Tokenize(string input);

/// <summary>
/// Convert the token ids to text.
/// </summary>
/// <param name="tokenIds"></param>
/// <param name="result"></param>
/// <param name="skipSpecialTokens"></param>
/// <returns>The consumed tokens for decoding.</returns>
int ConvertIdsToText(IEnumerable<int> tokenIds, out string result, bool skipSpecialTokens = false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't have an API like this. It's the same problem as the old DeTokenize method, sovled by the StreamingTokenDecoder.

A single character may be several tokens. So for example say the tokens [1, 2, 3] produce the character A then decoding [1, 2] will produce a broken string. Then decoding [3] will produce another broken string.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an IncrementalDecodingOffset member to Sequence. I'm trying to find a way to integrate the streaming decoder better into this design but haven't made it yet. The biggest problem of using it directly is that it rents some memories and needs to be released. Though adding StreamingTokenDecoder as a member of Sequence could make it auto-released, it's a bit weird to make Sequence own a StreamingTokenDecoder because it should be a part of the tokenizer. Besides, the Sequence, as a mid-level class which may used frequently by users, is not supposed to deal with logics related with LLamaContext, while LLamaContext is required to initialize a StreamingTokenDecoder.


// TODO: decode from Logprobs
}
}
24 changes: 24 additions & 0 deletions LLama/Experimental/Common/ModelRunnerInput.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Experimental.Common
{
/// <summary>
/// The input prepared for model runner.
/// </summary>
/// <param name="TokenIds">The tokens to feed to the model.</param>
/// <param name="Positions">The positions of these tokens.</param>
/// <param name="SeqIds">The sequence ids of these tokens.</param>
/// <param name="WithLogits">Whether the logits need to be computed for the token.</param>
/// <param name="PromptLengths">The lengths of the prompts if the input is at prefill stage, otherwise empty.</param>
/// <param name="SubqueryLengths">The lengths of the subqueries if the input is at prefill stage, otherwise empty.</param>
public record class ModelRunnerInput(
int[] TokenIds,
int[] Positions,
int[] SeqIds,
bool[] WithLogits,
int[] PromptLengths,
int[] SubqueryLengths
);
}
42 changes: 42 additions & 0 deletions LLama/Experimental/Common/RequestMetrics.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Experimental.Common
{
/// <summary>
/// Metrics associated with a request.
/// </summary>
public class RequestMetrics
{
/// <summary>
/// The time when the request arrived.
/// </summary>
public DateTime ArrivalTime { get; set; }

/// <summary>
/// The time when the request was first scheduled.
/// </summary>
public DateTime? FirstScheduledTime { get; set; }

/// <summary>
/// The time when the first token was generated.
/// </summary>
public DateTime? FirstTokenTime { get; set; }

/// <summary>
/// The time when the last token was generated.
/// </summary>
public DateTime? LastTokenTime { get; set; }

/// <summary>
/// The time the request spent in the queue.
/// </summary>
public TimeSpan? TimeInQueue { get; set; }

/// <summary>
/// The time when the request was finished.
/// </summary>
public DateTime? FinishedTime { get; set; }
}
}
101 changes: 101 additions & 0 deletions LLama/Experimental/Common/RequestOutput.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
using LLama.Experimental.Utils;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace LLama.Experimental.Common
{
/// <summary>
/// The output data of a request to the LLM.
/// </summary>
/// <param name="RequestId">The unique ID of the request.</param>
/// <param name="Prompt">The prompt string of the request.</param>
/// <param name="PromptTokenIds">The token IDs of the prompt.</param>
/// <param name="Outputs">The output sequences of the request.</param>
/// <param name="Finished">Whether the whole request is finished.</param>
/// <param name="Metrics">Metrics associated with the request.</param>
public record class RequestOutput(
string RequestId,
string? Prompt,
IList<int> PromptTokenIds,
IList<CompletionOutput> Outputs,
bool Finished,
RequestMetrics Metrics
)
{
/// <summary>
/// Create an instance from <see cref="SequenceGroup"/>.
/// </summary>
/// <param name="seqGroup"></param>
/// <returns></returns>
/// <exception cref="NotImplementedException"></exception>
public static RequestOutput FromSeqGroup(SequenceGroup seqGroup)
{
var seqs = seqGroup.GetAllSeqs();
if(seqs.Count() != 1)
{
// TODO: deal with beam search here.
throw new NotImplementedException();
}

List<CompletionOutput> outputs = new();
int index = 0;
foreach(var seq in seqs)
{
outputs.Add(new CompletionOutput(index, seq.OutputText, seq.OutputTokens,
seq.Status.GetFinishedReason(), seq.StoppingString, seq.StoppingTokenId));
index++;
}

if (seqGroup.IsFinished)
{
seqGroup.SetFinishedTime(DateTime.Now);
}
return new RequestOutput(seqGroup.RequestId, seqGroup.Prompt, seqGroup.PromptTokenIds,
outputs, seqGroup.IsFinished, seqGroup.Metrics);
}

/// <inheritdoc/>
public override string ToString()
{
return ClassStringFormatter.Format(this);
}
}

/// <summary>
/// The output data of one completion output of a request.
/// </summary>
/// <param name="Index">The index of the output in the request.</param>
/// <param name="Text">The generated output text.</param>
/// <param name="TokenIds">The token IDs of the generated output text.</param>
/// <param name="FinishReason">The reason why the sequence is finished.</param>
/// <param name="StoppingString">
/// The stop string that caused the completion to stop,
/// Null if the completion finished for some other reason.
/// </param>
/// <param name="StoppingToken">
/// The stop string that caused the completion to stop,
/// Null if the completion finished for some other reason.
/// </param>
public record class CompletionOutput(
int Index,
string Text,
IList<int> TokenIds,
string FinishReason,
string? StoppingString,
int? StoppingToken
)
{
/// <summary>
/// Whether the completion has finished.
/// </summary>
public bool IsFinished => !string.IsNullOrEmpty(FinishReason);

/// <inheritdoc/>
public override string ToString()
{
return ClassStringFormatter.Format(this);
}
}
}
Loading
Loading