diff --git a/LLama/Experimental/Abstractions/IModelRunner.cs b/LLama/Experimental/Abstractions/IModelRunner.cs new file mode 100644 index 000000000..d89f012f0 --- /dev/null +++ b/LLama/Experimental/Abstractions/IModelRunner.cs @@ -0,0 +1,20 @@ +using LLama.Experimental.Common; +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Abstractions +{ + /// + /// It defines how to execute the model. + /// + public interface IModelRunner: IDisposable + { + /// + /// Deal with the scheduled sequences to get the output. + /// + /// + /// + SamplerOutput ExecuteModel(IEnumerable seqGroupMetadataList); + } +} diff --git a/LLama/Experimental/Abstractions/ISamplingMethod.cs b/LLama/Experimental/Abstractions/ISamplingMethod.cs new file mode 100644 index 000000000..694468a5c --- /dev/null +++ b/LLama/Experimental/Abstractions/ISamplingMethod.cs @@ -0,0 +1,44 @@ +using LLama.Experimental.Common; +using LLama.Experimental.Runner.LLamaCpp; +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Abstractions +{ + /// + /// Method to sample the model output. + /// + public interface ISamplingMethod + // TODO: We should reconsider this design. Maybe it's better to use `SamplingParams` to let user set, + // and choose the actual sampler internally according to the params. + { + /// + ///The maximum number of sequences running in parallel. + /// + /// If you don't know what to return, you can return the default value. + /// + /// Generally, if you want to select several results from n results, you need + /// to set the maximum number of sequences to run. + /// + /// + /// + /// + int GetMaxNumRunningSeqs(int defaultValue, int currentNumSeqs); + + /// + /// Whether to skip special tokens. + /// + bool SkipSpecialTokens { get; } + + /// + /// Sample the sequence logits to get the token. + /// + /// + /// + /// + /// + SequenceOutput SampleSequence(Span logits, int seqId, SamplingMetadata samplingMetadata); + // TODO: maybe we shouldn't expose all the samplingMetadata to users here. + } +} diff --git a/LLama/Experimental/Abstractions/ISchedulingPolicy.cs b/LLama/Experimental/Abstractions/ISchedulingPolicy.cs new file mode 100644 index 000000000..e4d4bbb5e --- /dev/null +++ b/LLama/Experimental/Abstractions/ISchedulingPolicy.cs @@ -0,0 +1,21 @@ +using LLama.Experimental.Common; +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Abstractions +{ + /// + /// Define the scheduling policy, which decides the priority orders of sequences. + /// + public interface ISchedulingPolicy + { + /// + /// Get the priority of a sequence group. + /// + /// + /// + /// + int GetPriority(DateTime now, SequenceGroup seqGroup); + } +} diff --git a/LLama/Experimental/Abstractions/IStoppingCriteria.cs b/LLama/Experimental/Abstractions/IStoppingCriteria.cs new file mode 100644 index 000000000..4a86f94d6 --- /dev/null +++ b/LLama/Experimental/Abstractions/IStoppingCriteria.cs @@ -0,0 +1,30 @@ +using LLama.Experimental.Common; +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Abstractions +{ + /// + /// Stopping criteria that can be applied during generation. + /// + public interface IStoppingCriteria + { + /// + /// Check if the sequence should be stopped and return the status. + /// + /// If it's not supposed to be stopped, be sure to return its current status. + /// + /// + /// + StoppingCriteriaOutput CheckStop(Sequence seq); // TODO: include other params? + } + + /// + /// The output of + /// + /// The sequence status. + /// If the sequence stops because of the appearance of a string, please set it here. + /// If the sequence stops because of the appearance of a token, please set it here. + public record class StoppingCriteriaOutput(SequenceStatus Status, string? StoppingString = null, int? StoppingTokenId = null); +} diff --git a/LLama/Experimental/Abstractions/ITokenizer.cs b/LLama/Experimental/Abstractions/ITokenizer.cs new file mode 100644 index 000000000..4e74973da --- /dev/null +++ b/LLama/Experimental/Abstractions/ITokenizer.cs @@ -0,0 +1,34 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Abstractions +{ + /// + /// The interface for tokenizer in LLamaSharp. It's responsible for converting text to token ids, or vice versa. + /// + public interface ITokenizer + { + // TODO: `ApplyChatTemplate` API + + // TODO: Batched Encode? + + /// + /// Get the token ids from the text + /// + /// + /// + IList Tokenize(string input); + + /// + /// Convert the token ids to text. + /// + /// + /// + /// + /// The consumed tokens for decoding. + int ConvertIdsToText(IEnumerable tokenIds, out string result, bool skipSpecialTokens = false); + + // TODO: decode from Logprobs + } +} diff --git a/LLama/Experimental/Common/ModelRunnerInput.cs b/LLama/Experimental/Common/ModelRunnerInput.cs new file mode 100644 index 000000000..ad02145b9 --- /dev/null +++ b/LLama/Experimental/Common/ModelRunnerInput.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Common +{ + /// + /// The input prepared for model runner. + /// + /// The tokens to feed to the model. + /// The positions of these tokens. + /// The sequence ids of these tokens. + /// Whether the logits need to be computed for the token. + /// The lengths of the prompts if the input is at prefill stage, otherwise empty. + /// The lengths of the subqueries if the input is at prefill stage, otherwise empty. + public record class ModelRunnerInput( + int[] TokenIds, + int[] Positions, + int[] SeqIds, + bool[] WithLogits, + int[] PromptLengths, + int[] SubqueryLengths + ); +} diff --git a/LLama/Experimental/Common/RequestMetrics.cs b/LLama/Experimental/Common/RequestMetrics.cs new file mode 100644 index 000000000..405b201d3 --- /dev/null +++ b/LLama/Experimental/Common/RequestMetrics.cs @@ -0,0 +1,42 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Common +{ + /// + /// Metrics associated with a request. + /// + public class RequestMetrics + { + /// + /// The time when the request arrived. + /// + public DateTime ArrivalTime { get; set; } + + /// + /// The time when the request was first scheduled. + /// + public DateTime? FirstScheduledTime { get; set; } + + /// + /// The time when the first token was generated. + /// + public DateTime? FirstTokenTime { get; set; } + + /// + /// The time when the last token was generated. + /// + public DateTime? LastTokenTime { get; set; } + + /// + /// The time the request spent in the queue. + /// + public TimeSpan? TimeInQueue { get; set; } + + /// + /// The time when the request was finished. + /// + public DateTime? FinishedTime { get; set; } + } +} diff --git a/LLama/Experimental/Common/RequestOutput.cs b/LLama/Experimental/Common/RequestOutput.cs new file mode 100644 index 000000000..095105fcf --- /dev/null +++ b/LLama/Experimental/Common/RequestOutput.cs @@ -0,0 +1,101 @@ +using LLama.Experimental.Utils; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace LLama.Experimental.Common +{ + /// + /// The output data of a request to the LLM. + /// + /// The unique ID of the request. + /// The prompt string of the request. + /// The token IDs of the prompt. + /// The output sequences of the request. + /// Whether the whole request is finished. + /// Metrics associated with the request. + public record class RequestOutput( + string RequestId, + string? Prompt, + IList PromptTokenIds, + IList Outputs, + bool Finished, + RequestMetrics Metrics + ) + { + /// + /// Create an instance from . + /// + /// + /// + /// + public static RequestOutput FromSeqGroup(SequenceGroup seqGroup) + { + var seqs = seqGroup.GetAllSeqs(); + if(seqs.Count() != 1) + { + // TODO: deal with beam search here. + throw new NotImplementedException(); + } + + List outputs = new(); + int index = 0; + foreach(var seq in seqs) + { + outputs.Add(new CompletionOutput(index, seq.OutputText, seq.OutputTokens, + seq.Status.GetFinishedReason(), seq.StoppingString, seq.StoppingTokenId)); + index++; + } + + if (seqGroup.IsFinished) + { + seqGroup.SetFinishedTime(DateTime.Now); + } + return new RequestOutput(seqGroup.RequestId, seqGroup.Prompt, seqGroup.PromptTokenIds, + outputs, seqGroup.IsFinished, seqGroup.Metrics); + } + + /// + public override string ToString() + { + return ClassStringFormatter.Format(this); + } + } + + /// + /// The output data of one completion output of a request. + /// + /// The index of the output in the request. + /// The generated output text. + /// The token IDs of the generated output text. + /// The reason why the sequence is finished. + /// + /// The stop string that caused the completion to stop, + /// Null if the completion finished for some other reason. + /// + /// + /// The stop string that caused the completion to stop, + /// Null if the completion finished for some other reason. + /// + public record class CompletionOutput( + int Index, + string Text, + IList TokenIds, + string FinishReason, + string? StoppingString, + int? StoppingToken + ) + { + /// + /// Whether the completion has finished. + /// + public bool IsFinished => !string.IsNullOrEmpty(FinishReason); + + /// + public override string ToString() + { + return ClassStringFormatter.Format(this); + } + } +} diff --git a/LLama/Experimental/Common/SamplerOutput.cs b/LLama/Experimental/Common/SamplerOutput.cs new file mode 100644 index 000000000..83f9238d8 --- /dev/null +++ b/LLama/Experimental/Common/SamplerOutput.cs @@ -0,0 +1,88 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Common +{ + /// + /// For each sequence group, we generate a list of SequenceOutput object, + /// each of which contains one possible candidate for the next token. + /// + /// This datastructure implements methods so it can be used like a list. + /// + public class SamplerOutput: IList + { + /// + /// The list of objects, which are the outputs of the LLM model. + /// + public List Outputs { get; } + + /// + /// + /// + /// + public SamplerOutput(List outputs) + { + Outputs = outputs; + } + + #region IList Implementation + /// + public SequenceGroupOutput this[int index] { get => Outputs[index]; set => Outputs[index] = value; } + /// + public int Count => Outputs.Count; + /// + public bool IsReadOnly => false; + /// + public void Add(SequenceGroupOutput item) + { + Outputs.Add(item); + } + /// + public void Clear() + { + Outputs.Clear(); + } + /// + public bool Contains(SequenceGroupOutput item) + { + return Outputs.Contains(item); + } + /// + public void CopyTo(SequenceGroupOutput[] array, int arrayIndex) + { + Outputs.CopyTo(array, arrayIndex); + } + /// + public IEnumerator GetEnumerator() + { + return Outputs.GetEnumerator(); + } + /// + public int IndexOf(SequenceGroupOutput item) + { + return Outputs.IndexOf(item); + } + /// + public void Insert(int index, SequenceGroupOutput item) + { + Outputs.Insert(index, item); + } + /// + public bool Remove(SequenceGroupOutput item) + { + return Outputs.Remove(item); + } + /// + public void RemoveAt(int index) + { + Outputs.RemoveAt(index); + } + + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() + { + return Outputs.GetEnumerator(); + } + #endregion + } +} diff --git a/LLama/Experimental/Common/SamplingMetadata.cs b/LLama/Experimental/Common/SamplingMetadata.cs new file mode 100644 index 000000000..737e8fc83 --- /dev/null +++ b/LLama/Experimental/Common/SamplingMetadata.cs @@ -0,0 +1,28 @@ +using LLama.Experimental.Utils; +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Common +{ + /// + /// Metadata for input sequences. Used in sampler. + /// + /// List of seq ids. + /// Seq_id -> SequenceData. + /// Lengths of prompts. + /// Token indices selected for sampling. + public record class SamplingMetadata( + IList SeqIds, + IDictionary SeqData, + IList PromptLengths, + IList SelectedTokenIndices + ) + { + /// + public override string ToString() + { + return ClassStringFormatter.Format(this); + } + } +} diff --git a/LLama/Experimental/Common/SchedulerOutputs.cs b/LLama/Experimental/Common/SchedulerOutputs.cs new file mode 100644 index 000000000..710846c8b --- /dev/null +++ b/LLama/Experimental/Common/SchedulerOutputs.cs @@ -0,0 +1,111 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace LLama.Experimental.Common +{ + /// + /// The scheduling decision made from a scheduler. + /// + /// Scheduled sequence groups. + /// Number of prefill groups scheduled. + /// Total number of batched tokens. + /// Sequence groups that are going to be ignored. + public record class SchedulerOutputs( + IEnumerable ScheduledSeqGroups, + int NumPrefillGroups, + int NumBatchedTokens, + IEnumerable IgnoredSeqGroups + ) + { + /// + /// Whether the scheduler output is empty. + /// + public bool IsEmpty => ScheduledSeqGroups.Count() == 0; + } + + /// + /// + /// + /// A sequence group that's scheduled. + /// + /// The total chunk size (number of tokens) to process for next iteration. + /// 1 for decoding. Same as prompt tokens for prefill, but if prefill is + /// chunked, it can be smaller than that. + /// + public record class ScheduledSequenceGroup(SequenceGroup SeqGroup, int TokenChunkSize); + + /// + /// The requests that are scheduled from a waiting queue. + /// + /// + /// + /// + public record class SchedulerPrefillOutputs( + LinkedList RemainingWaitingQueue, + List SeqGroups, + List IgnoredSeqGroups + ) + { + /// + public static SchedulerPrefillOutputs CreateEmpty() + { + return new SchedulerPrefillOutputs( + new LinkedList(), + new List(), + new List() + ); + } + } + + /// + /// The requests that are scheduled from a running queue. + /// + /// Could contain prefill (prefill that's chunked) or decodes. If there's not + /// enough memory, it can be preempted (for recompute) or swapped out. + /// + /// + /// + /// + /// + /// + public record class SchedulerRunningOutputs( + LinkedList RemainingRunningQueue, + List DecodeSeqGroups, + List PrefillSeqGroups, + List PreemptedSeqGroups, + List SwappedOutSeqGroups + ) + { + /// + public static SchedulerRunningOutputs CreateEmpty() + { + return new SchedulerRunningOutputs( + new LinkedList(), + new List(), + new List(), + new List(), + new List() + ); + } + } + + /// + /// The requests that are scheduled from a swap queue. + /// Could contain prefill (prefill that's chunked) or decodes. + /// + /// + /// + public record class SchedulerSwappedInOutputs( + List DecodeSeqGroups, + List PrefillSeqGroups + ) + { + /// + public static SchedulerSwappedInOutputs CreateEmpty() + { + return new SchedulerSwappedInOutputs(new List(), new List()); + } + } +} diff --git a/LLama/Experimental/Common/SchedulingBudget.cs b/LLama/Experimental/Common/SchedulingBudget.cs new file mode 100644 index 000000000..45822f43f --- /dev/null +++ b/LLama/Experimental/Common/SchedulingBudget.cs @@ -0,0 +1,85 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; + +namespace LLama.Experimental.Common +{ + /// + /// The available slots for scheduling. + /// + internal class SchedulingBudget + { + private HashSet _requestIdsNumBatchedTokens; + + private HashSet _requestIdsNumCurrentSeqs; + + public int TokenBudget { get; set; } + + public int MaxNumSeqs { get; set; } + + public int RemainingTokenBudget => TokenBudget - NumBatchedTokens; + + internal int NumCurrentSeqs { get; set; } + + internal int NumBatchedTokens { get; set; } + + public SchedulingBudget(int tokenBudget, int maxNumSeqs) + { + TokenBudget = tokenBudget; + MaxNumSeqs = maxNumSeqs; + _requestIdsNumBatchedTokens = new HashSet(); + _requestIdsNumCurrentSeqs = new HashSet(); + NumCurrentSeqs = 0; + NumBatchedTokens = 0; + } + + public bool CanSchedule(int numNewTokens, int numNewSeqs) + { + Debug.Assert(numNewTokens >= 0); + Debug.Assert(numNewSeqs >= 0); + return NumBatchedTokens + numNewTokens <= TokenBudget + && NumCurrentSeqs + numNewSeqs <= MaxNumSeqs; + } + + public void AddNumBatchedTokens(string requestId, int numBatchedTokens) + { + if (_requestIdsNumBatchedTokens.Contains(requestId)) + { + return; + } + + _requestIdsNumBatchedTokens.Add(requestId); + NumBatchedTokens += numBatchedTokens; + } + + public void SubtractNumBatchedTokens(string requestId, int numBatchedTokens) + { + if (_requestIdsNumBatchedTokens.Contains(requestId)) + { + _requestIdsNumBatchedTokens.Remove(requestId); + NumBatchedTokens -= numBatchedTokens; + } + } + + public void AddNumSeqs(string requestId, int numCurrentSeqs) + { + if (_requestIdsNumCurrentSeqs.Contains(requestId)) + { + return; + } + + _requestIdsNumCurrentSeqs.Add(requestId); + NumCurrentSeqs += numCurrentSeqs; + } + + public void SubtractNumSeqs(string requestId, int numCurrentSeqs) + { + if (_requestIdsNumCurrentSeqs.Contains(requestId)) + { + _requestIdsNumCurrentSeqs.Remove(requestId); + NumCurrentSeqs -= numCurrentSeqs; + } + } + } +} diff --git a/LLama/Experimental/Common/Sequence.cs b/LLama/Experimental/Common/Sequence.cs new file mode 100644 index 000000000..d7e4b8e99 --- /dev/null +++ b/LLama/Experimental/Common/Sequence.cs @@ -0,0 +1,151 @@ +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Text; + +namespace LLama.Experimental.Common +{ + /// + /// Stores the data, status, and other information of a sequence. + /// + public sealed class Sequence + { + /// + /// The ID of the sequence. + /// + public int Id { get; private set; } + + /// + /// The prompt of the sequence. + /// + public string? Prompt { get; } + + /// + /// Data used for computation of this sequence. + /// + public SequenceData Data { get; private set; } + + /// + /// The output text of the sequence. + /// Note that it should only be set when you want to implement some interfaces yourself. + /// + public string OutputText { get; internal set; } + + /// + /// Input + output token IDs. + /// Note that it should only be set when you want to implement some interfaces yourself. + /// + public IEnumerable TokenIds => Data. TokenIds; + + /// + /// Length of the sequence data. + /// + public int Length => Data.Length; + + /// + /// The status of the sequence. + /// + public SequenceStatus Status { get; internal set; } + + /// + /// The stopping string of the sequence if it stops because of this string. + /// + public string? StoppingString { get; internal set; } + + /// + /// The stopping token of the sequence if it stops because of this token. + /// + public int? StoppingTokenId { get; internal set; } + + /// + /// The offset of the sequence in the decoding process. + /// It's useful when the tokenizer may use more than 1 token id to represent a token. + /// + public int IncrementalDecodingOffset { get; internal set; } + + /// + /// Whether the sequence has finished. + /// + public bool IsFinished + { + get + { + return Status is SequenceStatus.FinishStopped + or SequenceStatus.FinishLengthCapped + or SequenceStatus.FinishAborted + or SequenceStatus.FinishIgnored; + } + } + + /// + /// The output token ids of the sequence. + /// + public IList OutputTokens => Data.OutputTokenIds; + + /// + /// Whether the sequence is at prefill stage. + /// + public bool IsPrefill => Data.Stage == SequenceStage.Prefill; + + /// + /// Get the number of new tokens to be computed. + /// + public int NumNewTokens + { + get + { + if (Data.Stage == SequenceStage.Decode) + { + return 1; + } + return Data.NumUncomputedTokens; + } + } + + /// + /// + /// + /// + /// + /// + public Sequence(int id, string? prompt, IList promptTokens) + { + Id = id; + Prompt = prompt; + Data = new SequenceData(promptTokens); + OutputText = ""; + Status = SequenceStatus.Waiting; + IncrementalDecodingOffset = Data.PromptTokenIds.Count; + + // TODO: deal with incremental detokenization. + } + + /// + /// Add a token id to the output ids of the sequence data. + /// + /// + public void AppendToken(int tokenId) + // TODO: logprobs? + { + Data.AppendToken(tokenId); + } + + /// + /// Get a new sequence with same data but new id. + /// + /// + /// + public Sequence Fork(int newSeqId) + { + // clone the current data. + var clone = (Sequence)MemberwiseClone(); + clone.Data = new SequenceData( + new List(Data.PromptTokenIds), + new List(Data.OutputTokenIds) + ); + // set new id + clone.Id = newSeqId; + return clone; + } + } +} diff --git a/LLama/Experimental/Common/SequenceData.cs b/LLama/Experimental/Common/SequenceData.cs new file mode 100644 index 000000000..0cd284b36 --- /dev/null +++ b/LLama/Experimental/Common/SequenceData.cs @@ -0,0 +1,120 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; + +namespace LLama.Experimental.Common +{ + /// + /// Data associated with a sequence. + /// + public class SequenceData + { + /// + /// The token IDs of the prompt. + /// + public IList PromptTokenIds { get; set; } + + /// + /// The token IDs of the output. Set to an empty list if None. + /// + public List OutputTokenIds { get; set; } + + /// + /// The stage of the sequence data. + /// + public SequenceStage Stage { get; private set; } + + /// + /// The number of all the tokens in the sequence, including prompt and output. + /// + public int Length => OutputTokenIds.Count + PromptTokenIds.Count; + + /// + /// All the token IDs, including prompt and output. + /// + public IEnumerable TokenIds => PromptTokenIds.Concat(OutputTokenIds); + + /// + /// The number of prefill tokens that are already computed. + /// + public int NumComputedTokens { get; private set; } + + /// + /// The number of prefil tokens that are not computed. + /// + public int NumUncomputedTokens => Length - NumComputedTokens; + + /// + /// The last token ID in the sequence. + /// + public int LastTokenId + { + get + { + if(OutputTokenIds.Count == 0) + { + return PromptTokenIds[PromptTokenIds.Count - 1]; + } + return OutputTokenIds[OutputTokenIds.Count - 1]; + } + } + + /// + /// + /// + /// + /// + public SequenceData(IList promptTokens, IEnumerable? outputTokens = null) + { + OutputTokenIds = outputTokens is not null ? new List(outputTokens) : new List(); + PromptTokenIds = promptTokens; + + // TODO: cumulative_logprob? + NumComputedTokens = 0; + Stage = SequenceStage.Prefill; + } + + /// + /// Add a token id to the output token ids. + /// + /// + public void AppendToken(int tokenId) + { + OutputTokenIds.Add(tokenId); + } + + /// + /// Update number of tokens computed so far. + /// + /// + public void UpdateNumComputedTokens(int numNewComputedTokens) + { + NumComputedTokens += numNewComputedTokens; + Debug.Assert(NumComputedTokens <= Length); + if(NumUncomputedTokens == 0) + { + Stage = SequenceStage.Decode; + } + } + + /// + /// Reset the number of computed tokens from this sequence. It is + /// supposed to be called when a sequence needs to be started from + /// the beginning again(e.g., sequence is preempted). + /// + public void ResetStageForRecompute() + { + NumComputedTokens = 0; + Stage = SequenceStage.Prefill; + } + + /// + public override string ToString() + { + return $"SequenceData(\n PromptTokens: {string.Join(", ", PromptTokenIds)}, \n " + + $"OutputTokens: {string.Join(", ", OutputTokenIds)}, Stage: {Stage}\n)"; + } + } +} diff --git a/LLama/Experimental/Common/SequenceGroup.cs b/LLama/Experimental/Common/SequenceGroup.cs new file mode 100644 index 000000000..eb5a14b3c --- /dev/null +++ b/LLama/Experimental/Common/SequenceGroup.cs @@ -0,0 +1,268 @@ +using LLama.Experimental.Abstractions; +using System; +using System.Collections.Generic; +using System.Linq; + +namespace LLama.Experimental.Common +{ + /// + /// A group of sequences that are generated from the same prompt. + /// + public sealed class SequenceGroup + // TODO: Multi-modal data + { + /// + /// The ID of the request. + /// + public string RequestId { get; } + + /// + /// Mapping from seq ID to the sequence. + /// + public IDictionary SeqDict { get; } + + /// + /// The sampling method to do the sampling. + /// + public ISamplingMethod SamplingMethod { get; set; } + + /// + /// The stopping criteria to decide whether the generation of the sequence should be stopped. + /// + public IStoppingCriteria StoppingCriteria { get; set; } + + /// + /// The metrics for the scheduling and inference of this sequence group. + /// + public RequestMetrics Metrics { get; } + + /// + /// The common prompt of the sequences in this sequence group. + /// + public string? Prompt + { + get + { + // All sequences in the group should have the same prompt. + // We use the prompt of an arbitrary sequence. + return SeqDict.First().Value.Prompt; + } + } + + /// + /// The prompt tokens of the sequences in this sequence group. + /// + public IList PromptTokenIds + { + get + { + return SeqDict.First().Value.Data.PromptTokenIds; + } + } + + /// + /// Whether the request of this sequence group has been finished. + /// + public bool IsFinished + { + get + { + return SeqDict.Values.All(seq => seq.IsFinished); + } + } + + /// + /// Whether this sequence group is at prefill stage. + /// + public bool IsPrefill + { + get + { + return SeqDict.Values.First().IsPrefill; + } + } + + /// + /// The number of sequences in this sequence group. + /// + public int NumSeqs => SeqDict.Count; + + /// + /// The number of unfinished sequences in this sequence group. + /// + public int NumUnfinishedSeqs => GetUnfinishedSeqs().Count(); + + /// + /// Initializes a new instance of the class. + /// + /// + /// + /// + /// + /// + /// + public SequenceGroup(string requestId, Sequence[] sequences, ISamplingMethod samplingMethod, IStoppingCriteria stoppingCriteria, DateTime arrivalTime) + { + if(sequences.Length == 0) + { + throw new ArgumentException($"The sequences bypassed to SequenceGroup cannot be empty."); + } + + RequestId = requestId; + SeqDict = sequences.ToDictionary(Sequence => Sequence.Id, Sequence => Sequence); + SamplingMethod = samplingMethod; + StoppingCriteria = stoppingCriteria; + Metrics = new RequestMetrics() + { + ArrivalTime = arrivalTime + }; + } + + /// + /// Sets the first token time for Request level timings. + /// + /// + public void MaybeSetFirstTokenTime(DateTime time) + { + if (Metrics.FirstTokenTime is null) + { + Metrics.FirstTokenTime = time; + } + } + + /// + /// Sets the first scheduled time and time in queue for Request level timings + /// + /// + public void MaybeSetFirstScheduledTime(DateTime time) + { + if (Metrics.FirstScheduledTime is null) + { + Metrics.FirstScheduledTime = time; + Metrics.TimeInQueue = time - Metrics.ArrivalTime; + } + } + + /// + /// Sets the finished time for Request level timings. + /// + /// + public void SetFinishedTime(DateTime time) + { + Metrics.FinishedTime = time; + } + + /// + /// Get all sequences with the given status. + /// + /// + /// + public IEnumerable GetSeqsWithStatus(SequenceStatus status) + { + return SeqDict.Values.Where(seq => seq.Status == status); + } + + /// + /// Get all sequences in this sequence group. + /// + /// + public IEnumerable GetAllSeqs() + { + return SeqDict.Values; + } + + /// + /// Get all unfinished sequences in this sequence group. + /// + /// + public IEnumerable GetUnfinishedSeqs() + { + return SeqDict.Values.Where(seq => !seq.IsFinished); + } + + /// + /// Get all finished sequences in this sequence group. + /// + /// + public IEnumerable GetFinishedSeqs() + { + return SeqDict.Values.Where(seq => seq.IsFinished); + } + + /// + /// The maximum number of sequences running in parallel in the remaining + /// lifetime of the request. + /// + /// + public int GetMaxNumRunningSeqs() + { + int defaultValue = NumUnfinishedSeqs; + return SamplingMethod.GetMaxNumRunningSeqs(defaultValue, NumSeqs); + } + + /// + /// Add a new sequence. + /// + /// + /// + public void Add(Sequence seq) + { + if (SeqDict.ContainsKey(seq.Id)) + { + throw new ArgumentException($"Sequence {seq.Id} already exists."); + } + SeqDict[seq.Id] = seq; + } + + /// + /// Remove the sequence of seq id. + /// + /// + /// + public void Remove(int seqId) + { + if (!SeqDict.ContainsKey(seqId)) + { + throw new ArgumentException($"Sequence {seqId} not found."); + } + SeqDict.Remove(seqId); + } + + /// + /// Get the number of tokens to be computed in this sequence group. + /// + public int GetNumComputedTokens() + { + int numUncomputedTokens = 0; + foreach(var seq in GetAllSeqs()) + { + numUncomputedTokens += seq.Data.NumUncomputedTokens; + } + return numUncomputedTokens; + } + + /// + /// Update number of tokens computed so far. + /// + /// + public void UpdateNumComputedTokens(int numNewComputedTokens) + { + foreach(var seq in SeqDict.Values) + { + if (!seq.IsFinished) + { + seq.Data.UpdateNumComputedTokens(numNewComputedTokens); + } + } + } + + /// + public override string ToString() + { + return $"SequenceGroup(RequestId = {RequestId}, \n " + + $"SamplingMethod = ({SamplingMethod.GetType().Name}), \n " + + $"StoppingCriteria = ({StoppingCriteria.GetType().Name}), \n " + + $"NumSeqs = {SeqDict.Count}\n)"; + } + } +} diff --git a/LLama/Experimental/Common/SequenceGroupMetadata.cs b/LLama/Experimental/Common/SequenceGroupMetadata.cs new file mode 100644 index 000000000..bf9ff1d0a --- /dev/null +++ b/LLama/Experimental/Common/SequenceGroupMetadata.cs @@ -0,0 +1,79 @@ +using LLama.Experimental.Abstractions; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace LLama.Experimental.Common +{ + /// + /// Metadata for a sequence group. + /// + public class SequenceGroupMetadata + { + /// + /// The ID of the request. + /// + public string RequestId { get; set; } + + /// + /// Whether the request is at prompt stage. + /// + public bool IsPrompt { get; set; } + + /// + /// The sequence data. (Seq id -> sequence data) + /// + public Dictionary SeqData { get; set; } + + /// + /// The sampling method used to generate the outputs. + /// + public ISamplingMethod SamplingMethod { get; set; } + + /// + /// The stopping criteria to decide whether the generation of the sequence should be stopped. + /// + public IStoppingCriteria StoppingCriteria { get; set; } + + /// + /// The number of tokens to be processed (per sequence). + /// + public int TokenChunkSize { get; set; } + + /// + /// + /// + /// + /// + /// + /// + /// + /// + public SequenceGroupMetadata(string requestId, bool isPrompt, Dictionary seqData, + ISamplingMethod samplingMethod, IStoppingCriteria stoppingCriteria, int? tokenChunkSize) + { + RequestId = requestId; + IsPrompt = isPrompt; + SeqData = seqData; + SamplingMethod = samplingMethod; + StoppingCriteria = stoppingCriteria; + + if(tokenChunkSize is null) + { + if (isPrompt) + { + TokenChunkSize = seqData.Values.First().Length; + } + else + { + TokenChunkSize = 1; + } + } + else + { + TokenChunkSize = tokenChunkSize.Value; + } + } + } +} diff --git a/LLama/Experimental/Common/SequenceGroupOutput.cs b/LLama/Experimental/Common/SequenceGroupOutput.cs new file mode 100644 index 000000000..bdacf70d6 --- /dev/null +++ b/LLama/Experimental/Common/SequenceGroupOutput.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Common +{ + /// + /// The model output associated with a sequence group. + /// + /// + public record class SequenceGroupOutput(List Samples) + { + + } +} diff --git a/LLama/Experimental/Common/SequenceOutput.cs b/LLama/Experimental/Common/SequenceOutput.cs new file mode 100644 index 000000000..76bab7791 --- /dev/null +++ b/LLama/Experimental/Common/SequenceOutput.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Common +{ + /// + /// The model output associated with a sequence. + /// + public class SequenceOutput + // TODO: Beam search + { + /// + /// The output token ID. + /// + public int OutputTokenId { get; init; } + + /// + /// The ID of the parent sequence (for forking in beam search). + /// + public int ParentSeqId { get; init; } + + /// + /// The logprobs of the output token. + /// (Token id -> logP(x_i+1 | x_0, ..., x_i)) + /// + public float[]? Logprobs { get; init; } + } +} diff --git a/LLama/Experimental/Common/SequenceStage.cs b/LLama/Experimental/Common/SequenceStage.cs new file mode 100644 index 000000000..2fd76d35e --- /dev/null +++ b/LLama/Experimental/Common/SequenceStage.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Common +{ + /// + /// The sequence stage for . + /// + public enum SequenceStage + { + /// + /// The prefill stage, in which the model is processing your prompt. + /// + Prefill, + + /// + /// The decode stage, in which the model is generating the output. + /// + Decode + } +} diff --git a/LLama/Experimental/Common/SequenceStatus.cs b/LLama/Experimental/Common/SequenceStatus.cs new file mode 100644 index 000000000..7dadf5cba --- /dev/null +++ b/LLama/Experimental/Common/SequenceStatus.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Common +{ + /// + /// Status of a sequence. + /// + public enum SequenceStatus + { + /// + /// The sequence is waiting for scheduling. + /// + Waiting, + + /// + /// The sequence is running. + /// + Running, + + /// + /// The sequence has been swapped out due to some reasons. + /// + Swapped, + + /// + /// The sequence has been finished because it's stopped by a stopping criteria. + /// + FinishStopped, + + /// + /// The sequence has been finished because it reaches the maximum length. + /// + FinishLengthCapped, + + /// + /// The sequence has been finished because it's aborted. + /// + FinishAborted, + + /// + /// The sequence will never be processed for some reasons. Please check if the prompt length is too long. + /// + FinishIgnored + } + + /// + public static class SequenceStatusExtensions + { + /// + /// Get the finished reason in OpenAI style + /// + /// + /// + public static string GetFinishedReason(this SequenceStatus status) + { + return status switch + { + SequenceStatus.FinishStopped => "stop", + SequenceStatus.FinishLengthCapped => "length", + SequenceStatus.FinishAborted => "abort", + SequenceStatus.FinishIgnored => "length", + _ => "" + }; + } + } +} diff --git a/LLama/Experimental/Config/DeviceConfig.cs b/LLama/Experimental/Config/DeviceConfig.cs new file mode 100644 index 000000000..b2b7d589b --- /dev/null +++ b/LLama/Experimental/Config/DeviceConfig.cs @@ -0,0 +1,44 @@ +using LLama.Native; +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Config +{ + /// + /// Device configuration for using LLM. + /// + public class DeviceConfig + { + /// + /// main_gpu interpretation depends on split_mode: + /// + /// + /// None + /// The GPU that is used for the entire mode. + /// + /// + /// Row + /// The GPU that is used for small tensors and intermediate results. + /// + /// + /// Layer + /// Ignored. + /// + /// + /// + public int MainGpu { get; set; } = 0; + + /// + /// How to split the model across multiple GPUs + /// + public GPUSplitMode SplitMode { get; set; } = GPUSplitMode.None; + + /// + /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) + /// + public int GpuLayerCount { get; set; } = 20; + + // TODO: Add a static method/property "Auto" to return a default DeviceConfig + } +} diff --git a/LLama/Experimental/Config/KvCacheConfig.cs b/LLama/Experimental/Config/KvCacheConfig.cs new file mode 100644 index 000000000..59f271a7b --- /dev/null +++ b/LLama/Experimental/Config/KvCacheConfig.cs @@ -0,0 +1,17 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Config +{ + /// + /// Configuration for the KV cache. + /// + public class KvCacheConfig + { + /// + /// The maximum CPU memory space used for saving kv cache swapped from GPU. + /// + public int MaxSwapSpace { get; set; } + } +} diff --git a/LLama/Experimental/Config/SchedulerConfig.cs b/LLama/Experimental/Config/SchedulerConfig.cs new file mode 100644 index 000000000..2d4c0eecb --- /dev/null +++ b/LLama/Experimental/Config/SchedulerConfig.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Config +{ + /// + /// Scheduler configuration. + /// + public class SchedulerConfig + { + /// + /// Maximum number of tokens to be processed in a single iteration. + /// + public int MaxNumBatchedTokens { get; set; } + + /// + /// Maximum number of sequences to be processed in a single iteration. + /// + public int MaxNumSequences { get; set; } + + /// + /// Maximum length of a sequence (including prompt and generated text). + /// + public int MaxSequenceLength { get; set; } + + /// + /// If True, prefill requests can be chunked based on the remaining max_num_batched_tokens. + /// + public bool EnableChunkedPrefill { get; set; } + + /// + /// Apply a delay (of delay factor multiplied by previous prompt latency) before scheduling next prompt. + /// + public float DelayFactor { get; set; } + + public SchedulerConfig(int maxNumBatchedTokens, int maxNumSequences, int maxSequenceLength, bool enableChunkedPrefill = false, float delayFactor = .0f) + { + MaxNumBatchedTokens = maxNumBatchedTokens; + MaxNumSequences = maxNumSequences; + MaxSequenceLength = maxSequenceLength; + EnableChunkedPrefill = enableChunkedPrefill; + DelayFactor = delayFactor; + } + + + + /// + /// Verify if this configuration is valid and throw an exception if it's invalid. + /// + /// + public void ThrowIfInvalid() + { + if (MaxNumBatchedTokens < MaxSequenceLength && !EnableChunkedPrefill) + { + throw new ArgumentException($"MaxNumBatchedTokens ({MaxNumBatchedTokens}) is smaller than " + + $"MaxSequenceLength ({MaxSequenceLength}). This effectively limits the maximum sequence length to " + + $"MaxNumBatchedTokens. Please increase MaxNumBatchedTokens, decrease MaxSequenceLength or enable chunked prefill."); + } + + if (MaxNumBatchedTokens < MaxNumSequences) + { + throw new ArgumentException($"MaxNumBatchedTokens ({MaxNumBatchedTokens}) must be greater than or equal to " + + $"MaxNumSequences ({MaxNumSequences})."); + } + } + } +} diff --git a/LLama/Experimental/Core/AntipromptStoppingCriteria.cs b/LLama/Experimental/Core/AntipromptStoppingCriteria.cs new file mode 100644 index 000000000..b4edcd983 --- /dev/null +++ b/LLama/Experimental/Core/AntipromptStoppingCriteria.cs @@ -0,0 +1,31 @@ +using LLama.Experimental.Abstractions; +using LLama.Experimental.Common; +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Core +{ + // TODO: This is only the most simple implementation to run the test now. We should replace it in the future. + public class AntipromptStoppingCriteria: IStoppingCriteria + { + private string[] _antiprompts; + + public AntipromptStoppingCriteria(string[] antiprompts) + { + _antiprompts = antiprompts; + } + + public StoppingCriteriaOutput CheckStop(Sequence seq) + { + foreach (var antiprompt in _antiprompts) + { + if (seq.OutputText.EndsWith(antiprompt)) + { + return new StoppingCriteriaOutput(SequenceStatus.FinishStopped, antiprompt, null); + } + } + return new StoppingCriteriaOutput(seq.Status, null, null); + } + } +} diff --git a/LLama/Experimental/Core/LLamaCpp/LLamaGreedySamplingMethod.cs b/LLama/Experimental/Core/LLamaCpp/LLamaGreedySamplingMethod.cs new file mode 100644 index 000000000..ef3ffa525 --- /dev/null +++ b/LLama/Experimental/Core/LLamaCpp/LLamaGreedySamplingMethod.cs @@ -0,0 +1,41 @@ +using LLama.Experimental.Abstractions; +using LLama.Experimental.Common; +using LLama.Native; +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Core.LLamaCpp +{ + // TODO: This is only the most simple implementation to run the example. It should be replaced in the future. + public class LLamaGreedySamplingMethod: ISamplingMethod + { + private LLamaContext _context; + + public LLamaGreedySamplingMethod(LLamaContext context) + { + _context = context; + } + + public int GetMaxNumRunningSeqs(int defaultValue, int currentNumSeqs) + { + return defaultValue; + } + + /// + /// Whether to skip special tokens. + /// + public bool SkipSpecialTokens => false; + + public SequenceOutput SampleSequence(Span logits, int seqId, SamplingMetadata samplingMetadata) + { + // Process token data array to select a final token + var candidates = LLamaTokenDataArray.Create(logits); + return new SequenceOutput() + { + OutputTokenId = (int)candidates.SampleTokenGreedy(_context.NativeHandle), + ParentSeqId = seqId + }; + } + } +} diff --git a/LLama/Experimental/Core/LLamaCpp/LLamaTokenizer.cs b/LLama/Experimental/Core/LLamaCpp/LLamaTokenizer.cs new file mode 100644 index 000000000..a798af822 --- /dev/null +++ b/LLama/Experimental/Core/LLamaCpp/LLamaTokenizer.cs @@ -0,0 +1,37 @@ +using LLama.Experimental.Abstractions; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace LLama.Experimental.Core.LLamaCpp +{ + /// + /// llama.cpp tokenizer. + /// + public sealed class LLamaTokenizer: ITokenizer + { + private LLamaContext _context; + + public LLamaTokenizer(LLamaContext context) + { + _context = context; + } + + /// + public IList Tokenize(string input) + { + // TODO: refactor this!! + return _context.Tokenize(input).Select(x => ((int)x)).ToArray(); + } + + /// + public int ConvertIdsToText(IEnumerable tokenIds, out string result, bool skipSpecialTokens = false) + { + // TODO: integrate `StreamingDecoder` here. Currently only English has been supported. + // We should add a byte array to `sequence`. + result = _context.DeTokenize(tokenIds.Select(x => (Native.LLamaToken)x).ToArray()); + return tokenIds.Count(); + } + } +} diff --git a/LLama/Experimental/Core/SequenceLengthStoopingCriteria.cs b/LLama/Experimental/Core/SequenceLengthStoopingCriteria.cs new file mode 100644 index 000000000..3901d9df6 --- /dev/null +++ b/LLama/Experimental/Core/SequenceLengthStoopingCriteria.cs @@ -0,0 +1,28 @@ +using LLama.Experimental.Abstractions; +using LLama.Experimental.Common; +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Core +{ + // TODO: This is only the most simple implementation to run the test now. We should replace it in the future. + public class SequenceLengthStoopingCriteria: IStoppingCriteria + { + private int _maxSequenceLength; + + public SequenceLengthStoopingCriteria(int maxSequenceLength) + { + _maxSequenceLength = maxSequenceLength; + } + + public StoppingCriteriaOutput CheckStop(Sequence seq) + { + if(seq.Length >= _maxSequenceLength) + { + return new StoppingCriteriaOutput(SequenceStatus.FinishLengthCapped, null, null); + } + return new StoppingCriteriaOutput(seq.Status, null, null); + } + } +} diff --git a/LLama/Experimental/DeTokenizer.cs b/LLama/Experimental/DeTokenizer.cs new file mode 100644 index 000000000..6d8b8fa0e --- /dev/null +++ b/LLama/Experimental/DeTokenizer.cs @@ -0,0 +1,46 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel.Design; +using System.Linq; +using System.Text; +using LLama.Experimental.Abstractions; +using LLama.Experimental.Common; +using LLama.Extensions; + +namespace LLama.Experimental +{ + /// + /// Defines the process of converting sequence output to text. + /// + /// We should not expose this class to users. Implementing + /// should be the only thing the user need to concern to customize the tokenizing and detokenizing. + /// + internal static class DeTokenizer + { + private static int INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5; + + /// + /// Decodes the new token for a sequence. In-place operation. + /// + /// + /// + /// + public static void DecodeSequenceInplace(Sequence seq, ITokenizer tokenizer, ISamplingMethod samplingMethod) + { + var allInputIds = seq.TokenIds; + var (offset, text) = DetokenizeIncrementally(tokenizer, allInputIds, seq.IncrementalDecodingOffset, skipSpecialTokens: true); + + // TODO: deal with logprobs. + + seq.IncrementalDecodingOffset = offset; + seq.OutputText += text; + } + + private static (int, string) DetokenizeIncrementally(ITokenizer tokenizer, IEnumerable allInputIds, int offset, bool skipSpecialTokens = false) + { + var consumedTokens = tokenizer.ConvertIdsToText(allInputIds.Skip(offset), out var text, skipSpecialTokens); + offset += consumedTokens; + return (offset, text); + } + } +} diff --git a/LLama/Experimental/Extensions/DequeExtensions.cs b/LLama/Experimental/Extensions/DequeExtensions.cs new file mode 100644 index 000000000..8a923bd1c --- /dev/null +++ b/LLama/Experimental/Extensions/DequeExtensions.cs @@ -0,0 +1,90 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Extensions +{ + /// + /// Extension to use as a deque. + /// + public static class DequeExtensions + { + /// + public static void AddFront(this LinkedList deque, T item) + { + deque.AddFirst(item); + } + + /// + public static void AddBack(this LinkedList deque, T item) + { + deque.AddLast(item); + } + + /// + public static T RemoveFront(this LinkedList deque) + { + if (deque.Count == 0) + { + throw new InvalidOperationException("The deque is empty."); + } + + T item = deque.First!.Value; + deque.RemoveFirst(); + return item; + } + + /// + public static T RemoveBack(this LinkedList deque) + { + if (deque.Count == 0) + { + throw new InvalidOperationException("The deque is empty."); + } + + T item = deque.Last!.Value; + deque.RemoveLast(); + return item; + } + + /// + public static T PeekFront(this LinkedList deque) + { + if (deque.Count == 0) + { + throw new InvalidOperationException("The deque is empty."); + } + + return deque.First!.Value; + } + + /// + public static T PeekBack(this LinkedList deque) + { + if (deque.Count == 0) + { + throw new InvalidOperationException("The deque is empty."); + } + + return deque.Last!.Value; + } + + /// + public static void ExtendFront(this LinkedList deque, IEnumerable items) + { + foreach (var item in items) + { + deque.AddFront(item); + } + } + + /// + public static void ExtendBack(this LinkedList deque, IEnumerable items) + { + foreach (var item in items) + { + deque.AddBack(item); + } + } + } +} diff --git a/LLama/Experimental/Extensions/SchedulingPolicyExtensions.cs b/LLama/Experimental/Extensions/SchedulingPolicyExtensions.cs new file mode 100644 index 000000000..c168df392 --- /dev/null +++ b/LLama/Experimental/Extensions/SchedulingPolicyExtensions.cs @@ -0,0 +1,25 @@ +using LLama.Experimental.Abstractions; +using LLama.Experimental.Common; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace LLama.Experimental.Extensions +{ + /// + public static class SchedulingPolicyExtensions + { + /// + /// Sorts the sequence groups by priority. + /// + /// + /// + /// + /// + public static LinkedList SortByPriority(this ISchedulingPolicy policy, DateTime now, LinkedList seqGroups) + { + return new LinkedList(seqGroups.OrderByDescending(seqGroups => policy.GetPriority(now, seqGroups))); + } + } +} diff --git a/LLama/Experimental/GlobalConfig.cs b/LLama/Experimental/GlobalConfig.cs new file mode 100644 index 000000000..3286047c4 --- /dev/null +++ b/LLama/Experimental/GlobalConfig.cs @@ -0,0 +1,29 @@ +using LLama.Experimental.Abstractions; +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental +{ + /// + /// Global configuration for LLamaSharp. + /// + public class GlobalConfig + { + public static ISamplingMethod DefaultSamplingMethod + { + get + { + throw new NotImplementedException(); + } + } + + public static IStoppingCriteria DefaultStoppingCriteria + { + get + { + throw new NotImplementedException(); + } + } + } +} diff --git a/LLama/Experimental/KvCacheManager.cs b/LLama/Experimental/KvCacheManager.cs new file mode 100644 index 000000000..b476678c0 --- /dev/null +++ b/LLama/Experimental/KvCacheManager.cs @@ -0,0 +1,31 @@ +using LLama.Experimental.Common; +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental +{ + public class KvCacheManager + { + public bool CanAppendSlots(SequenceGroup seqGroup) + { + return true; + } + + public AllocStatus CanAllocate(SequenceGroup seqGroup) + { + return AllocStatus.OK; + } + + public void Allocate(SequenceGroup seqGroup) + { + } + } + + public enum AllocStatus + { + OK, + Later, + Never + } +} diff --git a/LLama/Experimental/LLM.cs b/LLama/Experimental/LLM.cs new file mode 100644 index 000000000..4cbcf1960 --- /dev/null +++ b/LLama/Experimental/LLM.cs @@ -0,0 +1,180 @@ +using LLama.Experimental.Abstractions; +using LLama.Experimental.Common; +using LLama.Experimental.Config; +using LLama.Experimental.Utils; +using Microsoft.Extensions.Logging; +using System; +using System.Collections.Generic; +using System.Data; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; + +namespace LLama.Experimental +{ +//#if NET8_0_OR_GREATER +// [Experimental("LLM")] +//#endif + public sealed class LLM + { + private LLMEngine _engine; + + private IdCounter _counter; + + /// + /// Get or set the tokenizer used for this . + /// + public ITokenizer Tokenizer + { + get => _engine.Tokenizer; + set => _engine.Tokenizer = value; + } + + public LLM(IModelRunner modelRunner, ITokenizer tokenizer, SchedulerConfig schedulerConfig, ILogger? logger = null) + { + _engine = new LLMEngine(schedulerConfig, modelRunner, tokenizer, logger); + _counter = new(); + } + + /// + /// Generates the completions for the input prompt. If you have multiple inputs, + /// please use , + /// instead of calling this method multiple times. + /// + /// A prompt string. + /// The sampling parameters for text generation. If null, we use the default sampling parameters. + /// + /// The criteria to control whether a sequence generation should be stopped. If null, we use the default stopping criteria. + /// + /// The callback to get the progress of the generation. + /// A list of objects containing the generated completions in the same order as the input prompts. + public RequestOutput[] Generate(string prompt, ISamplingMethod? samplingMethod = null, + IStoppingCriteria? stoppingCriteria = null, ProgressCallback? progressCallback = null) + { + return Generate([prompt], samplingMethod, stoppingCriteria, progressCallback); + } + + /// + /// Generates the completions for the input prompt. If you have multiple inputs, + /// please use + /// instead of calling this method multiple times. + /// + /// Token ids. + /// The sampling parameters for text generation. If null, we use the default sampling parameters. + /// + /// The criteria to control whether a sequence generation should be stopped. If null, we use the default stopping criteria. + /// + /// The callback to get the progress of the generation. + /// A list of objects containing the generated completions in the same order as the input prompt ids. + public RequestOutput[] Generate(IList promptTokenIds, ISamplingMethod? samplingMethod = null, + IStoppingCriteria? stoppingCriteria = null, ProgressCallback? progressCallback = null) + { + return Generate([promptTokenIds], samplingMethod, stoppingCriteria, progressCallback); + } + + /// + /// Generates the completions for the input prompts. + /// This class automatically batches the given prompts, considering + /// the memory constraint. For the best performance, please put all of your prompts + /// into a single list and pass it to this method. + /// + /// A list of prompts to generate completions for. + /// The sampling parameters for text generation. If null, we use the default sampling parameters. + /// + /// The criteria to control whether a sequence generation should be stopped. If null, we use the default stopping criteria. + /// + /// The callback to get the progress of the generation. + /// A list of objects containing the generated completions in the same order as the input prompts. + public RequestOutput[] Generate(IEnumerable prompts, ISamplingMethod? samplingMethod = null, + IStoppingCriteria? stoppingCriteria = null, ProgressCallback? progressCallback = null) + { + if(prompts.Count() == 0) + { + return []; + } + samplingMethod ??= GlobalConfig.DefaultSamplingMethod; + stoppingCriteria ??= GlobalConfig.DefaultStoppingCriteria; + + // Add requests to the engine. + foreach(var prompt in prompts) + { + AddRequest(prompt, samplingMethod, stoppingCriteria); + } + return RunEngine(progressCallback); + } + + /// + /// Generates the completions for the input prompt ids list. + /// This class automatically batches the given prompts, considering + /// the memory constraint. For the best performance, please put all of your prompts + /// into a single list and pass it to this method. + /// + /// A list of token ids to generate completion for. + /// The sampling parameters for text generation. If null, we use the default sampling parameters. + /// + /// The criteria to control whether a sequence generation should be stopped. If null, we use the default stopping criteria. + /// + /// The callback to get the progress of the generation. + /// A list of objects containing the generated completions in the same order as the input prompt ids. + public RequestOutput[] Generate(IList> promptTokenIds, ISamplingMethod? samplingMethod = null, + IStoppingCriteria? stoppingCriteria = null, ProgressCallback? progressCallback = null) + { + if(promptTokenIds.Count == 0) + { + return []; + } + samplingMethod ??= GlobalConfig.DefaultSamplingMethod; + stoppingCriteria ??= GlobalConfig.DefaultStoppingCriteria; + + // Add requests to the engine. + foreach(var prompt in promptTokenIds) + { + AddRequest(null, samplingMethod, stoppingCriteria); + } + return RunEngine(progressCallback); + } + + private void AddRequest(string? prompt, ISamplingMethod samplingMethod, IStoppingCriteria stoppingCriteria, IList? promptTokenIds = null) + { + var requestId = _counter.Next().ToString(); + _engine.AddRequest(requestId, prompt, samplingMethod, stoppingCriteria, promptTokenIds, DateTime.Now); + } + + private RequestOutput[] RunEngine(ProgressCallback? callback) + { + float numRequests = _engine.NumUnfinishedRequests; + Debug.Assert(numRequests - 0 > 0.0001f); // assert the number of requests is not zero + int completedRequests = 0; + List outputs = new(); + while (_engine.HasUnfinishedRequests) + { + var stepOutputs = _engine.Step(); + foreach(var output in stepOutputs) + { + if (output.Finished) + { + outputs.Add(output); + if(callback is not null) + { + completedRequests++; + callback(completedRequests / numRequests); + } + } + } + } + // Sort the outputs by request ID. + // This is necessary because some requests may be finished earlier than its previous requests. + return outputs.OrderBy(o => o.RequestId).ToArray(); + } + + + /// + /// A callback function to used for reporting the progress of the generation. + /// It will be called every time a new request is completed. + /// + /// The progress in percentage. + public delegate void ProgressCallback(float progress); + } + +} + diff --git a/LLama/Experimental/LLMEngine.cs b/LLama/Experimental/LLMEngine.cs new file mode 100644 index 000000000..ce568d80d --- /dev/null +++ b/LLama/Experimental/LLMEngine.cs @@ -0,0 +1,229 @@ +using System; +using System.Collections.Generic; +using LLama.Experimental.Abstractions; +using Microsoft.Extensions.Logging; +using LLama.Experimental.Common; +using LLama.Experimental.Utils; +using LLama.Extensions; +using System.Linq; +using System.Diagnostics; +using LLama.Experimental.Config; + +namespace LLama.Experimental +{ + /// + /// An LLM engine that receives requests and generates texts. + /// + /// It receives requests + /// from clients and generates texts from the LLM.It includes a tokenizer, a + /// language model, and GPU memory space allocated for intermediate states(aka KV cache). + /// This class utilizes iteration-level scheduling and efficient memory management + /// to maximize the serving throughput. + /// + internal sealed class LLMEngine + { + private ILogger? _logger; + + private IdCounter _seqCounter; + + public Scheduler Scheduler { get; } + + public IModelRunner ModelRunner { get; } + + public ITokenizer Tokenizer { get; set; } + + /// + /// Gets the number of unfinished requests. + /// + public int NumUnfinishedRequests => Scheduler.GetNumUnfinishedSeqGroups(); + + /// + /// Returns True if there are unfinished requests. + /// + public bool HasUnfinishedRequests => Scheduler.HasUnfinishedSeqs(); + + public LLMEngine(SchedulerConfig schedulerConfig, IModelRunner modelRunner, ITokenizer tokenizer, ILogger? logger = null) + { + _seqCounter = new(); + Scheduler = new Scheduler(schedulerConfig, new KvCacheConfig(), logger); + Tokenizer = tokenizer; + _logger = logger; + ModelRunner = modelRunner; + } + + /// + /// Performs one decoding iteration and returns newly generated results. + /// + /// Details: + /// - Step 1: Schedules the sequences to be executed in the next + /// iteration and the token blocks to be swapped in/out/copy. + /// + /// - Depending on the scheduling policy, + /// sequences may be `preempted/reordered`. + /// - A Sequence Group(SG) refer to a group of sequences + /// that are generated from the same prompt. + /// + /// - Step 2: Calls the distributed executor to execute the model. + /// - Step 3: Processes the model output. This mainly includes: + /// + /// - Decodes the relevant outputs. + /// - Updates the scheduled sequence groups with model outputs + /// based on its `sampling parameters` (`use_beam_search` or not). + /// - Frees the finished sequence groups. + /// + /// - Finally, it creates and returns the newly generated results. + /// + /// + public List Step() + { + var (seqGroupMetadataList, schedulerOutputs) = Scheduler.Schedule(); + var output = !schedulerOutputs.IsEmpty ? ModelRunner.ExecuteModel(seqGroupMetadataList) : new SamplerOutput([]); + return ProcessModelOutputs(output, schedulerOutputs); + } + + /// + /// Add a request to the engine's request pool. + /// + /// The request is added to the request pool and will be processed by the + /// scheduler as `engine.step()` is called.The exact scheduling policy is + /// determined by the scheduler. + /// + /// The unique ID of the request. + /// The prompt string. Can be Null or empty if prompt_token_ids is provided. + /// The sampling parameters for text generation. + /// The stopping criteria to decide whether the generation should be stopped. + /// The token IDs of the prompt. If Null, we use the tokenizer to convert the prompts to token IDs. + /// The arrival time of the request. If Null, we use the current monotonic time. + public void AddRequest(string requestId, string? prompt, ISamplingMethod samplingMethod, IStoppingCriteria stoppingCriteria, IList? promptTokenIds = null, DateTime? arrivalTime = null) + { + arrivalTime ??= DateTime.Now; + if(promptTokenIds is null) + { + Debug.Assert(prompt is not null); + promptTokenIds = Tokenizer.Tokenize(prompt!); + } + else if (!string.IsNullOrEmpty(prompt)) + { + _logger?.LogWarning("Both prompt and prompt_token_ids are provided. The prompt will be ignored."); + } + + var seqId = _seqCounter.Next(); + var seq = new Sequence(seqId, prompt, promptTokenIds); + var seqGroup = new SequenceGroup(requestId, [seq], samplingMethod, stoppingCriteria, arrivalTime.Value); + + // Add the sequence group to the scheduler. + Scheduler.AddSeqGroup(seqGroup); + } + + private List ProcessModelOutputs(SamplerOutput outputs, SchedulerOutputs schedulerOutputs) + { + var now = DateTime.Now; + // Update the scheduled sequence groups with the model outputs. + var scheduledSeqGroups = schedulerOutputs.ScheduledSeqGroups; + Debug.Assert(scheduledSeqGroups.Count() == outputs.Count); + int i = 0; + foreach(var scheduledSeqGroup in scheduledSeqGroups) + { + var output = outputs[i]; + var seqGroup = scheduledSeqGroup.SeqGroup; + seqGroup.UpdateNumComputedTokens(scheduledSeqGroup.TokenChunkSize); + ProcessSequenceGroupOutputs(seqGroup, output); + i++; + } + + // Free the finished sequence groups. + Scheduler.FreeFinishedSeqGroups(); + + // Create the outputs. + List requestOutputs = new(); + foreach(var scheduledSeqGroup in scheduledSeqGroups) + { + var seqGroup = scheduledSeqGroup.SeqGroup; + seqGroup.MaybeSetFirstTokenTime(now); + requestOutputs.Add(RequestOutput.FromSeqGroup(seqGroup)); + } + foreach(var seqGroup in schedulerOutputs.IgnoredSeqGroups) + { + requestOutputs.Add(RequestOutput.FromSeqGroup(seqGroup)); + } + + // TODO: log stats here. + return requestOutputs; + } + + private void ProcessSequenceGroupOutputs(SequenceGroup seqGroup, SequenceGroupOutput outputs) + { + // TODO: support using logprobs + var samples = outputs.Samples; + var parentSeqs = seqGroup.GetSeqsWithStatus(SequenceStatus.Running); + var existingFinishedSeqs = seqGroup.GetFinishedSeqs(); + var parentChildDict = parentSeqs.ToDictionary(x => x.Id, _ => new List()); + foreach(var sample in samples) + { + parentChildDict[sample.ParentSeqId].Add(sample); + } + // List of (child, parent) + List<(Sequence, Sequence)> childSeqs = new(); + + foreach(var parent in parentSeqs) + { + var childSamples = parentChildDict[parent.Id]; + if(childSamples.Count == 0) + { + // This parent sequence has no children samples. Remove the parent sequence + // from the sequence group since it will not be used in the future iterations. + parent.Status = SequenceStatus.FinishAborted; + seqGroup.Remove(parent.Id); + Scheduler.FreeSeq(parent); + continue; + } + foreach(var childSample in childSamples.SkipLast(1)) + { + var newChildSeqId = _seqCounter.Next(); + var child = parent.Fork(newChildSeqId); + child.AppendToken(childSample.OutputTokenId); + childSeqs.Add((child, parent)); + } + // Continue the parent sequence for the last child sample. + // We reuse the parent sequence here to reduce redundant memory + // copies, especially when using non-beam search sampling methods. + var lastChildSample = childSamples.Last(); + parent.AppendToken(lastChildSample.OutputTokenId); + childSeqs.Add((parent, parent)); + } + + foreach(var (seq, _) in childSeqs) + { + DeTokenizer.DecodeSequenceInplace(seq, Tokenizer, seqGroup.SamplingMethod); + var stoppingCriteriaOutput = seqGroup.StoppingCriteria.CheckStop(seq); + seq.Status = stoppingCriteriaOutput.Status; + seq.StoppingTokenId = stoppingCriteriaOutput.StoppingTokenId; + seq.StoppingString = stoppingCriteriaOutput.StoppingString; + } + + // Only implement non beam-search case here now. + // TODO: deal with beam search. + { + // For newly created child sequences, add them to the sequence group. + foreach(var (seq, parent) in childSeqs) + { + if(seq != parent) // if the reference are not the same + { + seqGroup.Add(seq); + } + // TODO: see if we need to do the fork in the scheduler. + } + + // NOTE: be careful of this logic. + foreach(var (seq, parent) in childSeqs) + { + if(seq == parent && seq.IsFinished) + { + Scheduler.FreeSeq(seq); + } + } + return; + } + } + } +} diff --git a/LLama/Experimental/LLamaModelRunner.cs b/LLama/Experimental/LLamaModelRunner.cs new file mode 100644 index 000000000..f29ac5559 --- /dev/null +++ b/LLama/Experimental/LLamaModelRunner.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental +{ + internal class LLamaModelRunner + { + } +} diff --git a/LLama/Experimental/Runner/LLamaCpp/LLamaCppRunnerInput.cs b/LLama/Experimental/Runner/LLamaCpp/LLamaCppRunnerInput.cs new file mode 100644 index 000000000..32595ba3b --- /dev/null +++ b/LLama/Experimental/Runner/LLamaCpp/LLamaCppRunnerInput.cs @@ -0,0 +1,92 @@ +using LLama.Experimental.Common; +using LLama.Native; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; + +namespace LLama.Experimental.Runner.LLamaCpp +{ + /// + /// Input special for . + /// + public class LLamaCppRunnerInput + // TODO: get this from a pool? + { + public int[] TokenIds { get; } + + public int[] Positions { get; } + + public int[] SeqIdCount { get; } + + public int[][] SeqIds { get; } + + public bool[] WithLogits { get; } + + public IntPtr[] SeqIdsPtrs { get; } + + /// + /// Construct from . + /// + /// + public LLamaCppRunnerInput(ModelRunnerInput input) + { + Debug.Assert(input.TokenIds.Length == input.Positions.Length); + Debug.Assert(input.TokenIds.Length == input.SeqIds.Length); + Debug.Assert(input.TokenIds.Length == input.WithLogits.Length); + TokenIds = input.TokenIds; + Positions = input.Positions; + + // TODO: Now we never put a token in multiple sequences, + // which may impact on the speed of the model in some cases. + // We should consider to support this in the future. + SeqIdCount = Enumerable.Repeat(1, TokenIds.Length).ToArray(); + SeqIds = new int[TokenIds.Length][]; + for(int i = 0; i < input.SeqIds.Length; i++) + { + SeqIds[i] = [input.SeqIds[i]]; + } + WithLogits = input.WithLogits; + SeqIdsPtrs = new IntPtr[SeqIds.Length]; + } + + /// + /// Convert to . + /// + /// [WARNING] You must hold the pin holder until the returned value will no longer be used. + /// + /// + /// + /// + internal LLamaNativeBatch ToLLamaNativeBatch(out GroupDisposable pinHolder) + { + pinHolder = new GroupDisposable(); + + unsafe + { + var batch = new LLamaNativeBatch + { + n_tokens = TokenIds.Length, + logits = (byte*)pinHolder.Add(WithLogits.AsMemory().Pin()).Pointer, + + n_seq_id = (int*)pinHolder.Add(SeqIdCount.AsMemory().Pin()).Pointer, + pos = (LLamaPos*)pinHolder.Add(Positions.AsMemory().Pin()).Pointer, + seq_id = (LLamaSeqId**)pinHolder.Add(SeqIdsPtrs.AsMemory().Pin()).Pointer, + + // embd is not currently supported, so this is always null! + embd = null, + + // Note that if embd is **not null** then this will be null! + tokens = (LLamaToken*)pinHolder.Add(TokenIds.AsMemory().Pin()).Pointer, + }; + + // Create pointers to each of the arrays in turns + for (var i = 0; i < SeqIdsPtrs.Length; i++) + SeqIdsPtrs[i] = (IntPtr)pinHolder.Add(SeqIds[i].AsMemory().Pin()).Pointer; + + return batch; + } + } + } +} diff --git a/LLama/Experimental/Runner/LLamaCpp/LogitsGenerator.cs b/LLama/Experimental/Runner/LLamaCpp/LogitsGenerator.cs new file mode 100644 index 000000000..2dad0be12 --- /dev/null +++ b/LLama/Experimental/Runner/LLamaCpp/LogitsGenerator.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Runner.LLamaCpp +{ + /// + /// Since the native API will return for logits, + /// we only get it when it's actually needed. + /// + public class LogitsGenerator + { + private int _pos; + + private LLamaContext _context; + + public LogitsGenerator(int pos, LLamaContext context) + { + _pos = pos; + _context = context; + } + + public Span GetLogits() + { + return _context.NativeHandle.GetLogitsIth(_pos); + } + } +} diff --git a/LLama/Experimental/Runner/LLamaCppRunner.cs b/LLama/Experimental/Runner/LLamaCppRunner.cs new file mode 100644 index 000000000..6cc5656da --- /dev/null +++ b/LLama/Experimental/Runner/LLamaCppRunner.cs @@ -0,0 +1,83 @@ +using LLama.Abstractions; +using LLama.Experimental.Abstractions; +using LLama.Experimental.Common; +using LLama.Experimental.Runner.LLamaCpp; +using LLama.Native; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace LLama.Experimental.Runner +{ + /// + /// Using llama.cpp backend to execute the model. + /// + public sealed class LLamaCppRunner: ModelRunnerBase, IModelRunner + { + public LLamaWeights ModelWeights { get; } + + public LLamaContext Context { get; } + + public LLamaCppRunner(LLamaWeights modelWeights, IContextParams contextParams) + { + ModelWeights = modelWeights; + Context = new LLamaContext(modelWeights, contextParams); + } + + /// + public SamplerOutput ExecuteModel(IEnumerable seqGroupMetadataList) + { + var modelInput = PrepareInputs(seqGroupMetadataList); + var samplingMetadata = PrepareSample(seqGroupMetadataList, modelInput.PromptLengths, modelInput.SubqueryLengths); + var llamaCppRunnerInput = new LLamaCppRunnerInput(modelInput); + var nativeBatch = llamaCppRunnerInput.ToLLamaNativeBatch(out var pinHolder); + + // TODO: is global lock still necessary? + + // Batched inference + Context.Decode(nativeBatch); + + // Get the logits + Dictionary seqIdToLogits = new(); + for(int i = 0; i < llamaCppRunnerInput.WithLogits.Length; i++) + { + if (llamaCppRunnerInput.WithLogits[i]) + { + for(int j = 0; j < llamaCppRunnerInput.SeqIds[i].Length; j++) + { + if (seqIdToLogits.ContainsKey(llamaCppRunnerInput.SeqIds[i][j])) + { + throw new Exception("Duplicate sequence id found when getting logits."); + } + else + { + seqIdToLogits.Add(llamaCppRunnerInput.SeqIds[i][j], new LogitsGenerator(i, Context)); + } + } + } + } + + // Sample the logits to get output tokens. + List outputs = new(); + foreach(var seqGroupMetadata in seqGroupMetadataList) + { + List sequenceOutputs = new(); + foreach(var seqId in seqGroupMetadata.SeqData.Keys) + { + var output = seqGroupMetadata.SamplingMethod.SampleSequence(seqIdToLogits[seqId].GetLogits(), seqId, samplingMetadata); + sequenceOutputs.Add(output); + } + outputs.Add(new SequenceGroupOutput(sequenceOutputs)); + } + + return new SamplerOutput(outputs); + } + + public void Dispose() + { + // It should dispose context but not model weight. + Context.Dispose(); + } + } +} diff --git a/LLama/Experimental/Runner/ModelRunnerBase.cs b/LLama/Experimental/Runner/ModelRunnerBase.cs new file mode 100644 index 000000000..5ccd865a9 --- /dev/null +++ b/LLama/Experimental/Runner/ModelRunnerBase.cs @@ -0,0 +1,138 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using LLama.Experimental.Abstractions; +using LLama.Experimental.Common; +using LLama.Extensions; + +namespace LLama.Experimental.Runner +{ + /// + /// A class that provides some commonly used method when running the model. + /// + /// Note that you could certainly not use this helper class and implement from scratch. + /// + public abstract class ModelRunnerBase + { + protected ModelRunnerInput PrepareInputs(IEnumerable seqGroupMetadataList) + { + Debug.Assert(seqGroupMetadataList.Count() > 0); + if (seqGroupMetadataList.First().IsPrompt) + { + return PreparePrompt(seqGroupMetadataList); + } + else + { + return PrepareDecode(seqGroupMetadataList); + } + } + + protected SamplingMetadata PrepareSample(IEnumerable seqGroupMetadataList, int[] promptLengths, int[] subqueryLengths) + { + // TODO: implement it. + return null; + } + + /// + /// Prepare input for sequences at prefill stage. + /// + /// + /// + protected ModelRunnerInput PreparePrompt(IEnumerable seqGroupMetadataList) + { + Debug.Assert(seqGroupMetadataList.Count() > 0); + List inputTokenIds = new(); + List inputPositions = new(); + List sequenceIdMapping = new(); // sequennce id of corresponding tokens + List withLogits = new(); + + List promptLengths = new(); + List contextLengths = new(); + List subqueryLengths = new(); + + foreach(var seqGroupMetadata in seqGroupMetadataList) + { + Debug.Assert(seqGroupMetadata.IsPrompt); + var seqIds = seqGroupMetadata.SeqData.Keys.ToList(); + Debug.Assert(seqIds.Count == 1); + var seqId = seqIds[0]; + + var tokenChunkSize = seqGroupMetadata.TokenChunkSize; + var seqData = seqGroupMetadata.SeqData[seqId]; + var computedLength = seqData.NumComputedTokens; + // We should use `Length` here because in case of preemption it contains output tokens. + var prefillEnd = Math.Min(seqData.Length, computedLength + tokenChunkSize); + var prompTokenIds = seqData.TokenIds.Take(prefillEnd).Skip(computedLength); + var promptLength = prompTokenIds.Count(); + // Right now, the prefill_end is always same as the length of sequence. + // However, once chunked prefill is introduced, this assumption can be changed. + Debug.Assert(prefillEnd == seqData.Length); + promptLengths.Add(promptLength); + + // TODO: check the logic here, related with blocks? + + // actual prompt lens + contextLengths.Add(computedLength); + subqueryLengths.Add(promptLength - computedLength); + + inputTokenIds.AddRange(prompTokenIds); + // NOTE: Here we assume that the first token in the prompt is always the first token in the sequence. + inputPositions.AddRange(Enumerable.Range(computedLength, prefillEnd)); + + // TODO: deal with sliding window here? + sequenceIdMapping.AddRange(Enumerable.Repeat(seqId, promptLength)); + + withLogits.AddRange(Enumerable.Repeat(false, promptLength - 1)); + withLogits.Add(true); + } + + int maxSubqueryLength = subqueryLengths.Max(); + int maxPromptLength = promptLengths.Max(); + int numPromptTokens = inputTokenIds.Count; + Debug.Assert(maxSubqueryLength > 0); + + return new ModelRunnerInput(inputTokenIds.ToArray(), inputPositions.ToArray(), sequenceIdMapping.ToArray(), + withLogits.ToArray(), promptLengths.ToArray(), subqueryLengths.ToArray()); + } + + /// + /// Prepare input for sequences at decode stage. + /// + /// + /// + protected ModelRunnerInput PrepareDecode(IEnumerable seqGroupMetadataList) + { + Debug.Assert(seqGroupMetadataList.Count() > 0); + List inputTokenIds = new(); + List inputPositions = new(); + List sequenceIdMapping = new(); // sequennce id of corresponding tokens + List withLogits = new(); + + foreach (var seqGroupMetadata in seqGroupMetadataList) + { + Debug.Assert(!seqGroupMetadata.IsPrompt); + Debug.Assert(seqGroupMetadata.TokenChunkSize == 1); + var seqIds = seqGroupMetadata.SeqData.Keys.ToList(); + + foreach(var seqId in seqIds) + { + var seqData = seqGroupMetadata.SeqData[seqId]; + var generationToken = seqData.LastTokenId; + inputTokenIds.Add(generationToken); + + var seqLength = seqData.Length; + var position = seqLength - 1; + inputPositions.Add(position); + + sequenceIdMapping.Add(seqId); + withLogits.Add(true); + } + } + + return new ModelRunnerInput(inputTokenIds.ToArray(), inputPositions.ToArray(), + sequenceIdMapping.ToArray(), withLogits.ToArray(), [], []); + } + } +} diff --git a/LLama/Experimental/Scheduler.cs b/LLama/Experimental/Scheduler.cs new file mode 100644 index 000000000..47180ea6c --- /dev/null +++ b/LLama/Experimental/Scheduler.cs @@ -0,0 +1,601 @@ +using LLama.Experimental.Abstractions; +using LLama.Experimental.Common; +using LLama.Experimental.Config; +using LLama.Experimental.Extensions; +using Microsoft.Extensions.Logging; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; + +namespace LLama.Experimental +{ + /// + /// The scheduler to schedule the requests for model inference. + /// + public sealed class Scheduler + // TODO: LORA + { + private ILogger? _logger; + + /// + /// Whether we schedule a prompt at previous step. + /// + private bool _prevIsPrompt; + + /// + /// Latency of the last prompt step + /// + private float _lastPromptLatency; + + /// + /// Time at previous scheduling step + /// + private DateTime _prevTime; + + /// + /// Scheduler configuration. + /// + public SchedulerConfig SchedulerConfig { get; set; } + + /// + /// KV cache configuration. + /// + public KvCacheConfig KvCacheConfig { get; set; } + + /// + /// The maximumum prompt length that can be used. + /// It's deduced from the scheduler configuration. + /// + public int MaxPromptLength { get; set; } + + /// + /// Sequence groups in the WAITING state. It contain new prefill or preempted requests. + /// + public LinkedList Waiting { get; set; } + + /// + /// Sequence groups in the RUNNING state. It contains the requests that is being decoded. + /// + public LinkedList Running { get; set; } + + /// + /// Sequence groups in the SWAPPED state. It contains decode requests that are swapped out. + /// + public LinkedList Swapped { get; set; } + + public KvCacheManager KvCacheManager { get; } + + + /// + /// Create a scheduler. Note that this is not a high-level API. If you are an user, please + /// read the documentation and ensure you know what it does. + /// + /// + /// + /// + public Scheduler(SchedulerConfig schedulerConfig, KvCacheConfig kvCacheConfig, ILogger? logger = null) + { + SchedulerConfig = schedulerConfig; + KvCacheConfig = kvCacheConfig; + + if (SchedulerConfig.EnableChunkedPrefill) + { + MaxPromptLength = SchedulerConfig.MaxSequenceLength; + } + else + { + MaxPromptLength = Math.Min(SchedulerConfig.MaxSequenceLength, SchedulerConfig.MaxNumBatchedTokens); + } + + Waiting = new LinkedList(); + Running = new LinkedList(); + Swapped = new LinkedList(); + + _logger = logger; + + // TODO: init with config + KvCacheManager = new(); + } + + /// + /// Add sequence groups to the waiting queue. + /// + /// + /// + public Scheduler AddSeqGroup(SequenceGroup seqGroup) + { + _logger?.LogDebug($"Added seq group {seqGroup.RequestId}"); + Waiting.AddBack(seqGroup); + return this; + } + + /// + /// Aborts a sequence group with the given IDs. + /// Check if the sequence group with the given ID + /// is present in any of the state queue. + ///If present, remove the sequence group from the state queue. + /// Also, if any of the sequences in the sequence group is not finished, + /// free the sequence with status `FINISHED_ABORTED`. + ///Otherwise, do nothing. + /// + /// + /// + public Scheduler AbortSeqGroup(IEnumerable requestIds) + { + var requestIdSet = new HashSet(requestIds.Distinct()); + + AbortInternal(Waiting, requestIdSet); + AbortInternal(Running, requestIdSet); + AbortInternal(Swapped, requestIdSet); + return this; + } + + /// + /// Whether all sequences has been finished at this moment. + /// + /// + public bool HasUnfinishedSeqs() + { + return Waiting.Count != 0 || Running.Count != 0 || Swapped.Count != 0; + } + + /// + /// Get the number of unfinished sequence groups. + /// + /// + public int GetNumUnfinishedSeqGroups() + { + return Waiting.Count + Running.Count + Swapped.Count; + } + + /// + /// Free the sequence resource that managed by the scheduler. + /// It's actually an empty method now and may be implemented in the future if needed. + /// + /// + public void FreeSeq(Sequence seq) + { + // TODO: implement it if needed. + } + + /// + /// Schedule sequence groups. + /// This function call changes the internal states of the scheduler, + /// such as this.Running, this.Wwapped, and this.Waiting. + /// + /// + public (List, SchedulerOutputs) Schedule() + { + var schedulerOutputs = ScheduleInternal(); + var now = DateTime.Now; + + // Create input data structures. + List seqGroupMetadataList = new(); + int i = 0; + foreach(var scheduledSeqGroup in schedulerOutputs.ScheduledSeqGroups) + { + var seqGroup = scheduledSeqGroup.SeqGroup; + var tokenChunkSize = scheduledSeqGroup.TokenChunkSize; + seqGroup.MaybeSetFirstScheduledTime(now); + + Dictionary seqData = new(); + + foreach(var seq in seqGroup.GetSeqsWithStatus(SequenceStatus.Running)) + { + var seqId = seq.Id; + seqData[seqId] = seq.Data; + } + + // It assumes the scheduled_seq_groups is ordered by prefill < decoding. + bool isPrompt = i < schedulerOutputs.NumPrefillGroups; + var seqGroupMetadata = new SequenceGroupMetadata( + seqGroup.RequestId, + isPrompt, + seqData, + seqGroup.SamplingMethod, + seqGroup.StoppingCriteria, + tokenChunkSize + ); + seqGroupMetadataList.Add(seqGroupMetadata); + + i++; + } + + return (seqGroupMetadataList, schedulerOutputs); + } + + /// + /// Free finished sequence groups. + /// + public void FreeFinishedSeqGroups() + { + Running = new LinkedList(Running.Where(x => !x.IsFinished)); + } + + /// + /// Schedule queued requests. + /// + /// + /// + private SchedulerOutputs ScheduleInternal() + { + if (SchedulerConfig.EnableChunkedPrefill) + { + // TODO: allow chunked prefill. + throw new NotImplementedException(); + } + else + { + return ScheduleDefault(); + } + } + + /// + /// Schedule queued requests. + /// + /// The current policy is designed to opimimize the throughput. First, + /// it batches as many prefill requests as possible.And it schedules + /// decodes.If there's a pressure on GPU memory, decode requests can + /// be swapped or preempted. + /// + /// + private SchedulerOutputs ScheduleDefault() + { + // Include running requests to the budget. + var budget = new SchedulingBudget(SchedulerConfig.MaxNumBatchedTokens, SchedulerConfig.MaxNumSequences); + // Make sure we include num running seqs before scheduling prefill, + // so that we don't schedule beyond max_num_seqs for prefill. + foreach(var seqGroup in Running) + { + budget.AddNumSeqs(seqGroup.RequestId, seqGroup.GetMaxNumRunningSeqs()); + } + + var remainingWaiting = Waiting; + var prefills = SchedulerPrefillOutputs.CreateEmpty(); + var remainingRunning = Running; + var runningScheduled = SchedulerRunningOutputs.CreateEmpty(); + var remainingSwapped = Swapped; + var swappedIn = SchedulerSwappedInOutputs.CreateEmpty(); + + if(Swapped.Count == 0) + { + prefills = SchedulePrefills(Waiting, budget, false); + remainingWaiting = prefills.RemainingWaitingQueue; + } + + var policy = PolicyFactory.DefaultPolicy; + // Don't schedule decodes if prefills are scheduled. + // NOTE: If `SchedulePrefills` doesn't enable chunking, this.Running + // only contains decode requests, not chunked prefills. + + if(prefills.SeqGroups.Count == 0) + { + runningScheduled = ScheduleRunning(Running, budget, policy, false); + remainingRunning = runningScheduled.RemainingRunningQueue; + + // If any sequence group is preempted, do not swap in any sequence group. + // Because it means there's no slot for new running requests. + if(runningScheduled.PreemptedSeqGroups.Count + runningScheduled.SwappedOutSeqGroups.Count == 0) + { + // TODO: implement the swapping. + } + } + + Debug.Assert(budget.NumBatchedTokens <= SchedulerConfig.MaxNumBatchedTokens); + Debug.Assert(budget.NumCurrentSeqs <= SchedulerConfig.MaxNumSequences); + + // Update waiting requests. + Waiting = remainingWaiting; + Waiting.ExtendFront(runningScheduled.PreemptedSeqGroups); + // Update new running requests. + Running = remainingRunning; + Running.ExtendFront(prefills.SeqGroups.Select(x => x.SeqGroup)); + Running.ExtendBack(runningScheduled.DecodeSeqGroups.Select(x => x.SeqGroup)); + Running.ExtendBack(swappedIn.DecodeSeqGroups.Select(x => x.SeqGroup)); + // Update swapped requests. + Swapped = remainingSwapped; + Swapped.ExtendBack(runningScheduled.SwappedOutSeqGroups); + + // There should be no prefill from running queue because this policy + // doesn't allow chunked prefills. + Debug.Assert(runningScheduled.PrefillSeqGroups.Count == 0); + Debug.Assert(swappedIn.PrefillSeqGroups.Count == 0); + return new SchedulerOutputs( + ScheduledSeqGroups: prefills.SeqGroups.Concat(runningScheduled.DecodeSeqGroups).Concat(swappedIn.DecodeSeqGroups), + NumPrefillGroups: prefills.SeqGroups.Count, + NumBatchedTokens: budget.NumBatchedTokens, + IgnoredSeqGroups: prefills.IgnoredSeqGroups + ); + } + + private SchedulerSwappedInOutputs ScheduleSwapped(LinkedList swappedQueue, SchedulingBudget budget, ISchedulingPolicy policy, bool enableChunking) + { + throw new NotImplementedException(); + } + + /// + /// Schedule sequence groups that are in prefill stage. + /// + /// Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE + /// as a new prefill(that starts from beginning -> most recently generated + /// tokens). + /// + /// It schedules waiting requests as long as it fits `budget` and + /// curr_loras smaller than or equal with max_lora from the scheduling config.The input arguments + /// `budget` and `curr_loras` are updated based on scheduled seq_groups. + /// + /// The queue that contains prefill requests. The given arguments are NOT in-place modified. + /// The scheduling budget. The argument is in-place updated when any requests are scheduled. + /// + /// If True, seq group can be chunked and only a chunked number of tokens are scheduled if + /// has not enough capacity to schedule all tokens. + /// + /// + private SchedulerPrefillOutputs SchedulePrefills(LinkedList waiting, SchedulingBudget budget, bool enableChunking = false) + { + List ignoredSeqGroups = new(); + List seqGroups = new(); + // We don't sort waiting queue because we assume it is sorted. + // Copy the queue so that the input queue is not modified. + var waitingQueue = new LinkedList(waiting); + + LinkedList leftoverWaitingSequences = new(); + while(PassedDelay(DateTime.Now) && waitingQueue.Count > 0) + { + var seqGroup = waitingQueue.PeekFront(); + + var waitingSeqs = seqGroup.GetSeqsWithStatus(SequenceStatus.Waiting); + Debug.Assert(waitingSeqs.Count() == 1, "Waiting sequence group should have only one prompt sequence."); + var numNewTokens = GetNumNewTokens(seqGroup, SequenceStatus.Waiting, enableChunking, budget); + if (!enableChunking) + { + var numPromptTokens = waitingSeqs.First().Length; + Debug.Assert(numNewTokens == numPromptTokens); + } + + if (numNewTokens > MaxPromptLength) + { + _logger?.LogWarning($"Input prompt ({numNewTokens} tokens) is too long " + + $"and exceeds limit of {MaxPromptLength}."); + foreach(var seq in waitingSeqs) + { + seq.Status = SequenceStatus.FinishIgnored; + } + ignoredSeqGroups.Add(seqGroup); + waitingQueue.RemoveFront(); + continue; + } + + // If the sequence group cannot be allocated, stop. + var canAlloc = KvCacheManager.CanAllocate(seqGroup); + if(canAlloc == AllocStatus.Later) + { + break; + } + else if(canAlloc == AllocStatus.Never) + { + _logger?.LogWarning($"Input prompt ({numNewTokens} tokens) is too long" + + " and exceeds the capacity of block_manager"); + foreach(var seq in waitingSeqs) + { + seq.Status = SequenceStatus.FinishIgnored; + } + ignoredSeqGroups.Add(seqGroup); + waitingQueue.RemoveFront(); + continue; + } + + var numNewSeqs = seqGroup.GetMaxNumRunningSeqs(); + if(numNewTokens == 0 || !budget.CanSchedule(numNewTokens, numNewSeqs)) + { + break; + } + + // Can schedule this request. + waitingQueue.RemoveFront(); + AllocateAndSetRunning(seqGroup, numNewTokens); + seqGroups.Add(new ScheduledSequenceGroup(seqGroup, numNewTokens)); + budget.AddNumBatchedTokens(seqGroup.RequestId, numNewTokens); + budget.AddNumSeqs(seqGroup.RequestId, numNewSeqs); + } + + waitingQueue.ExtendFront(leftoverWaitingSequences); + if(seqGroups.Count > 0) + { + _prevIsPrompt = true; + } + + return new SchedulerPrefillOutputs(waitingQueue, seqGroups, ignoredSeqGroups); + } + + /// + /// Schedule sequence groups that are running. + /// + /// Running queue should include decode and chunked prefill requests. + /// + /// + /// The queue that contains running requests (i.e., decodes). + /// The given arguments are NOT in-place modified. + /// + /// + /// The scheduling budget. The argument is in-place updated + /// when any decodes are preempted. + /// + /// The sorting policy to sort running_queue. + /// + /// If True, seq group can be chunked and only a chunked number of tokens are scheduled if + /// `budget.num_batched_tokens` has not enough capacity to schedule all tokens. + /// + /// + private SchedulerRunningOutputs ScheduleRunning(LinkedList runningQueue, SchedulingBudget budget, + ISchedulingPolicy policy, bool enableChunking) + { + List decodeSeqGroups = new(); + List prefillSeqGroups = new(); + List preempted = new(); + List swappedOut = new(); + + //NOTE: Preemption happens only when there is no available slot + //to keep all the sequence groups in the RUNNING state. + //In this case, the policy is responsible for deciding which sequence + //groups to preempt. + var now = DateTime.Now; + runningQueue = policy.SortByPriority(now, runningQueue); + + while(runningQueue.Count > 0) + { + var seqGroup = runningQueue.PeekFront(); + var numRunningTokens = GetNumNewTokens(seqGroup, SequenceStatus.Running, enableChunking, budget); + + // We can have up to 1 running prefill at any given time in running + // queue, which means we can guarantee chunk size is at least 1. + Debug.Assert(numRunningTokens != 0); + var numRunningSeqs = seqGroup.GetMaxNumRunningSeqs(); + + runningQueue.RemoveFront(); + bool appended = true; + while (!CanAppendSlots(seqGroup)) + { + // TODO: implement the preemption logic + Debug.Assert(false); + } + + if (appended) + { + _logger?.LogDebug($"append slot for {seqGroup}"); + AppendSlots(seqGroup); + if (seqGroup.IsPrefill) + { + prefillSeqGroups.Add(new ScheduledSequenceGroup(seqGroup, numRunningTokens)); + } + else + { + decodeSeqGroups.Add(new ScheduledSequenceGroup(seqGroup, 1)); + } + budget.AddNumBatchedTokens(seqGroup.RequestId, numRunningTokens); + budget.AddNumSeqs(seqGroup.RequestId, numRunningSeqs); + } + } + + Debug.Assert(runningQueue.Count == 0); + return new SchedulerRunningOutputs(runningQueue, decodeSeqGroups, prefillSeqGroups, preempted, swappedOut); + } + + private void AllocateAndSetRunning(SequenceGroup seqGroup, int numNewTokens) + { + KvCacheManager.Allocate(seqGroup); + foreach (var seq in seqGroup.GetSeqsWithStatus(SequenceStatus.Waiting)) + { + seq.Status = SequenceStatus.Running; + } + } + + private bool PassedDelay(DateTime now) + { + if (_prevIsPrompt) + { + _lastPromptLatency = (now - _prevTime).Milliseconds; + } + _prevTime = now; + _prevIsPrompt = false; + // Delay scheduling prompts to let waiting queue fill up + if (SchedulerConfig.DelayFactor > 0 && Waiting.Count > 0) + { + var earliestArrivalTime = Waiting.Select(x => x.Metrics.ArrivalTime).Min(); + return (now - earliestArrivalTime).Milliseconds > (SchedulerConfig.DelayFactor * _lastPromptLatency) || Running.Count == 0; + } + return true; + } + + private bool CanAppendSlots(SequenceGroup seqGroup) + { + return KvCacheManager.CanAppendSlots(seqGroup); + } + + private void AppendSlots(SequenceGroup seqGroup) + { + // TODO: Implement this method + } + + /// + /// Get the next new tokens to compute for a given sequence group that's in a given `status`. + /// + /// The API could chunk the number of tokens to compute based on `budget` + /// if `enable_chunking` is True.If a sequence group has multiple + /// sequences(e.g., running beam search), it means it is in decoding + /// phase, so chunking doesn't happen. + /// + /// + /// + /// + /// + /// + /// + private int GetNumNewTokens(SequenceGroup seqGroup, SequenceStatus status, bool enableChunking, SchedulingBudget budget) + { + int numNewTokens = 0; + var seqs = seqGroup.GetSeqsWithStatus(status); + foreach(var seq in seqs) + { + numNewTokens += seq.NumNewTokens; + } + // Chunk if a running request cannot fit in. + // If number of seq > 1, it means it is doing beam search in a + // decode phase. Do not chunk in that case. + if(enableChunking && seqs.Count() == 1) + { + numNewTokens = Math.Min(numNewTokens, budget.RemainingTokenBudget); + } + return numNewTokens; + } + + private void AbortInternal(LinkedList queue, HashSet requestIds) + { + Queue abortedGroups = new(); + foreach (var seqGroup in queue) + { + if (requestIds.Count == 0) + { + break; + } + if (requestIds.Contains(seqGroup.RequestId)) + { + _logger?.LogDebug($"Aborted seq group {seqGroup.RequestId}"); + abortedGroups.Enqueue(seqGroup); + requestIds.Remove(seqGroup.RequestId); + } + } + foreach(var abortGroup in abortedGroups) + { + queue.Remove(abortGroup); + foreach(var seq in abortGroup.GetAllSeqs()) + { + if (seq.IsFinished) + { + continue; + } + seq.Status = SequenceStatus.FinishAborted; + } + } + } + } + + /// + /// The mode of preemption. + /// + public enum PreemptionMode + { + /// + /// Swap out the blocks of the preempted sequences to CPU memory + /// and swap them back in when the sequences are resumed. + /// + Swap, + + /// + /// Discard the blocks of the preempted sequences and recompute them + /// when the sequences are resumed, treating the sequences as new prompts. + /// + Recompute + } +} diff --git a/LLama/Experimental/SchedulingPolicies.cs b/LLama/Experimental/SchedulingPolicies.cs new file mode 100644 index 000000000..c22ff3b69 --- /dev/null +++ b/LLama/Experimental/SchedulingPolicies.cs @@ -0,0 +1,25 @@ +using System; +using System.Collections.Generic; +using System.Text; +using LLama.Experimental.Abstractions; +using LLama.Experimental.Common; + +namespace LLama.Experimental +{ + /// + /// First in first out policy. + /// + public class FCFS: ISchedulingPolicy + { + /// + public int GetPriority(DateTime now, SequenceGroup seqGroup) + { + return (now - seqGroup.Metrics.ArrivalTime).Milliseconds; + } + } + + public class PolicyFactory + { + public static ISchedulingPolicy DefaultPolicy { get; set; } = new FCFS(); + } +} diff --git a/LLama/Experimental/Utils/ClassStringFormatter.cs b/LLama/Experimental/Utils/ClassStringFormatter.cs new file mode 100644 index 000000000..6ae2cc117 --- /dev/null +++ b/LLama/Experimental/Utils/ClassStringFormatter.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Reflection; +using System.Text; + +namespace LLama.Experimental.Utils +{ + internal static class ClassStringFormatter + { + public static string Format(T obj) + { + Type type = obj.GetType(); + PropertyInfo[] properties = type.GetProperties(); + + string res = $"{type.Name}("; + foreach (var property in properties) + { + object? value = property.GetValue(obj); + res += $"\n {property.Name} = {value},"; + } + if(properties.Length == 0) + { + res += ")"; + } + else + { + res += "\n)"; + } + return res; + } + } +} diff --git a/LLama/Experimental/Utils/IdCounter.cs b/LLama/Experimental/Utils/IdCounter.cs new file mode 100644 index 000000000..6ff4fd558 --- /dev/null +++ b/LLama/Experimental/Utils/IdCounter.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Experimental.Utils +{ + internal class IdCounter + { + private int _number; + + public IdCounter(int start = 0) + { + _number = start; + } + + public int Next() + { + return _number++; + } + + public void Reset() + { + _number = 0; + } + } +} diff --git a/LLama/Extensions/IEnumerableExtensions.cs b/LLama/Extensions/IEnumerableExtensions.cs index 3d1e2e814..c94335e1b 100644 --- a/LLama/Extensions/IEnumerableExtensions.cs +++ b/LLama/Extensions/IEnumerableExtensions.cs @@ -10,8 +10,13 @@ public static IEnumerable TakeLast(this IEnumerable source, int count) { return TakeLastImpl(source, count); } + + public static IEnumerable SkipLast(this IEnumerable source, int count) + { + return SkipLastImpl(source, count); + } #elif !NET6_0_OR_GREATER && !NETSTANDARD2_1_OR_GREATER - #error Target framework not supported! +#error Target framework not supported! #endif internal static IEnumerable TakeLastImpl(IEnumerable source, int count) @@ -24,5 +29,10 @@ internal static IEnumerable TakeLastImpl(IEnumerable source, int count) list.RemoveRange(0, list.Count - count); return list; } + + internal static IEnumerable SkipLastImpl(IEnumerable source, int count) + { + return source.Take(source.Count() - count); + } } } diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 9517965e6..12ea24740 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -400,6 +400,11 @@ public DecodeResult Decode(LLamaBatch batch) return (DecodeResult)NativeHandle.Decode(batch); } + public DecodeResult Decode(LLamaNativeBatch batch) + { + return (DecodeResult)NativeHandle.Decode(batch); + } + /// /// /// diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index 3947b7c31..3b4493ec9 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -3,7 +3,7 @@ netstandard2.0;net6.0;net7.0;net8.0 LLama enable - 10 + 12 AnyCPU;x64;Arm64 True diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 2f881fa5d..630c7bba2 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -219,6 +219,18 @@ public DecodeResult Decode(LLamaBatch batch) return (DecodeResult)NativeApi.llama_decode(this, nb); } + /// + /// + /// + /// + /// + public DecodeResult Decode(LLamaNativeBatch batch) + { + // TODO: is global lock still necessary? + lock (GlobalInferenceLock) + return (DecodeResult)NativeApi.llama_decode(this, batch); + } + /// /// Decode a set of tokens in batch-size chunks. ///