diff --git a/LLama/Experimental/Abstractions/IModelRunner.cs b/LLama/Experimental/Abstractions/IModelRunner.cs
new file mode 100644
index 000000000..d89f012f0
--- /dev/null
+++ b/LLama/Experimental/Abstractions/IModelRunner.cs
@@ -0,0 +1,20 @@
+using LLama.Experimental.Common;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Abstractions
+{
+ ///
+ /// It defines how to execute the model.
+ ///
+ public interface IModelRunner: IDisposable
+ {
+ ///
+ /// Deal with the scheduled sequences to get the output.
+ ///
+ ///
+ ///
+ SamplerOutput ExecuteModel(IEnumerable seqGroupMetadataList);
+ }
+}
diff --git a/LLama/Experimental/Abstractions/ISamplingMethod.cs b/LLama/Experimental/Abstractions/ISamplingMethod.cs
new file mode 100644
index 000000000..694468a5c
--- /dev/null
+++ b/LLama/Experimental/Abstractions/ISamplingMethod.cs
@@ -0,0 +1,44 @@
+using LLama.Experimental.Common;
+using LLama.Experimental.Runner.LLamaCpp;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Abstractions
+{
+ ///
+ /// Method to sample the model output.
+ ///
+ public interface ISamplingMethod
+ // TODO: We should reconsider this design. Maybe it's better to use `SamplingParams` to let user set,
+ // and choose the actual sampler internally according to the params.
+ {
+ ///
+ ///The maximum number of sequences running in parallel.
+ ///
+ /// If you don't know what to return, you can return the default value.
+ ///
+ /// Generally, if you want to select several results from n results, you need
+ /// to set the maximum number of sequences to run.
+ ///
+ ///
+ ///
+ ///
+ int GetMaxNumRunningSeqs(int defaultValue, int currentNumSeqs);
+
+ ///
+ /// Whether to skip special tokens.
+ ///
+ bool SkipSpecialTokens { get; }
+
+ ///
+ /// Sample the sequence logits to get the token.
+ ///
+ ///
+ ///
+ ///
+ ///
+ SequenceOutput SampleSequence(Span logits, int seqId, SamplingMetadata samplingMetadata);
+ // TODO: maybe we shouldn't expose all the samplingMetadata to users here.
+ }
+}
diff --git a/LLama/Experimental/Abstractions/ISchedulingPolicy.cs b/LLama/Experimental/Abstractions/ISchedulingPolicy.cs
new file mode 100644
index 000000000..e4d4bbb5e
--- /dev/null
+++ b/LLama/Experimental/Abstractions/ISchedulingPolicy.cs
@@ -0,0 +1,21 @@
+using LLama.Experimental.Common;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Abstractions
+{
+ ///
+ /// Define the scheduling policy, which decides the priority orders of sequences.
+ ///
+ public interface ISchedulingPolicy
+ {
+ ///
+ /// Get the priority of a sequence group.
+ ///
+ ///
+ ///
+ ///
+ int GetPriority(DateTime now, SequenceGroup seqGroup);
+ }
+}
diff --git a/LLama/Experimental/Abstractions/IStoppingCriteria.cs b/LLama/Experimental/Abstractions/IStoppingCriteria.cs
new file mode 100644
index 000000000..4a86f94d6
--- /dev/null
+++ b/LLama/Experimental/Abstractions/IStoppingCriteria.cs
@@ -0,0 +1,30 @@
+using LLama.Experimental.Common;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Abstractions
+{
+ ///
+ /// Stopping criteria that can be applied during generation.
+ ///
+ public interface IStoppingCriteria
+ {
+ ///
+ /// Check if the sequence should be stopped and return the status.
+ ///
+ /// If it's not supposed to be stopped, be sure to return its current status.
+ ///
+ ///
+ ///
+ StoppingCriteriaOutput CheckStop(Sequence seq); // TODO: include other params?
+ }
+
+ ///
+ /// The output of
+ ///
+ /// The sequence status.
+ /// If the sequence stops because of the appearance of a string, please set it here.
+ /// If the sequence stops because of the appearance of a token, please set it here.
+ public record class StoppingCriteriaOutput(SequenceStatus Status, string? StoppingString = null, int? StoppingTokenId = null);
+}
diff --git a/LLama/Experimental/Abstractions/ITokenizer.cs b/LLama/Experimental/Abstractions/ITokenizer.cs
new file mode 100644
index 000000000..4e74973da
--- /dev/null
+++ b/LLama/Experimental/Abstractions/ITokenizer.cs
@@ -0,0 +1,34 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Abstractions
+{
+ ///
+ /// The interface for tokenizer in LLamaSharp. It's responsible for converting text to token ids, or vice versa.
+ ///
+ public interface ITokenizer
+ {
+ // TODO: `ApplyChatTemplate` API
+
+ // TODO: Batched Encode?
+
+ ///
+ /// Get the token ids from the text
+ ///
+ ///
+ ///
+ IList Tokenize(string input);
+
+ ///
+ /// Convert the token ids to text.
+ ///
+ ///
+ ///
+ ///
+ /// The consumed tokens for decoding.
+ int ConvertIdsToText(IEnumerable tokenIds, out string result, bool skipSpecialTokens = false);
+
+ // TODO: decode from Logprobs
+ }
+}
diff --git a/LLama/Experimental/Common/ModelRunnerInput.cs b/LLama/Experimental/Common/ModelRunnerInput.cs
new file mode 100644
index 000000000..ad02145b9
--- /dev/null
+++ b/LLama/Experimental/Common/ModelRunnerInput.cs
@@ -0,0 +1,24 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Common
+{
+ ///
+ /// The input prepared for model runner.
+ ///
+ /// The tokens to feed to the model.
+ /// The positions of these tokens.
+ /// The sequence ids of these tokens.
+ /// Whether the logits need to be computed for the token.
+ /// The lengths of the prompts if the input is at prefill stage, otherwise empty.
+ /// The lengths of the subqueries if the input is at prefill stage, otherwise empty.
+ public record class ModelRunnerInput(
+ int[] TokenIds,
+ int[] Positions,
+ int[] SeqIds,
+ bool[] WithLogits,
+ int[] PromptLengths,
+ int[] SubqueryLengths
+ );
+}
diff --git a/LLama/Experimental/Common/RequestMetrics.cs b/LLama/Experimental/Common/RequestMetrics.cs
new file mode 100644
index 000000000..405b201d3
--- /dev/null
+++ b/LLama/Experimental/Common/RequestMetrics.cs
@@ -0,0 +1,42 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Common
+{
+ ///
+ /// Metrics associated with a request.
+ ///
+ public class RequestMetrics
+ {
+ ///
+ /// The time when the request arrived.
+ ///
+ public DateTime ArrivalTime { get; set; }
+
+ ///
+ /// The time when the request was first scheduled.
+ ///
+ public DateTime? FirstScheduledTime { get; set; }
+
+ ///
+ /// The time when the first token was generated.
+ ///
+ public DateTime? FirstTokenTime { get; set; }
+
+ ///
+ /// The time when the last token was generated.
+ ///
+ public DateTime? LastTokenTime { get; set; }
+
+ ///
+ /// The time the request spent in the queue.
+ ///
+ public TimeSpan? TimeInQueue { get; set; }
+
+ ///
+ /// The time when the request was finished.
+ ///
+ public DateTime? FinishedTime { get; set; }
+ }
+}
diff --git a/LLama/Experimental/Common/RequestOutput.cs b/LLama/Experimental/Common/RequestOutput.cs
new file mode 100644
index 000000000..095105fcf
--- /dev/null
+++ b/LLama/Experimental/Common/RequestOutput.cs
@@ -0,0 +1,101 @@
+using LLama.Experimental.Utils;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace LLama.Experimental.Common
+{
+ ///
+ /// The output data of a request to the LLM.
+ ///
+ /// The unique ID of the request.
+ /// The prompt string of the request.
+ /// The token IDs of the prompt.
+ /// The output sequences of the request.
+ /// Whether the whole request is finished.
+ /// Metrics associated with the request.
+ public record class RequestOutput(
+ string RequestId,
+ string? Prompt,
+ IList PromptTokenIds,
+ IList Outputs,
+ bool Finished,
+ RequestMetrics Metrics
+ )
+ {
+ ///
+ /// Create an instance from .
+ ///
+ ///
+ ///
+ ///
+ public static RequestOutput FromSeqGroup(SequenceGroup seqGroup)
+ {
+ var seqs = seqGroup.GetAllSeqs();
+ if(seqs.Count() != 1)
+ {
+ // TODO: deal with beam search here.
+ throw new NotImplementedException();
+ }
+
+ List outputs = new();
+ int index = 0;
+ foreach(var seq in seqs)
+ {
+ outputs.Add(new CompletionOutput(index, seq.OutputText, seq.OutputTokens,
+ seq.Status.GetFinishedReason(), seq.StoppingString, seq.StoppingTokenId));
+ index++;
+ }
+
+ if (seqGroup.IsFinished)
+ {
+ seqGroup.SetFinishedTime(DateTime.Now);
+ }
+ return new RequestOutput(seqGroup.RequestId, seqGroup.Prompt, seqGroup.PromptTokenIds,
+ outputs, seqGroup.IsFinished, seqGroup.Metrics);
+ }
+
+ ///
+ public override string ToString()
+ {
+ return ClassStringFormatter.Format(this);
+ }
+ }
+
+ ///
+ /// The output data of one completion output of a request.
+ ///
+ /// The index of the output in the request.
+ /// The generated output text.
+ /// The token IDs of the generated output text.
+ /// The reason why the sequence is finished.
+ ///
+ /// The stop string that caused the completion to stop,
+ /// Null if the completion finished for some other reason.
+ ///
+ ///
+ /// The stop string that caused the completion to stop,
+ /// Null if the completion finished for some other reason.
+ ///
+ public record class CompletionOutput(
+ int Index,
+ string Text,
+ IList TokenIds,
+ string FinishReason,
+ string? StoppingString,
+ int? StoppingToken
+ )
+ {
+ ///
+ /// Whether the completion has finished.
+ ///
+ public bool IsFinished => !string.IsNullOrEmpty(FinishReason);
+
+ ///
+ public override string ToString()
+ {
+ return ClassStringFormatter.Format(this);
+ }
+ }
+}
diff --git a/LLama/Experimental/Common/SamplerOutput.cs b/LLama/Experimental/Common/SamplerOutput.cs
new file mode 100644
index 000000000..83f9238d8
--- /dev/null
+++ b/LLama/Experimental/Common/SamplerOutput.cs
@@ -0,0 +1,88 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Common
+{
+ ///
+ /// For each sequence group, we generate a list of SequenceOutput object,
+ /// each of which contains one possible candidate for the next token.
+ ///
+ /// This datastructure implements methods so it can be used like a list.
+ ///
+ public class SamplerOutput: IList
+ {
+ ///
+ /// The list of objects, which are the outputs of the LLM model.
+ ///
+ public List Outputs { get; }
+
+ ///
+ ///
+ ///
+ ///
+ public SamplerOutput(List outputs)
+ {
+ Outputs = outputs;
+ }
+
+ #region IList Implementation
+ ///
+ public SequenceGroupOutput this[int index] { get => Outputs[index]; set => Outputs[index] = value; }
+ ///
+ public int Count => Outputs.Count;
+ ///
+ public bool IsReadOnly => false;
+ ///
+ public void Add(SequenceGroupOutput item)
+ {
+ Outputs.Add(item);
+ }
+ ///
+ public void Clear()
+ {
+ Outputs.Clear();
+ }
+ ///
+ public bool Contains(SequenceGroupOutput item)
+ {
+ return Outputs.Contains(item);
+ }
+ ///
+ public void CopyTo(SequenceGroupOutput[] array, int arrayIndex)
+ {
+ Outputs.CopyTo(array, arrayIndex);
+ }
+ ///
+ public IEnumerator GetEnumerator()
+ {
+ return Outputs.GetEnumerator();
+ }
+ ///
+ public int IndexOf(SequenceGroupOutput item)
+ {
+ return Outputs.IndexOf(item);
+ }
+ ///
+ public void Insert(int index, SequenceGroupOutput item)
+ {
+ Outputs.Insert(index, item);
+ }
+ ///
+ public bool Remove(SequenceGroupOutput item)
+ {
+ return Outputs.Remove(item);
+ }
+ ///
+ public void RemoveAt(int index)
+ {
+ Outputs.RemoveAt(index);
+ }
+
+ System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
+ {
+ return Outputs.GetEnumerator();
+ }
+ #endregion
+ }
+}
diff --git a/LLama/Experimental/Common/SamplingMetadata.cs b/LLama/Experimental/Common/SamplingMetadata.cs
new file mode 100644
index 000000000..737e8fc83
--- /dev/null
+++ b/LLama/Experimental/Common/SamplingMetadata.cs
@@ -0,0 +1,28 @@
+using LLama.Experimental.Utils;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Common
+{
+ ///
+ /// Metadata for input sequences. Used in sampler.
+ ///
+ /// List of seq ids.
+ /// Seq_id -> SequenceData.
+ /// Lengths of prompts.
+ /// Token indices selected for sampling.
+ public record class SamplingMetadata(
+ IList SeqIds,
+ IDictionary SeqData,
+ IList PromptLengths,
+ IList SelectedTokenIndices
+ )
+ {
+ ///
+ public override string ToString()
+ {
+ return ClassStringFormatter.Format(this);
+ }
+ }
+}
diff --git a/LLama/Experimental/Common/SchedulerOutputs.cs b/LLama/Experimental/Common/SchedulerOutputs.cs
new file mode 100644
index 000000000..710846c8b
--- /dev/null
+++ b/LLama/Experimental/Common/SchedulerOutputs.cs
@@ -0,0 +1,111 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace LLama.Experimental.Common
+{
+ ///
+ /// The scheduling decision made from a scheduler.
+ ///
+ /// Scheduled sequence groups.
+ /// Number of prefill groups scheduled.
+ /// Total number of batched tokens.
+ /// Sequence groups that are going to be ignored.
+ public record class SchedulerOutputs(
+ IEnumerable ScheduledSeqGroups,
+ int NumPrefillGroups,
+ int NumBatchedTokens,
+ IEnumerable IgnoredSeqGroups
+ )
+ {
+ ///
+ /// Whether the scheduler output is empty.
+ ///
+ public bool IsEmpty => ScheduledSeqGroups.Count() == 0;
+ }
+
+ ///
+ ///
+ ///
+ /// A sequence group that's scheduled.
+ ///
+ /// The total chunk size (number of tokens) to process for next iteration.
+ /// 1 for decoding. Same as prompt tokens for prefill, but if prefill is
+ /// chunked, it can be smaller than that.
+ ///
+ public record class ScheduledSequenceGroup(SequenceGroup SeqGroup, int TokenChunkSize);
+
+ ///
+ /// The requests that are scheduled from a waiting queue.
+ ///
+ ///
+ ///
+ ///
+ public record class SchedulerPrefillOutputs(
+ LinkedList RemainingWaitingQueue,
+ List SeqGroups,
+ List IgnoredSeqGroups
+ )
+ {
+ ///
+ public static SchedulerPrefillOutputs CreateEmpty()
+ {
+ return new SchedulerPrefillOutputs(
+ new LinkedList(),
+ new List(),
+ new List()
+ );
+ }
+ }
+
+ ///
+ /// The requests that are scheduled from a running queue.
+ ///
+ /// Could contain prefill (prefill that's chunked) or decodes. If there's not
+ /// enough memory, it can be preempted (for recompute) or swapped out.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public record class SchedulerRunningOutputs(
+ LinkedList RemainingRunningQueue,
+ List DecodeSeqGroups,
+ List PrefillSeqGroups,
+ List PreemptedSeqGroups,
+ List SwappedOutSeqGroups
+ )
+ {
+ ///
+ public static SchedulerRunningOutputs CreateEmpty()
+ {
+ return new SchedulerRunningOutputs(
+ new LinkedList(),
+ new List(),
+ new List(),
+ new List(),
+ new List()
+ );
+ }
+ }
+
+ ///
+ /// The requests that are scheduled from a swap queue.
+ /// Could contain prefill (prefill that's chunked) or decodes.
+ ///
+ ///
+ ///
+ public record class SchedulerSwappedInOutputs(
+ List DecodeSeqGroups,
+ List PrefillSeqGroups
+ )
+ {
+ ///
+ public static SchedulerSwappedInOutputs CreateEmpty()
+ {
+ return new SchedulerSwappedInOutputs(new List(), new List());
+ }
+ }
+}
diff --git a/LLama/Experimental/Common/SchedulingBudget.cs b/LLama/Experimental/Common/SchedulingBudget.cs
new file mode 100644
index 000000000..45822f43f
--- /dev/null
+++ b/LLama/Experimental/Common/SchedulingBudget.cs
@@ -0,0 +1,85 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Text;
+
+namespace LLama.Experimental.Common
+{
+ ///
+ /// The available slots for scheduling.
+ ///
+ internal class SchedulingBudget
+ {
+ private HashSet _requestIdsNumBatchedTokens;
+
+ private HashSet _requestIdsNumCurrentSeqs;
+
+ public int TokenBudget { get; set; }
+
+ public int MaxNumSeqs { get; set; }
+
+ public int RemainingTokenBudget => TokenBudget - NumBatchedTokens;
+
+ internal int NumCurrentSeqs { get; set; }
+
+ internal int NumBatchedTokens { get; set; }
+
+ public SchedulingBudget(int tokenBudget, int maxNumSeqs)
+ {
+ TokenBudget = tokenBudget;
+ MaxNumSeqs = maxNumSeqs;
+ _requestIdsNumBatchedTokens = new HashSet();
+ _requestIdsNumCurrentSeqs = new HashSet();
+ NumCurrentSeqs = 0;
+ NumBatchedTokens = 0;
+ }
+
+ public bool CanSchedule(int numNewTokens, int numNewSeqs)
+ {
+ Debug.Assert(numNewTokens >= 0);
+ Debug.Assert(numNewSeqs >= 0);
+ return NumBatchedTokens + numNewTokens <= TokenBudget
+ && NumCurrentSeqs + numNewSeqs <= MaxNumSeqs;
+ }
+
+ public void AddNumBatchedTokens(string requestId, int numBatchedTokens)
+ {
+ if (_requestIdsNumBatchedTokens.Contains(requestId))
+ {
+ return;
+ }
+
+ _requestIdsNumBatchedTokens.Add(requestId);
+ NumBatchedTokens += numBatchedTokens;
+ }
+
+ public void SubtractNumBatchedTokens(string requestId, int numBatchedTokens)
+ {
+ if (_requestIdsNumBatchedTokens.Contains(requestId))
+ {
+ _requestIdsNumBatchedTokens.Remove(requestId);
+ NumBatchedTokens -= numBatchedTokens;
+ }
+ }
+
+ public void AddNumSeqs(string requestId, int numCurrentSeqs)
+ {
+ if (_requestIdsNumCurrentSeqs.Contains(requestId))
+ {
+ return;
+ }
+
+ _requestIdsNumCurrentSeqs.Add(requestId);
+ NumCurrentSeqs += numCurrentSeqs;
+ }
+
+ public void SubtractNumSeqs(string requestId, int numCurrentSeqs)
+ {
+ if (_requestIdsNumCurrentSeqs.Contains(requestId))
+ {
+ _requestIdsNumCurrentSeqs.Remove(requestId);
+ NumCurrentSeqs -= numCurrentSeqs;
+ }
+ }
+ }
+}
diff --git a/LLama/Experimental/Common/Sequence.cs b/LLama/Experimental/Common/Sequence.cs
new file mode 100644
index 000000000..d7e4b8e99
--- /dev/null
+++ b/LLama/Experimental/Common/Sequence.cs
@@ -0,0 +1,151 @@
+using System;
+using System.Collections.Generic;
+using System.Reflection;
+using System.Text;
+
+namespace LLama.Experimental.Common
+{
+ ///
+ /// Stores the data, status, and other information of a sequence.
+ ///
+ public sealed class Sequence
+ {
+ ///
+ /// The ID of the sequence.
+ ///
+ public int Id { get; private set; }
+
+ ///
+ /// The prompt of the sequence.
+ ///
+ public string? Prompt { get; }
+
+ ///
+ /// Data used for computation of this sequence.
+ ///
+ public SequenceData Data { get; private set; }
+
+ ///
+ /// The output text of the sequence.
+ /// Note that it should only be set when you want to implement some interfaces yourself.
+ ///
+ public string OutputText { get; internal set; }
+
+ ///
+ /// Input + output token IDs.
+ /// Note that it should only be set when you want to implement some interfaces yourself.
+ ///
+ public IEnumerable TokenIds => Data. TokenIds;
+
+ ///
+ /// Length of the sequence data.
+ ///
+ public int Length => Data.Length;
+
+ ///
+ /// The status of the sequence.
+ ///
+ public SequenceStatus Status { get; internal set; }
+
+ ///
+ /// The stopping string of the sequence if it stops because of this string.
+ ///
+ public string? StoppingString { get; internal set; }
+
+ ///
+ /// The stopping token of the sequence if it stops because of this token.
+ ///
+ public int? StoppingTokenId { get; internal set; }
+
+ ///
+ /// The offset of the sequence in the decoding process.
+ /// It's useful when the tokenizer may use more than 1 token id to represent a token.
+ ///
+ public int IncrementalDecodingOffset { get; internal set; }
+
+ ///
+ /// Whether the sequence has finished.
+ ///
+ public bool IsFinished
+ {
+ get
+ {
+ return Status is SequenceStatus.FinishStopped
+ or SequenceStatus.FinishLengthCapped
+ or SequenceStatus.FinishAborted
+ or SequenceStatus.FinishIgnored;
+ }
+ }
+
+ ///
+ /// The output token ids of the sequence.
+ ///
+ public IList OutputTokens => Data.OutputTokenIds;
+
+ ///
+ /// Whether the sequence is at prefill stage.
+ ///
+ public bool IsPrefill => Data.Stage == SequenceStage.Prefill;
+
+ ///
+ /// Get the number of new tokens to be computed.
+ ///
+ public int NumNewTokens
+ {
+ get
+ {
+ if (Data.Stage == SequenceStage.Decode)
+ {
+ return 1;
+ }
+ return Data.NumUncomputedTokens;
+ }
+ }
+
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public Sequence(int id, string? prompt, IList promptTokens)
+ {
+ Id = id;
+ Prompt = prompt;
+ Data = new SequenceData(promptTokens);
+ OutputText = "";
+ Status = SequenceStatus.Waiting;
+ IncrementalDecodingOffset = Data.PromptTokenIds.Count;
+
+ // TODO: deal with incremental detokenization.
+ }
+
+ ///
+ /// Add a token id to the output ids of the sequence data.
+ ///
+ ///
+ public void AppendToken(int tokenId)
+ // TODO: logprobs?
+ {
+ Data.AppendToken(tokenId);
+ }
+
+ ///
+ /// Get a new sequence with same data but new id.
+ ///
+ ///
+ ///
+ public Sequence Fork(int newSeqId)
+ {
+ // clone the current data.
+ var clone = (Sequence)MemberwiseClone();
+ clone.Data = new SequenceData(
+ new List(Data.PromptTokenIds),
+ new List(Data.OutputTokenIds)
+ );
+ // set new id
+ clone.Id = newSeqId;
+ return clone;
+ }
+ }
+}
diff --git a/LLama/Experimental/Common/SequenceData.cs b/LLama/Experimental/Common/SequenceData.cs
new file mode 100644
index 000000000..0cd284b36
--- /dev/null
+++ b/LLama/Experimental/Common/SequenceData.cs
@@ -0,0 +1,120 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Text;
+
+namespace LLama.Experimental.Common
+{
+ ///
+ /// Data associated with a sequence.
+ ///
+ public class SequenceData
+ {
+ ///
+ /// The token IDs of the prompt.
+ ///
+ public IList PromptTokenIds { get; set; }
+
+ ///
+ /// The token IDs of the output. Set to an empty list if None.
+ ///
+ public List OutputTokenIds { get; set; }
+
+ ///
+ /// The stage of the sequence data.
+ ///
+ public SequenceStage Stage { get; private set; }
+
+ ///
+ /// The number of all the tokens in the sequence, including prompt and output.
+ ///
+ public int Length => OutputTokenIds.Count + PromptTokenIds.Count;
+
+ ///
+ /// All the token IDs, including prompt and output.
+ ///
+ public IEnumerable TokenIds => PromptTokenIds.Concat(OutputTokenIds);
+
+ ///
+ /// The number of prefill tokens that are already computed.
+ ///
+ public int NumComputedTokens { get; private set; }
+
+ ///
+ /// The number of prefil tokens that are not computed.
+ ///
+ public int NumUncomputedTokens => Length - NumComputedTokens;
+
+ ///
+ /// The last token ID in the sequence.
+ ///
+ public int LastTokenId
+ {
+ get
+ {
+ if(OutputTokenIds.Count == 0)
+ {
+ return PromptTokenIds[PromptTokenIds.Count - 1];
+ }
+ return OutputTokenIds[OutputTokenIds.Count - 1];
+ }
+ }
+
+ ///
+ ///
+ ///
+ ///
+ ///
+ public SequenceData(IList promptTokens, IEnumerable? outputTokens = null)
+ {
+ OutputTokenIds = outputTokens is not null ? new List(outputTokens) : new List();
+ PromptTokenIds = promptTokens;
+
+ // TODO: cumulative_logprob?
+ NumComputedTokens = 0;
+ Stage = SequenceStage.Prefill;
+ }
+
+ ///
+ /// Add a token id to the output token ids.
+ ///
+ ///
+ public void AppendToken(int tokenId)
+ {
+ OutputTokenIds.Add(tokenId);
+ }
+
+ ///
+ /// Update number of tokens computed so far.
+ ///
+ ///
+ public void UpdateNumComputedTokens(int numNewComputedTokens)
+ {
+ NumComputedTokens += numNewComputedTokens;
+ Debug.Assert(NumComputedTokens <= Length);
+ if(NumUncomputedTokens == 0)
+ {
+ Stage = SequenceStage.Decode;
+ }
+ }
+
+ ///
+ /// Reset the number of computed tokens from this sequence. It is
+ /// supposed to be called when a sequence needs to be started from
+ /// the beginning again(e.g., sequence is preempted).
+ ///
+ public void ResetStageForRecompute()
+ {
+ NumComputedTokens = 0;
+ Stage = SequenceStage.Prefill;
+ }
+
+ ///
+ public override string ToString()
+ {
+ return $"SequenceData(\n PromptTokens: {string.Join(", ", PromptTokenIds)}, \n " +
+ $"OutputTokens: {string.Join(", ", OutputTokenIds)}, Stage: {Stage}\n)";
+ }
+ }
+}
diff --git a/LLama/Experimental/Common/SequenceGroup.cs b/LLama/Experimental/Common/SequenceGroup.cs
new file mode 100644
index 000000000..eb5a14b3c
--- /dev/null
+++ b/LLama/Experimental/Common/SequenceGroup.cs
@@ -0,0 +1,268 @@
+using LLama.Experimental.Abstractions;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace LLama.Experimental.Common
+{
+ ///
+ /// A group of sequences that are generated from the same prompt.
+ ///
+ public sealed class SequenceGroup
+ // TODO: Multi-modal data
+ {
+ ///
+ /// The ID of the request.
+ ///
+ public string RequestId { get; }
+
+ ///
+ /// Mapping from seq ID to the sequence.
+ ///
+ public IDictionary SeqDict { get; }
+
+ ///
+ /// The sampling method to do the sampling.
+ ///
+ public ISamplingMethod SamplingMethod { get; set; }
+
+ ///
+ /// The stopping criteria to decide whether the generation of the sequence should be stopped.
+ ///
+ public IStoppingCriteria StoppingCriteria { get; set; }
+
+ ///
+ /// The metrics for the scheduling and inference of this sequence group.
+ ///
+ public RequestMetrics Metrics { get; }
+
+ ///
+ /// The common prompt of the sequences in this sequence group.
+ ///
+ public string? Prompt
+ {
+ get
+ {
+ // All sequences in the group should have the same prompt.
+ // We use the prompt of an arbitrary sequence.
+ return SeqDict.First().Value.Prompt;
+ }
+ }
+
+ ///
+ /// The prompt tokens of the sequences in this sequence group.
+ ///
+ public IList PromptTokenIds
+ {
+ get
+ {
+ return SeqDict.First().Value.Data.PromptTokenIds;
+ }
+ }
+
+ ///
+ /// Whether the request of this sequence group has been finished.
+ ///
+ public bool IsFinished
+ {
+ get
+ {
+ return SeqDict.Values.All(seq => seq.IsFinished);
+ }
+ }
+
+ ///
+ /// Whether this sequence group is at prefill stage.
+ ///
+ public bool IsPrefill
+ {
+ get
+ {
+ return SeqDict.Values.First().IsPrefill;
+ }
+ }
+
+ ///
+ /// The number of sequences in this sequence group.
+ ///
+ public int NumSeqs => SeqDict.Count;
+
+ ///
+ /// The number of unfinished sequences in this sequence group.
+ ///
+ public int NumUnfinishedSeqs => GetUnfinishedSeqs().Count();
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public SequenceGroup(string requestId, Sequence[] sequences, ISamplingMethod samplingMethod, IStoppingCriteria stoppingCriteria, DateTime arrivalTime)
+ {
+ if(sequences.Length == 0)
+ {
+ throw new ArgumentException($"The sequences bypassed to SequenceGroup cannot be empty.");
+ }
+
+ RequestId = requestId;
+ SeqDict = sequences.ToDictionary(Sequence => Sequence.Id, Sequence => Sequence);
+ SamplingMethod = samplingMethod;
+ StoppingCriteria = stoppingCriteria;
+ Metrics = new RequestMetrics()
+ {
+ ArrivalTime = arrivalTime
+ };
+ }
+
+ ///
+ /// Sets the first token time for Request level timings.
+ ///
+ ///
+ public void MaybeSetFirstTokenTime(DateTime time)
+ {
+ if (Metrics.FirstTokenTime is null)
+ {
+ Metrics.FirstTokenTime = time;
+ }
+ }
+
+ ///
+ /// Sets the first scheduled time and time in queue for Request level timings
+ ///
+ ///
+ public void MaybeSetFirstScheduledTime(DateTime time)
+ {
+ if (Metrics.FirstScheduledTime is null)
+ {
+ Metrics.FirstScheduledTime = time;
+ Metrics.TimeInQueue = time - Metrics.ArrivalTime;
+ }
+ }
+
+ ///
+ /// Sets the finished time for Request level timings.
+ ///
+ ///
+ public void SetFinishedTime(DateTime time)
+ {
+ Metrics.FinishedTime = time;
+ }
+
+ ///
+ /// Get all sequences with the given status.
+ ///
+ ///
+ ///
+ public IEnumerable GetSeqsWithStatus(SequenceStatus status)
+ {
+ return SeqDict.Values.Where(seq => seq.Status == status);
+ }
+
+ ///
+ /// Get all sequences in this sequence group.
+ ///
+ ///
+ public IEnumerable GetAllSeqs()
+ {
+ return SeqDict.Values;
+ }
+
+ ///
+ /// Get all unfinished sequences in this sequence group.
+ ///
+ ///
+ public IEnumerable GetUnfinishedSeqs()
+ {
+ return SeqDict.Values.Where(seq => !seq.IsFinished);
+ }
+
+ ///
+ /// Get all finished sequences in this sequence group.
+ ///
+ ///
+ public IEnumerable GetFinishedSeqs()
+ {
+ return SeqDict.Values.Where(seq => seq.IsFinished);
+ }
+
+ ///
+ /// The maximum number of sequences running in parallel in the remaining
+ /// lifetime of the request.
+ ///
+ ///
+ public int GetMaxNumRunningSeqs()
+ {
+ int defaultValue = NumUnfinishedSeqs;
+ return SamplingMethod.GetMaxNumRunningSeqs(defaultValue, NumSeqs);
+ }
+
+ ///
+ /// Add a new sequence.
+ ///
+ ///
+ ///
+ public void Add(Sequence seq)
+ {
+ if (SeqDict.ContainsKey(seq.Id))
+ {
+ throw new ArgumentException($"Sequence {seq.Id} already exists.");
+ }
+ SeqDict[seq.Id] = seq;
+ }
+
+ ///
+ /// Remove the sequence of seq id.
+ ///
+ ///
+ ///
+ public void Remove(int seqId)
+ {
+ if (!SeqDict.ContainsKey(seqId))
+ {
+ throw new ArgumentException($"Sequence {seqId} not found.");
+ }
+ SeqDict.Remove(seqId);
+ }
+
+ ///
+ /// Get the number of tokens to be computed in this sequence group.
+ ///
+ public int GetNumComputedTokens()
+ {
+ int numUncomputedTokens = 0;
+ foreach(var seq in GetAllSeqs())
+ {
+ numUncomputedTokens += seq.Data.NumUncomputedTokens;
+ }
+ return numUncomputedTokens;
+ }
+
+ ///
+ /// Update number of tokens computed so far.
+ ///
+ ///
+ public void UpdateNumComputedTokens(int numNewComputedTokens)
+ {
+ foreach(var seq in SeqDict.Values)
+ {
+ if (!seq.IsFinished)
+ {
+ seq.Data.UpdateNumComputedTokens(numNewComputedTokens);
+ }
+ }
+ }
+
+ ///
+ public override string ToString()
+ {
+ return $"SequenceGroup(RequestId = {RequestId}, \n " +
+ $"SamplingMethod = ({SamplingMethod.GetType().Name}), \n " +
+ $"StoppingCriteria = ({StoppingCriteria.GetType().Name}), \n " +
+ $"NumSeqs = {SeqDict.Count}\n)";
+ }
+ }
+}
diff --git a/LLama/Experimental/Common/SequenceGroupMetadata.cs b/LLama/Experimental/Common/SequenceGroupMetadata.cs
new file mode 100644
index 000000000..bf9ff1d0a
--- /dev/null
+++ b/LLama/Experimental/Common/SequenceGroupMetadata.cs
@@ -0,0 +1,79 @@
+using LLama.Experimental.Abstractions;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace LLama.Experimental.Common
+{
+ ///
+ /// Metadata for a sequence group.
+ ///
+ public class SequenceGroupMetadata
+ {
+ ///
+ /// The ID of the request.
+ ///
+ public string RequestId { get; set; }
+
+ ///
+ /// Whether the request is at prompt stage.
+ ///
+ public bool IsPrompt { get; set; }
+
+ ///
+ /// The sequence data. (Seq id -> sequence data)
+ ///
+ public Dictionary SeqData { get; set; }
+
+ ///
+ /// The sampling method used to generate the outputs.
+ ///
+ public ISamplingMethod SamplingMethod { get; set; }
+
+ ///
+ /// The stopping criteria to decide whether the generation of the sequence should be stopped.
+ ///
+ public IStoppingCriteria StoppingCriteria { get; set; }
+
+ ///
+ /// The number of tokens to be processed (per sequence).
+ ///
+ public int TokenChunkSize { get; set; }
+
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public SequenceGroupMetadata(string requestId, bool isPrompt, Dictionary seqData,
+ ISamplingMethod samplingMethod, IStoppingCriteria stoppingCriteria, int? tokenChunkSize)
+ {
+ RequestId = requestId;
+ IsPrompt = isPrompt;
+ SeqData = seqData;
+ SamplingMethod = samplingMethod;
+ StoppingCriteria = stoppingCriteria;
+
+ if(tokenChunkSize is null)
+ {
+ if (isPrompt)
+ {
+ TokenChunkSize = seqData.Values.First().Length;
+ }
+ else
+ {
+ TokenChunkSize = 1;
+ }
+ }
+ else
+ {
+ TokenChunkSize = tokenChunkSize.Value;
+ }
+ }
+ }
+}
diff --git a/LLama/Experimental/Common/SequenceGroupOutput.cs b/LLama/Experimental/Common/SequenceGroupOutput.cs
new file mode 100644
index 000000000..bdacf70d6
--- /dev/null
+++ b/LLama/Experimental/Common/SequenceGroupOutput.cs
@@ -0,0 +1,15 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Common
+{
+ ///
+ /// The model output associated with a sequence group.
+ ///
+ ///
+ public record class SequenceGroupOutput(List Samples)
+ {
+
+ }
+}
diff --git a/LLama/Experimental/Common/SequenceOutput.cs b/LLama/Experimental/Common/SequenceOutput.cs
new file mode 100644
index 000000000..76bab7791
--- /dev/null
+++ b/LLama/Experimental/Common/SequenceOutput.cs
@@ -0,0 +1,29 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Common
+{
+ ///
+ /// The model output associated with a sequence.
+ ///
+ public class SequenceOutput
+ // TODO: Beam search
+ {
+ ///
+ /// The output token ID.
+ ///
+ public int OutputTokenId { get; init; }
+
+ ///
+ /// The ID of the parent sequence (for forking in beam search).
+ ///
+ public int ParentSeqId { get; init; }
+
+ ///
+ /// The logprobs of the output token.
+ /// (Token id -> logP(x_i+1 | x_0, ..., x_i))
+ ///
+ public float[]? Logprobs { get; init; }
+ }
+}
diff --git a/LLama/Experimental/Common/SequenceStage.cs b/LLama/Experimental/Common/SequenceStage.cs
new file mode 100644
index 000000000..2fd76d35e
--- /dev/null
+++ b/LLama/Experimental/Common/SequenceStage.cs
@@ -0,0 +1,22 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Common
+{
+ ///
+ /// The sequence stage for .
+ ///
+ public enum SequenceStage
+ {
+ ///
+ /// The prefill stage, in which the model is processing your prompt.
+ ///
+ Prefill,
+
+ ///
+ /// The decode stage, in which the model is generating the output.
+ ///
+ Decode
+ }
+}
diff --git a/LLama/Experimental/Common/SequenceStatus.cs b/LLama/Experimental/Common/SequenceStatus.cs
new file mode 100644
index 000000000..7dadf5cba
--- /dev/null
+++ b/LLama/Experimental/Common/SequenceStatus.cs
@@ -0,0 +1,68 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Common
+{
+ ///
+ /// Status of a sequence.
+ ///
+ public enum SequenceStatus
+ {
+ ///
+ /// The sequence is waiting for scheduling.
+ ///
+ Waiting,
+
+ ///
+ /// The sequence is running.
+ ///
+ Running,
+
+ ///
+ /// The sequence has been swapped out due to some reasons.
+ ///
+ Swapped,
+
+ ///
+ /// The sequence has been finished because it's stopped by a stopping criteria.
+ ///
+ FinishStopped,
+
+ ///
+ /// The sequence has been finished because it reaches the maximum length.
+ ///
+ FinishLengthCapped,
+
+ ///
+ /// The sequence has been finished because it's aborted.
+ ///
+ FinishAborted,
+
+ ///
+ /// The sequence will never be processed for some reasons. Please check if the prompt length is too long.
+ ///
+ FinishIgnored
+ }
+
+ ///
+ public static class SequenceStatusExtensions
+ {
+ ///
+ /// Get the finished reason in OpenAI style
+ ///
+ ///
+ ///
+ public static string GetFinishedReason(this SequenceStatus status)
+ {
+ return status switch
+ {
+ SequenceStatus.FinishStopped => "stop",
+ SequenceStatus.FinishLengthCapped => "length",
+ SequenceStatus.FinishAborted => "abort",
+ SequenceStatus.FinishIgnored => "length",
+ _ => ""
+ };
+ }
+ }
+}
diff --git a/LLama/Experimental/Config/DeviceConfig.cs b/LLama/Experimental/Config/DeviceConfig.cs
new file mode 100644
index 000000000..b2b7d589b
--- /dev/null
+++ b/LLama/Experimental/Config/DeviceConfig.cs
@@ -0,0 +1,44 @@
+using LLama.Native;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Config
+{
+ ///
+ /// Device configuration for using LLM.
+ ///
+ public class DeviceConfig
+ {
+ ///
+ /// main_gpu interpretation depends on split_mode:
+ ///
+ /// -
+ /// None
+ /// The GPU that is used for the entire mode.
+ ///
+ /// -
+ /// Row
+ /// The GPU that is used for small tensors and intermediate results.
+ ///
+ /// -
+ /// Layer
+ /// Ignored.
+ ///
+ ///
+ ///
+ public int MainGpu { get; set; } = 0;
+
+ ///
+ /// How to split the model across multiple GPUs
+ ///
+ public GPUSplitMode SplitMode { get; set; } = GPUSplitMode.None;
+
+ ///
+ /// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
+ ///
+ public int GpuLayerCount { get; set; } = 20;
+
+ // TODO: Add a static method/property "Auto" to return a default DeviceConfig
+ }
+}
diff --git a/LLama/Experimental/Config/KvCacheConfig.cs b/LLama/Experimental/Config/KvCacheConfig.cs
new file mode 100644
index 000000000..59f271a7b
--- /dev/null
+++ b/LLama/Experimental/Config/KvCacheConfig.cs
@@ -0,0 +1,17 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Config
+{
+ ///
+ /// Configuration for the KV cache.
+ ///
+ public class KvCacheConfig
+ {
+ ///
+ /// The maximum CPU memory space used for saving kv cache swapped from GPU.
+ ///
+ public int MaxSwapSpace { get; set; }
+ }
+}
diff --git a/LLama/Experimental/Config/SchedulerConfig.cs b/LLama/Experimental/Config/SchedulerConfig.cs
new file mode 100644
index 000000000..2d4c0eecb
--- /dev/null
+++ b/LLama/Experimental/Config/SchedulerConfig.cs
@@ -0,0 +1,68 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Config
+{
+ ///
+ /// Scheduler configuration.
+ ///
+ public class SchedulerConfig
+ {
+ ///
+ /// Maximum number of tokens to be processed in a single iteration.
+ ///
+ public int MaxNumBatchedTokens { get; set; }
+
+ ///
+ /// Maximum number of sequences to be processed in a single iteration.
+ ///
+ public int MaxNumSequences { get; set; }
+
+ ///
+ /// Maximum length of a sequence (including prompt and generated text).
+ ///
+ public int MaxSequenceLength { get; set; }
+
+ ///
+ /// If True, prefill requests can be chunked based on the remaining max_num_batched_tokens.
+ ///
+ public bool EnableChunkedPrefill { get; set; }
+
+ ///
+ /// Apply a delay (of delay factor multiplied by previous prompt latency) before scheduling next prompt.
+ ///
+ public float DelayFactor { get; set; }
+
+ public SchedulerConfig(int maxNumBatchedTokens, int maxNumSequences, int maxSequenceLength, bool enableChunkedPrefill = false, float delayFactor = .0f)
+ {
+ MaxNumBatchedTokens = maxNumBatchedTokens;
+ MaxNumSequences = maxNumSequences;
+ MaxSequenceLength = maxSequenceLength;
+ EnableChunkedPrefill = enableChunkedPrefill;
+ DelayFactor = delayFactor;
+ }
+
+
+
+ ///
+ /// Verify if this configuration is valid and throw an exception if it's invalid.
+ ///
+ ///
+ public void ThrowIfInvalid()
+ {
+ if (MaxNumBatchedTokens < MaxSequenceLength && !EnableChunkedPrefill)
+ {
+ throw new ArgumentException($"MaxNumBatchedTokens ({MaxNumBatchedTokens}) is smaller than " +
+ $"MaxSequenceLength ({MaxSequenceLength}). This effectively limits the maximum sequence length to " +
+ $"MaxNumBatchedTokens. Please increase MaxNumBatchedTokens, decrease MaxSequenceLength or enable chunked prefill.");
+ }
+
+ if (MaxNumBatchedTokens < MaxNumSequences)
+ {
+ throw new ArgumentException($"MaxNumBatchedTokens ({MaxNumBatchedTokens}) must be greater than or equal to " +
+ $"MaxNumSequences ({MaxNumSequences}).");
+ }
+ }
+ }
+}
diff --git a/LLama/Experimental/Core/AntipromptStoppingCriteria.cs b/LLama/Experimental/Core/AntipromptStoppingCriteria.cs
new file mode 100644
index 000000000..b4edcd983
--- /dev/null
+++ b/LLama/Experimental/Core/AntipromptStoppingCriteria.cs
@@ -0,0 +1,31 @@
+using LLama.Experimental.Abstractions;
+using LLama.Experimental.Common;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Core
+{
+ // TODO: This is only the most simple implementation to run the test now. We should replace it in the future.
+ public class AntipromptStoppingCriteria: IStoppingCriteria
+ {
+ private string[] _antiprompts;
+
+ public AntipromptStoppingCriteria(string[] antiprompts)
+ {
+ _antiprompts = antiprompts;
+ }
+
+ public StoppingCriteriaOutput CheckStop(Sequence seq)
+ {
+ foreach (var antiprompt in _antiprompts)
+ {
+ if (seq.OutputText.EndsWith(antiprompt))
+ {
+ return new StoppingCriteriaOutput(SequenceStatus.FinishStopped, antiprompt, null);
+ }
+ }
+ return new StoppingCriteriaOutput(seq.Status, null, null);
+ }
+ }
+}
diff --git a/LLama/Experimental/Core/LLamaCpp/LLamaGreedySamplingMethod.cs b/LLama/Experimental/Core/LLamaCpp/LLamaGreedySamplingMethod.cs
new file mode 100644
index 000000000..ef3ffa525
--- /dev/null
+++ b/LLama/Experimental/Core/LLamaCpp/LLamaGreedySamplingMethod.cs
@@ -0,0 +1,41 @@
+using LLama.Experimental.Abstractions;
+using LLama.Experimental.Common;
+using LLama.Native;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Core.LLamaCpp
+{
+ // TODO: This is only the most simple implementation to run the example. It should be replaced in the future.
+ public class LLamaGreedySamplingMethod: ISamplingMethod
+ {
+ private LLamaContext _context;
+
+ public LLamaGreedySamplingMethod(LLamaContext context)
+ {
+ _context = context;
+ }
+
+ public int GetMaxNumRunningSeqs(int defaultValue, int currentNumSeqs)
+ {
+ return defaultValue;
+ }
+
+ ///
+ /// Whether to skip special tokens.
+ ///
+ public bool SkipSpecialTokens => false;
+
+ public SequenceOutput SampleSequence(Span logits, int seqId, SamplingMetadata samplingMetadata)
+ {
+ // Process token data array to select a final token
+ var candidates = LLamaTokenDataArray.Create(logits);
+ return new SequenceOutput()
+ {
+ OutputTokenId = (int)candidates.SampleTokenGreedy(_context.NativeHandle),
+ ParentSeqId = seqId
+ };
+ }
+ }
+}
diff --git a/LLama/Experimental/Core/LLamaCpp/LLamaTokenizer.cs b/LLama/Experimental/Core/LLamaCpp/LLamaTokenizer.cs
new file mode 100644
index 000000000..a798af822
--- /dev/null
+++ b/LLama/Experimental/Core/LLamaCpp/LLamaTokenizer.cs
@@ -0,0 +1,37 @@
+using LLama.Experimental.Abstractions;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace LLama.Experimental.Core.LLamaCpp
+{
+ ///
+ /// llama.cpp tokenizer.
+ ///
+ public sealed class LLamaTokenizer: ITokenizer
+ {
+ private LLamaContext _context;
+
+ public LLamaTokenizer(LLamaContext context)
+ {
+ _context = context;
+ }
+
+ ///
+ public IList Tokenize(string input)
+ {
+ // TODO: refactor this!!
+ return _context.Tokenize(input).Select(x => ((int)x)).ToArray();
+ }
+
+ ///
+ public int ConvertIdsToText(IEnumerable tokenIds, out string result, bool skipSpecialTokens = false)
+ {
+ // TODO: integrate `StreamingDecoder` here. Currently only English has been supported.
+ // We should add a byte array to `sequence`.
+ result = _context.DeTokenize(tokenIds.Select(x => (Native.LLamaToken)x).ToArray());
+ return tokenIds.Count();
+ }
+ }
+}
diff --git a/LLama/Experimental/Core/SequenceLengthStoopingCriteria.cs b/LLama/Experimental/Core/SequenceLengthStoopingCriteria.cs
new file mode 100644
index 000000000..3901d9df6
--- /dev/null
+++ b/LLama/Experimental/Core/SequenceLengthStoopingCriteria.cs
@@ -0,0 +1,28 @@
+using LLama.Experimental.Abstractions;
+using LLama.Experimental.Common;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Core
+{
+ // TODO: This is only the most simple implementation to run the test now. We should replace it in the future.
+ public class SequenceLengthStoopingCriteria: IStoppingCriteria
+ {
+ private int _maxSequenceLength;
+
+ public SequenceLengthStoopingCriteria(int maxSequenceLength)
+ {
+ _maxSequenceLength = maxSequenceLength;
+ }
+
+ public StoppingCriteriaOutput CheckStop(Sequence seq)
+ {
+ if(seq.Length >= _maxSequenceLength)
+ {
+ return new StoppingCriteriaOutput(SequenceStatus.FinishLengthCapped, null, null);
+ }
+ return new StoppingCriteriaOutput(seq.Status, null, null);
+ }
+ }
+}
diff --git a/LLama/Experimental/DeTokenizer.cs b/LLama/Experimental/DeTokenizer.cs
new file mode 100644
index 000000000..6d8b8fa0e
--- /dev/null
+++ b/LLama/Experimental/DeTokenizer.cs
@@ -0,0 +1,46 @@
+using System;
+using System.Collections.Generic;
+using System.ComponentModel.Design;
+using System.Linq;
+using System.Text;
+using LLama.Experimental.Abstractions;
+using LLama.Experimental.Common;
+using LLama.Extensions;
+
+namespace LLama.Experimental
+{
+ ///
+ /// Defines the process of converting sequence output to text.
+ ///
+ /// We should not expose this class to users. Implementing
+ /// should be the only thing the user need to concern to customize the tokenizing and detokenizing.
+ ///
+ internal static class DeTokenizer
+ {
+ private static int INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5;
+
+ ///
+ /// Decodes the new token for a sequence. In-place operation.
+ ///
+ ///
+ ///
+ ///
+ public static void DecodeSequenceInplace(Sequence seq, ITokenizer tokenizer, ISamplingMethod samplingMethod)
+ {
+ var allInputIds = seq.TokenIds;
+ var (offset, text) = DetokenizeIncrementally(tokenizer, allInputIds, seq.IncrementalDecodingOffset, skipSpecialTokens: true);
+
+ // TODO: deal with logprobs.
+
+ seq.IncrementalDecodingOffset = offset;
+ seq.OutputText += text;
+ }
+
+ private static (int, string) DetokenizeIncrementally(ITokenizer tokenizer, IEnumerable allInputIds, int offset, bool skipSpecialTokens = false)
+ {
+ var consumedTokens = tokenizer.ConvertIdsToText(allInputIds.Skip(offset), out var text, skipSpecialTokens);
+ offset += consumedTokens;
+ return (offset, text);
+ }
+ }
+}
diff --git a/LLama/Experimental/Extensions/DequeExtensions.cs b/LLama/Experimental/Extensions/DequeExtensions.cs
new file mode 100644
index 000000000..8a923bd1c
--- /dev/null
+++ b/LLama/Experimental/Extensions/DequeExtensions.cs
@@ -0,0 +1,90 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Extensions
+{
+ ///
+ /// Extension to use as a deque.
+ ///
+ public static class DequeExtensions
+ {
+ ///
+ public static void AddFront(this LinkedList deque, T item)
+ {
+ deque.AddFirst(item);
+ }
+
+ ///
+ public static void AddBack(this LinkedList deque, T item)
+ {
+ deque.AddLast(item);
+ }
+
+ ///
+ public static T RemoveFront(this LinkedList deque)
+ {
+ if (deque.Count == 0)
+ {
+ throw new InvalidOperationException("The deque is empty.");
+ }
+
+ T item = deque.First!.Value;
+ deque.RemoveFirst();
+ return item;
+ }
+
+ ///
+ public static T RemoveBack(this LinkedList deque)
+ {
+ if (deque.Count == 0)
+ {
+ throw new InvalidOperationException("The deque is empty.");
+ }
+
+ T item = deque.Last!.Value;
+ deque.RemoveLast();
+ return item;
+ }
+
+ ///
+ public static T PeekFront(this LinkedList deque)
+ {
+ if (deque.Count == 0)
+ {
+ throw new InvalidOperationException("The deque is empty.");
+ }
+
+ return deque.First!.Value;
+ }
+
+ ///
+ public static T PeekBack(this LinkedList deque)
+ {
+ if (deque.Count == 0)
+ {
+ throw new InvalidOperationException("The deque is empty.");
+ }
+
+ return deque.Last!.Value;
+ }
+
+ ///
+ public static void ExtendFront(this LinkedList deque, IEnumerable items)
+ {
+ foreach (var item in items)
+ {
+ deque.AddFront(item);
+ }
+ }
+
+ ///
+ public static void ExtendBack(this LinkedList deque, IEnumerable items)
+ {
+ foreach (var item in items)
+ {
+ deque.AddBack(item);
+ }
+ }
+ }
+}
diff --git a/LLama/Experimental/Extensions/SchedulingPolicyExtensions.cs b/LLama/Experimental/Extensions/SchedulingPolicyExtensions.cs
new file mode 100644
index 000000000..c168df392
--- /dev/null
+++ b/LLama/Experimental/Extensions/SchedulingPolicyExtensions.cs
@@ -0,0 +1,25 @@
+using LLama.Experimental.Abstractions;
+using LLama.Experimental.Common;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace LLama.Experimental.Extensions
+{
+ ///
+ public static class SchedulingPolicyExtensions
+ {
+ ///
+ /// Sorts the sequence groups by priority.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static LinkedList SortByPriority(this ISchedulingPolicy policy, DateTime now, LinkedList seqGroups)
+ {
+ return new LinkedList(seqGroups.OrderByDescending(seqGroups => policy.GetPriority(now, seqGroups)));
+ }
+ }
+}
diff --git a/LLama/Experimental/GlobalConfig.cs b/LLama/Experimental/GlobalConfig.cs
new file mode 100644
index 000000000..3286047c4
--- /dev/null
+++ b/LLama/Experimental/GlobalConfig.cs
@@ -0,0 +1,29 @@
+using LLama.Experimental.Abstractions;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental
+{
+ ///
+ /// Global configuration for LLamaSharp.
+ ///
+ public class GlobalConfig
+ {
+ public static ISamplingMethod DefaultSamplingMethod
+ {
+ get
+ {
+ throw new NotImplementedException();
+ }
+ }
+
+ public static IStoppingCriteria DefaultStoppingCriteria
+ {
+ get
+ {
+ throw new NotImplementedException();
+ }
+ }
+ }
+}
diff --git a/LLama/Experimental/KvCacheManager.cs b/LLama/Experimental/KvCacheManager.cs
new file mode 100644
index 000000000..b476678c0
--- /dev/null
+++ b/LLama/Experimental/KvCacheManager.cs
@@ -0,0 +1,31 @@
+using LLama.Experimental.Common;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental
+{
+ public class KvCacheManager
+ {
+ public bool CanAppendSlots(SequenceGroup seqGroup)
+ {
+ return true;
+ }
+
+ public AllocStatus CanAllocate(SequenceGroup seqGroup)
+ {
+ return AllocStatus.OK;
+ }
+
+ public void Allocate(SequenceGroup seqGroup)
+ {
+ }
+ }
+
+ public enum AllocStatus
+ {
+ OK,
+ Later,
+ Never
+ }
+}
diff --git a/LLama/Experimental/LLM.cs b/LLama/Experimental/LLM.cs
new file mode 100644
index 000000000..4cbcf1960
--- /dev/null
+++ b/LLama/Experimental/LLM.cs
@@ -0,0 +1,180 @@
+using LLama.Experimental.Abstractions;
+using LLama.Experimental.Common;
+using LLama.Experimental.Config;
+using LLama.Experimental.Utils;
+using Microsoft.Extensions.Logging;
+using System;
+using System.Collections.Generic;
+using System.Data;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Linq;
+
+namespace LLama.Experimental
+{
+//#if NET8_0_OR_GREATER
+// [Experimental("LLM")]
+//#endif
+ public sealed class LLM
+ {
+ private LLMEngine _engine;
+
+ private IdCounter _counter;
+
+ ///
+ /// Get or set the tokenizer used for this .
+ ///
+ public ITokenizer Tokenizer
+ {
+ get => _engine.Tokenizer;
+ set => _engine.Tokenizer = value;
+ }
+
+ public LLM(IModelRunner modelRunner, ITokenizer tokenizer, SchedulerConfig schedulerConfig, ILogger? logger = null)
+ {
+ _engine = new LLMEngine(schedulerConfig, modelRunner, tokenizer, logger);
+ _counter = new();
+ }
+
+ ///
+ /// Generates the completions for the input prompt. If you have multiple inputs,
+ /// please use ,
+ /// instead of calling this method multiple times.
+ ///
+ /// A prompt string.
+ /// The sampling parameters for text generation. If null, we use the default sampling parameters.
+ ///
+ /// The criteria to control whether a sequence generation should be stopped. If null, we use the default stopping criteria.
+ ///
+ /// The callback to get the progress of the generation.
+ /// A list of objects containing the generated completions in the same order as the input prompts.
+ public RequestOutput[] Generate(string prompt, ISamplingMethod? samplingMethod = null,
+ IStoppingCriteria? stoppingCriteria = null, ProgressCallback? progressCallback = null)
+ {
+ return Generate([prompt], samplingMethod, stoppingCriteria, progressCallback);
+ }
+
+ ///
+ /// Generates the completions for the input prompt. If you have multiple inputs,
+ /// please use
+ /// instead of calling this method multiple times.
+ ///
+ /// Token ids.
+ /// The sampling parameters for text generation. If null, we use the default sampling parameters.
+ ///
+ /// The criteria to control whether a sequence generation should be stopped. If null, we use the default stopping criteria.
+ ///
+ /// The callback to get the progress of the generation.
+ /// A list of objects containing the generated completions in the same order as the input prompt ids.
+ public RequestOutput[] Generate(IList promptTokenIds, ISamplingMethod? samplingMethod = null,
+ IStoppingCriteria? stoppingCriteria = null, ProgressCallback? progressCallback = null)
+ {
+ return Generate([promptTokenIds], samplingMethod, stoppingCriteria, progressCallback);
+ }
+
+ ///
+ /// Generates the completions for the input prompts.
+ /// This class automatically batches the given prompts, considering
+ /// the memory constraint. For the best performance, please put all of your prompts
+ /// into a single list and pass it to this method.
+ ///
+ /// A list of prompts to generate completions for.
+ /// The sampling parameters for text generation. If null, we use the default sampling parameters.
+ ///
+ /// The criteria to control whether a sequence generation should be stopped. If null, we use the default stopping criteria.
+ ///
+ /// The callback to get the progress of the generation.
+ /// A list of objects containing the generated completions in the same order as the input prompts.
+ public RequestOutput[] Generate(IEnumerable prompts, ISamplingMethod? samplingMethod = null,
+ IStoppingCriteria? stoppingCriteria = null, ProgressCallback? progressCallback = null)
+ {
+ if(prompts.Count() == 0)
+ {
+ return [];
+ }
+ samplingMethod ??= GlobalConfig.DefaultSamplingMethod;
+ stoppingCriteria ??= GlobalConfig.DefaultStoppingCriteria;
+
+ // Add requests to the engine.
+ foreach(var prompt in prompts)
+ {
+ AddRequest(prompt, samplingMethod, stoppingCriteria);
+ }
+ return RunEngine(progressCallback);
+ }
+
+ ///
+ /// Generates the completions for the input prompt ids list.
+ /// This class automatically batches the given prompts, considering
+ /// the memory constraint. For the best performance, please put all of your prompts
+ /// into a single list and pass it to this method.
+ ///
+ /// A list of token ids to generate completion for.
+ /// The sampling parameters for text generation. If null, we use the default sampling parameters.
+ ///
+ /// The criteria to control whether a sequence generation should be stopped. If null, we use the default stopping criteria.
+ ///
+ /// The callback to get the progress of the generation.
+ /// A list of objects containing the generated completions in the same order as the input prompt ids.
+ public RequestOutput[] Generate(IList> promptTokenIds, ISamplingMethod? samplingMethod = null,
+ IStoppingCriteria? stoppingCriteria = null, ProgressCallback? progressCallback = null)
+ {
+ if(promptTokenIds.Count == 0)
+ {
+ return [];
+ }
+ samplingMethod ??= GlobalConfig.DefaultSamplingMethod;
+ stoppingCriteria ??= GlobalConfig.DefaultStoppingCriteria;
+
+ // Add requests to the engine.
+ foreach(var prompt in promptTokenIds)
+ {
+ AddRequest(null, samplingMethod, stoppingCriteria);
+ }
+ return RunEngine(progressCallback);
+ }
+
+ private void AddRequest(string? prompt, ISamplingMethod samplingMethod, IStoppingCriteria stoppingCriteria, IList? promptTokenIds = null)
+ {
+ var requestId = _counter.Next().ToString();
+ _engine.AddRequest(requestId, prompt, samplingMethod, stoppingCriteria, promptTokenIds, DateTime.Now);
+ }
+
+ private RequestOutput[] RunEngine(ProgressCallback? callback)
+ {
+ float numRequests = _engine.NumUnfinishedRequests;
+ Debug.Assert(numRequests - 0 > 0.0001f); // assert the number of requests is not zero
+ int completedRequests = 0;
+ List outputs = new();
+ while (_engine.HasUnfinishedRequests)
+ {
+ var stepOutputs = _engine.Step();
+ foreach(var output in stepOutputs)
+ {
+ if (output.Finished)
+ {
+ outputs.Add(output);
+ if(callback is not null)
+ {
+ completedRequests++;
+ callback(completedRequests / numRequests);
+ }
+ }
+ }
+ }
+ // Sort the outputs by request ID.
+ // This is necessary because some requests may be finished earlier than its previous requests.
+ return outputs.OrderBy(o => o.RequestId).ToArray();
+ }
+
+
+ ///
+ /// A callback function to used for reporting the progress of the generation.
+ /// It will be called every time a new request is completed.
+ ///
+ /// The progress in percentage.
+ public delegate void ProgressCallback(float progress);
+ }
+
+}
+
diff --git a/LLama/Experimental/LLMEngine.cs b/LLama/Experimental/LLMEngine.cs
new file mode 100644
index 000000000..ce568d80d
--- /dev/null
+++ b/LLama/Experimental/LLMEngine.cs
@@ -0,0 +1,229 @@
+using System;
+using System.Collections.Generic;
+using LLama.Experimental.Abstractions;
+using Microsoft.Extensions.Logging;
+using LLama.Experimental.Common;
+using LLama.Experimental.Utils;
+using LLama.Extensions;
+using System.Linq;
+using System.Diagnostics;
+using LLama.Experimental.Config;
+
+namespace LLama.Experimental
+{
+ ///
+ /// An LLM engine that receives requests and generates texts.
+ ///
+ /// It receives requests
+ /// from clients and generates texts from the LLM.It includes a tokenizer, a
+ /// language model, and GPU memory space allocated for intermediate states(aka KV cache).
+ /// This class utilizes iteration-level scheduling and efficient memory management
+ /// to maximize the serving throughput.
+ ///
+ internal sealed class LLMEngine
+ {
+ private ILogger? _logger;
+
+ private IdCounter _seqCounter;
+
+ public Scheduler Scheduler { get; }
+
+ public IModelRunner ModelRunner { get; }
+
+ public ITokenizer Tokenizer { get; set; }
+
+ ///
+ /// Gets the number of unfinished requests.
+ ///
+ public int NumUnfinishedRequests => Scheduler.GetNumUnfinishedSeqGroups();
+
+ ///
+ /// Returns True if there are unfinished requests.
+ ///
+ public bool HasUnfinishedRequests => Scheduler.HasUnfinishedSeqs();
+
+ public LLMEngine(SchedulerConfig schedulerConfig, IModelRunner modelRunner, ITokenizer tokenizer, ILogger? logger = null)
+ {
+ _seqCounter = new();
+ Scheduler = new Scheduler(schedulerConfig, new KvCacheConfig(), logger);
+ Tokenizer = tokenizer;
+ _logger = logger;
+ ModelRunner = modelRunner;
+ }
+
+ ///
+ /// Performs one decoding iteration and returns newly generated results.
+ ///
+ /// Details:
+ /// - Step 1: Schedules the sequences to be executed in the next
+ /// iteration and the token blocks to be swapped in/out/copy.
+ ///
+ /// - Depending on the scheduling policy,
+ /// sequences may be `preempted/reordered`.
+ /// - A Sequence Group(SG) refer to a group of sequences
+ /// that are generated from the same prompt.
+ ///
+ /// - Step 2: Calls the distributed executor to execute the model.
+ /// - Step 3: Processes the model output. This mainly includes:
+ ///
+ /// - Decodes the relevant outputs.
+ /// - Updates the scheduled sequence groups with model outputs
+ /// based on its `sampling parameters` (`use_beam_search` or not).
+ /// - Frees the finished sequence groups.
+ ///
+ /// - Finally, it creates and returns the newly generated results.
+ ///
+ ///
+ public List Step()
+ {
+ var (seqGroupMetadataList, schedulerOutputs) = Scheduler.Schedule();
+ var output = !schedulerOutputs.IsEmpty ? ModelRunner.ExecuteModel(seqGroupMetadataList) : new SamplerOutput([]);
+ return ProcessModelOutputs(output, schedulerOutputs);
+ }
+
+ ///
+ /// Add a request to the engine's request pool.
+ ///
+ /// The request is added to the request pool and will be processed by the
+ /// scheduler as `engine.step()` is called.The exact scheduling policy is
+ /// determined by the scheduler.
+ ///
+ /// The unique ID of the request.
+ /// The prompt string. Can be Null or empty if prompt_token_ids is provided.
+ /// The sampling parameters for text generation.
+ /// The stopping criteria to decide whether the generation should be stopped.
+ /// The token IDs of the prompt. If Null, we use the tokenizer to convert the prompts to token IDs.
+ /// The arrival time of the request. If Null, we use the current monotonic time.
+ public void AddRequest(string requestId, string? prompt, ISamplingMethod samplingMethod, IStoppingCriteria stoppingCriteria, IList? promptTokenIds = null, DateTime? arrivalTime = null)
+ {
+ arrivalTime ??= DateTime.Now;
+ if(promptTokenIds is null)
+ {
+ Debug.Assert(prompt is not null);
+ promptTokenIds = Tokenizer.Tokenize(prompt!);
+ }
+ else if (!string.IsNullOrEmpty(prompt))
+ {
+ _logger?.LogWarning("Both prompt and prompt_token_ids are provided. The prompt will be ignored.");
+ }
+
+ var seqId = _seqCounter.Next();
+ var seq = new Sequence(seqId, prompt, promptTokenIds);
+ var seqGroup = new SequenceGroup(requestId, [seq], samplingMethod, stoppingCriteria, arrivalTime.Value);
+
+ // Add the sequence group to the scheduler.
+ Scheduler.AddSeqGroup(seqGroup);
+ }
+
+ private List ProcessModelOutputs(SamplerOutput outputs, SchedulerOutputs schedulerOutputs)
+ {
+ var now = DateTime.Now;
+ // Update the scheduled sequence groups with the model outputs.
+ var scheduledSeqGroups = schedulerOutputs.ScheduledSeqGroups;
+ Debug.Assert(scheduledSeqGroups.Count() == outputs.Count);
+ int i = 0;
+ foreach(var scheduledSeqGroup in scheduledSeqGroups)
+ {
+ var output = outputs[i];
+ var seqGroup = scheduledSeqGroup.SeqGroup;
+ seqGroup.UpdateNumComputedTokens(scheduledSeqGroup.TokenChunkSize);
+ ProcessSequenceGroupOutputs(seqGroup, output);
+ i++;
+ }
+
+ // Free the finished sequence groups.
+ Scheduler.FreeFinishedSeqGroups();
+
+ // Create the outputs.
+ List requestOutputs = new();
+ foreach(var scheduledSeqGroup in scheduledSeqGroups)
+ {
+ var seqGroup = scheduledSeqGroup.SeqGroup;
+ seqGroup.MaybeSetFirstTokenTime(now);
+ requestOutputs.Add(RequestOutput.FromSeqGroup(seqGroup));
+ }
+ foreach(var seqGroup in schedulerOutputs.IgnoredSeqGroups)
+ {
+ requestOutputs.Add(RequestOutput.FromSeqGroup(seqGroup));
+ }
+
+ // TODO: log stats here.
+ return requestOutputs;
+ }
+
+ private void ProcessSequenceGroupOutputs(SequenceGroup seqGroup, SequenceGroupOutput outputs)
+ {
+ // TODO: support using logprobs
+ var samples = outputs.Samples;
+ var parentSeqs = seqGroup.GetSeqsWithStatus(SequenceStatus.Running);
+ var existingFinishedSeqs = seqGroup.GetFinishedSeqs();
+ var parentChildDict = parentSeqs.ToDictionary(x => x.Id, _ => new List());
+ foreach(var sample in samples)
+ {
+ parentChildDict[sample.ParentSeqId].Add(sample);
+ }
+ // List of (child, parent)
+ List<(Sequence, Sequence)> childSeqs = new();
+
+ foreach(var parent in parentSeqs)
+ {
+ var childSamples = parentChildDict[parent.Id];
+ if(childSamples.Count == 0)
+ {
+ // This parent sequence has no children samples. Remove the parent sequence
+ // from the sequence group since it will not be used in the future iterations.
+ parent.Status = SequenceStatus.FinishAborted;
+ seqGroup.Remove(parent.Id);
+ Scheduler.FreeSeq(parent);
+ continue;
+ }
+ foreach(var childSample in childSamples.SkipLast(1))
+ {
+ var newChildSeqId = _seqCounter.Next();
+ var child = parent.Fork(newChildSeqId);
+ child.AppendToken(childSample.OutputTokenId);
+ childSeqs.Add((child, parent));
+ }
+ // Continue the parent sequence for the last child sample.
+ // We reuse the parent sequence here to reduce redundant memory
+ // copies, especially when using non-beam search sampling methods.
+ var lastChildSample = childSamples.Last();
+ parent.AppendToken(lastChildSample.OutputTokenId);
+ childSeqs.Add((parent, parent));
+ }
+
+ foreach(var (seq, _) in childSeqs)
+ {
+ DeTokenizer.DecodeSequenceInplace(seq, Tokenizer, seqGroup.SamplingMethod);
+ var stoppingCriteriaOutput = seqGroup.StoppingCriteria.CheckStop(seq);
+ seq.Status = stoppingCriteriaOutput.Status;
+ seq.StoppingTokenId = stoppingCriteriaOutput.StoppingTokenId;
+ seq.StoppingString = stoppingCriteriaOutput.StoppingString;
+ }
+
+ // Only implement non beam-search case here now.
+ // TODO: deal with beam search.
+ {
+ // For newly created child sequences, add them to the sequence group.
+ foreach(var (seq, parent) in childSeqs)
+ {
+ if(seq != parent) // if the reference are not the same
+ {
+ seqGroup.Add(seq);
+ }
+ // TODO: see if we need to do the fork in the scheduler.
+ }
+
+ // NOTE: be careful of this logic.
+ foreach(var (seq, parent) in childSeqs)
+ {
+ if(seq == parent && seq.IsFinished)
+ {
+ Scheduler.FreeSeq(seq);
+ }
+ }
+ return;
+ }
+ }
+ }
+}
diff --git a/LLama/Experimental/LLamaModelRunner.cs b/LLama/Experimental/LLamaModelRunner.cs
new file mode 100644
index 000000000..f29ac5559
--- /dev/null
+++ b/LLama/Experimental/LLamaModelRunner.cs
@@ -0,0 +1,10 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental
+{
+ internal class LLamaModelRunner
+ {
+ }
+}
diff --git a/LLama/Experimental/Runner/LLamaCpp/LLamaCppRunnerInput.cs b/LLama/Experimental/Runner/LLamaCpp/LLamaCppRunnerInput.cs
new file mode 100644
index 000000000..32595ba3b
--- /dev/null
+++ b/LLama/Experimental/Runner/LLamaCpp/LLamaCppRunnerInput.cs
@@ -0,0 +1,92 @@
+using LLama.Experimental.Common;
+using LLama.Native;
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Text;
+
+namespace LLama.Experimental.Runner.LLamaCpp
+{
+ ///
+ /// Input special for .
+ ///
+ public class LLamaCppRunnerInput
+ // TODO: get this from a pool?
+ {
+ public int[] TokenIds { get; }
+
+ public int[] Positions { get; }
+
+ public int[] SeqIdCount { get; }
+
+ public int[][] SeqIds { get; }
+
+ public bool[] WithLogits { get; }
+
+ public IntPtr[] SeqIdsPtrs { get; }
+
+ ///
+ /// Construct from .
+ ///
+ ///
+ public LLamaCppRunnerInput(ModelRunnerInput input)
+ {
+ Debug.Assert(input.TokenIds.Length == input.Positions.Length);
+ Debug.Assert(input.TokenIds.Length == input.SeqIds.Length);
+ Debug.Assert(input.TokenIds.Length == input.WithLogits.Length);
+ TokenIds = input.TokenIds;
+ Positions = input.Positions;
+
+ // TODO: Now we never put a token in multiple sequences,
+ // which may impact on the speed of the model in some cases.
+ // We should consider to support this in the future.
+ SeqIdCount = Enumerable.Repeat(1, TokenIds.Length).ToArray();
+ SeqIds = new int[TokenIds.Length][];
+ for(int i = 0; i < input.SeqIds.Length; i++)
+ {
+ SeqIds[i] = [input.SeqIds[i]];
+ }
+ WithLogits = input.WithLogits;
+ SeqIdsPtrs = new IntPtr[SeqIds.Length];
+ }
+
+ ///
+ /// Convert to .
+ ///
+ /// [WARNING] You must hold the pin holder until the returned value will no longer be used.
+ ///
+ ///
+ ///
+ ///
+ internal LLamaNativeBatch ToLLamaNativeBatch(out GroupDisposable pinHolder)
+ {
+ pinHolder = new GroupDisposable();
+
+ unsafe
+ {
+ var batch = new LLamaNativeBatch
+ {
+ n_tokens = TokenIds.Length,
+ logits = (byte*)pinHolder.Add(WithLogits.AsMemory().Pin()).Pointer,
+
+ n_seq_id = (int*)pinHolder.Add(SeqIdCount.AsMemory().Pin()).Pointer,
+ pos = (LLamaPos*)pinHolder.Add(Positions.AsMemory().Pin()).Pointer,
+ seq_id = (LLamaSeqId**)pinHolder.Add(SeqIdsPtrs.AsMemory().Pin()).Pointer,
+
+ // embd is not currently supported, so this is always null!
+ embd = null,
+
+ // Note that if embd is **not null** then this will be null!
+ tokens = (LLamaToken*)pinHolder.Add(TokenIds.AsMemory().Pin()).Pointer,
+ };
+
+ // Create pointers to each of the arrays in turns
+ for (var i = 0; i < SeqIdsPtrs.Length; i++)
+ SeqIdsPtrs[i] = (IntPtr)pinHolder.Add(SeqIds[i].AsMemory().Pin()).Pointer;
+
+ return batch;
+ }
+ }
+ }
+}
diff --git a/LLama/Experimental/Runner/LLamaCpp/LogitsGenerator.cs b/LLama/Experimental/Runner/LLamaCpp/LogitsGenerator.cs
new file mode 100644
index 000000000..2dad0be12
--- /dev/null
+++ b/LLama/Experimental/Runner/LLamaCpp/LogitsGenerator.cs
@@ -0,0 +1,28 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Runner.LLamaCpp
+{
+ ///
+ /// Since the native API will return for logits,
+ /// we only get it when it's actually needed.
+ ///
+ public class LogitsGenerator
+ {
+ private int _pos;
+
+ private LLamaContext _context;
+
+ public LogitsGenerator(int pos, LLamaContext context)
+ {
+ _pos = pos;
+ _context = context;
+ }
+
+ public Span GetLogits()
+ {
+ return _context.NativeHandle.GetLogitsIth(_pos);
+ }
+ }
+}
diff --git a/LLama/Experimental/Runner/LLamaCppRunner.cs b/LLama/Experimental/Runner/LLamaCppRunner.cs
new file mode 100644
index 000000000..6cc5656da
--- /dev/null
+++ b/LLama/Experimental/Runner/LLamaCppRunner.cs
@@ -0,0 +1,83 @@
+using LLama.Abstractions;
+using LLama.Experimental.Abstractions;
+using LLama.Experimental.Common;
+using LLama.Experimental.Runner.LLamaCpp;
+using LLama.Native;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace LLama.Experimental.Runner
+{
+ ///
+ /// Using llama.cpp backend to execute the model.
+ ///
+ public sealed class LLamaCppRunner: ModelRunnerBase, IModelRunner
+ {
+ public LLamaWeights ModelWeights { get; }
+
+ public LLamaContext Context { get; }
+
+ public LLamaCppRunner(LLamaWeights modelWeights, IContextParams contextParams)
+ {
+ ModelWeights = modelWeights;
+ Context = new LLamaContext(modelWeights, contextParams);
+ }
+
+ ///
+ public SamplerOutput ExecuteModel(IEnumerable seqGroupMetadataList)
+ {
+ var modelInput = PrepareInputs(seqGroupMetadataList);
+ var samplingMetadata = PrepareSample(seqGroupMetadataList, modelInput.PromptLengths, modelInput.SubqueryLengths);
+ var llamaCppRunnerInput = new LLamaCppRunnerInput(modelInput);
+ var nativeBatch = llamaCppRunnerInput.ToLLamaNativeBatch(out var pinHolder);
+
+ // TODO: is global lock still necessary?
+
+ // Batched inference
+ Context.Decode(nativeBatch);
+
+ // Get the logits
+ Dictionary seqIdToLogits = new();
+ for(int i = 0; i < llamaCppRunnerInput.WithLogits.Length; i++)
+ {
+ if (llamaCppRunnerInput.WithLogits[i])
+ {
+ for(int j = 0; j < llamaCppRunnerInput.SeqIds[i].Length; j++)
+ {
+ if (seqIdToLogits.ContainsKey(llamaCppRunnerInput.SeqIds[i][j]))
+ {
+ throw new Exception("Duplicate sequence id found when getting logits.");
+ }
+ else
+ {
+ seqIdToLogits.Add(llamaCppRunnerInput.SeqIds[i][j], new LogitsGenerator(i, Context));
+ }
+ }
+ }
+ }
+
+ // Sample the logits to get output tokens.
+ List outputs = new();
+ foreach(var seqGroupMetadata in seqGroupMetadataList)
+ {
+ List sequenceOutputs = new();
+ foreach(var seqId in seqGroupMetadata.SeqData.Keys)
+ {
+ var output = seqGroupMetadata.SamplingMethod.SampleSequence(seqIdToLogits[seqId].GetLogits(), seqId, samplingMetadata);
+ sequenceOutputs.Add(output);
+ }
+ outputs.Add(new SequenceGroupOutput(sequenceOutputs));
+ }
+
+ return new SamplerOutput(outputs);
+ }
+
+ public void Dispose()
+ {
+ // It should dispose context but not model weight.
+ Context.Dispose();
+ }
+ }
+}
diff --git a/LLama/Experimental/Runner/ModelRunnerBase.cs b/LLama/Experimental/Runner/ModelRunnerBase.cs
new file mode 100644
index 000000000..5ccd865a9
--- /dev/null
+++ b/LLama/Experimental/Runner/ModelRunnerBase.cs
@@ -0,0 +1,138 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Text;
+using LLama.Experimental.Abstractions;
+using LLama.Experimental.Common;
+using LLama.Extensions;
+
+namespace LLama.Experimental.Runner
+{
+ ///
+ /// A class that provides some commonly used method when running the model.
+ ///
+ /// Note that you could certainly not use this helper class and implement from scratch.
+ ///
+ public abstract class ModelRunnerBase
+ {
+ protected ModelRunnerInput PrepareInputs(IEnumerable seqGroupMetadataList)
+ {
+ Debug.Assert(seqGroupMetadataList.Count() > 0);
+ if (seqGroupMetadataList.First().IsPrompt)
+ {
+ return PreparePrompt(seqGroupMetadataList);
+ }
+ else
+ {
+ return PrepareDecode(seqGroupMetadataList);
+ }
+ }
+
+ protected SamplingMetadata PrepareSample(IEnumerable seqGroupMetadataList, int[] promptLengths, int[] subqueryLengths)
+ {
+ // TODO: implement it.
+ return null;
+ }
+
+ ///
+ /// Prepare input for sequences at prefill stage.
+ ///
+ ///
+ ///
+ protected ModelRunnerInput PreparePrompt(IEnumerable seqGroupMetadataList)
+ {
+ Debug.Assert(seqGroupMetadataList.Count() > 0);
+ List inputTokenIds = new();
+ List inputPositions = new();
+ List sequenceIdMapping = new(); // sequennce id of corresponding tokens
+ List withLogits = new();
+
+ List promptLengths = new();
+ List contextLengths = new();
+ List subqueryLengths = new();
+
+ foreach(var seqGroupMetadata in seqGroupMetadataList)
+ {
+ Debug.Assert(seqGroupMetadata.IsPrompt);
+ var seqIds = seqGroupMetadata.SeqData.Keys.ToList();
+ Debug.Assert(seqIds.Count == 1);
+ var seqId = seqIds[0];
+
+ var tokenChunkSize = seqGroupMetadata.TokenChunkSize;
+ var seqData = seqGroupMetadata.SeqData[seqId];
+ var computedLength = seqData.NumComputedTokens;
+ // We should use `Length` here because in case of preemption it contains output tokens.
+ var prefillEnd = Math.Min(seqData.Length, computedLength + tokenChunkSize);
+ var prompTokenIds = seqData.TokenIds.Take(prefillEnd).Skip(computedLength);
+ var promptLength = prompTokenIds.Count();
+ // Right now, the prefill_end is always same as the length of sequence.
+ // However, once chunked prefill is introduced, this assumption can be changed.
+ Debug.Assert(prefillEnd == seqData.Length);
+ promptLengths.Add(promptLength);
+
+ // TODO: check the logic here, related with blocks?
+
+ // actual prompt lens
+ contextLengths.Add(computedLength);
+ subqueryLengths.Add(promptLength - computedLength);
+
+ inputTokenIds.AddRange(prompTokenIds);
+ // NOTE: Here we assume that the first token in the prompt is always the first token in the sequence.
+ inputPositions.AddRange(Enumerable.Range(computedLength, prefillEnd));
+
+ // TODO: deal with sliding window here?
+ sequenceIdMapping.AddRange(Enumerable.Repeat(seqId, promptLength));
+
+ withLogits.AddRange(Enumerable.Repeat(false, promptLength - 1));
+ withLogits.Add(true);
+ }
+
+ int maxSubqueryLength = subqueryLengths.Max();
+ int maxPromptLength = promptLengths.Max();
+ int numPromptTokens = inputTokenIds.Count;
+ Debug.Assert(maxSubqueryLength > 0);
+
+ return new ModelRunnerInput(inputTokenIds.ToArray(), inputPositions.ToArray(), sequenceIdMapping.ToArray(),
+ withLogits.ToArray(), promptLengths.ToArray(), subqueryLengths.ToArray());
+ }
+
+ ///
+ /// Prepare input for sequences at decode stage.
+ ///
+ ///
+ ///
+ protected ModelRunnerInput PrepareDecode(IEnumerable seqGroupMetadataList)
+ {
+ Debug.Assert(seqGroupMetadataList.Count() > 0);
+ List inputTokenIds = new();
+ List inputPositions = new();
+ List sequenceIdMapping = new(); // sequennce id of corresponding tokens
+ List withLogits = new();
+
+ foreach (var seqGroupMetadata in seqGroupMetadataList)
+ {
+ Debug.Assert(!seqGroupMetadata.IsPrompt);
+ Debug.Assert(seqGroupMetadata.TokenChunkSize == 1);
+ var seqIds = seqGroupMetadata.SeqData.Keys.ToList();
+
+ foreach(var seqId in seqIds)
+ {
+ var seqData = seqGroupMetadata.SeqData[seqId];
+ var generationToken = seqData.LastTokenId;
+ inputTokenIds.Add(generationToken);
+
+ var seqLength = seqData.Length;
+ var position = seqLength - 1;
+ inputPositions.Add(position);
+
+ sequenceIdMapping.Add(seqId);
+ withLogits.Add(true);
+ }
+ }
+
+ return new ModelRunnerInput(inputTokenIds.ToArray(), inputPositions.ToArray(),
+ sequenceIdMapping.ToArray(), withLogits.ToArray(), [], []);
+ }
+ }
+}
diff --git a/LLama/Experimental/Scheduler.cs b/LLama/Experimental/Scheduler.cs
new file mode 100644
index 000000000..47180ea6c
--- /dev/null
+++ b/LLama/Experimental/Scheduler.cs
@@ -0,0 +1,601 @@
+using LLama.Experimental.Abstractions;
+using LLama.Experimental.Common;
+using LLama.Experimental.Config;
+using LLama.Experimental.Extensions;
+using Microsoft.Extensions.Logging;
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+
+namespace LLama.Experimental
+{
+ ///
+ /// The scheduler to schedule the requests for model inference.
+ ///
+ public sealed class Scheduler
+ // TODO: LORA
+ {
+ private ILogger? _logger;
+
+ ///
+ /// Whether we schedule a prompt at previous step.
+ ///
+ private bool _prevIsPrompt;
+
+ ///
+ /// Latency of the last prompt step
+ ///
+ private float _lastPromptLatency;
+
+ ///
+ /// Time at previous scheduling step
+ ///
+ private DateTime _prevTime;
+
+ ///
+ /// Scheduler configuration.
+ ///
+ public SchedulerConfig SchedulerConfig { get; set; }
+
+ ///
+ /// KV cache configuration.
+ ///
+ public KvCacheConfig KvCacheConfig { get; set; }
+
+ ///
+ /// The maximumum prompt length that can be used.
+ /// It's deduced from the scheduler configuration.
+ ///
+ public int MaxPromptLength { get; set; }
+
+ ///
+ /// Sequence groups in the WAITING state. It contain new prefill or preempted requests.
+ ///
+ public LinkedList Waiting { get; set; }
+
+ ///
+ /// Sequence groups in the RUNNING state. It contains the requests that is being decoded.
+ ///
+ public LinkedList Running { get; set; }
+
+ ///
+ /// Sequence groups in the SWAPPED state. It contains decode requests that are swapped out.
+ ///
+ public LinkedList Swapped { get; set; }
+
+ public KvCacheManager KvCacheManager { get; }
+
+
+ ///
+ /// Create a scheduler. Note that this is not a high-level API. If you are an user, please
+ /// read the documentation and ensure you know what it does.
+ ///
+ ///
+ ///
+ ///
+ public Scheduler(SchedulerConfig schedulerConfig, KvCacheConfig kvCacheConfig, ILogger? logger = null)
+ {
+ SchedulerConfig = schedulerConfig;
+ KvCacheConfig = kvCacheConfig;
+
+ if (SchedulerConfig.EnableChunkedPrefill)
+ {
+ MaxPromptLength = SchedulerConfig.MaxSequenceLength;
+ }
+ else
+ {
+ MaxPromptLength = Math.Min(SchedulerConfig.MaxSequenceLength, SchedulerConfig.MaxNumBatchedTokens);
+ }
+
+ Waiting = new LinkedList();
+ Running = new LinkedList();
+ Swapped = new LinkedList();
+
+ _logger = logger;
+
+ // TODO: init with config
+ KvCacheManager = new();
+ }
+
+ ///
+ /// Add sequence groups to the waiting queue.
+ ///
+ ///
+ ///
+ public Scheduler AddSeqGroup(SequenceGroup seqGroup)
+ {
+ _logger?.LogDebug($"Added seq group {seqGroup.RequestId}");
+ Waiting.AddBack(seqGroup);
+ return this;
+ }
+
+ ///
+ /// Aborts a sequence group with the given IDs.
+ /// Check if the sequence group with the given ID
+ /// is present in any of the state queue.
+ ///If present, remove the sequence group from the state queue.
+ /// Also, if any of the sequences in the sequence group is not finished,
+ /// free the sequence with status `FINISHED_ABORTED`.
+ ///Otherwise, do nothing.
+ ///
+ ///
+ ///
+ public Scheduler AbortSeqGroup(IEnumerable requestIds)
+ {
+ var requestIdSet = new HashSet(requestIds.Distinct());
+
+ AbortInternal(Waiting, requestIdSet);
+ AbortInternal(Running, requestIdSet);
+ AbortInternal(Swapped, requestIdSet);
+ return this;
+ }
+
+ ///
+ /// Whether all sequences has been finished at this moment.
+ ///
+ ///
+ public bool HasUnfinishedSeqs()
+ {
+ return Waiting.Count != 0 || Running.Count != 0 || Swapped.Count != 0;
+ }
+
+ ///
+ /// Get the number of unfinished sequence groups.
+ ///
+ ///
+ public int GetNumUnfinishedSeqGroups()
+ {
+ return Waiting.Count + Running.Count + Swapped.Count;
+ }
+
+ ///
+ /// Free the sequence resource that managed by the scheduler.
+ /// It's actually an empty method now and may be implemented in the future if needed.
+ ///
+ ///
+ public void FreeSeq(Sequence seq)
+ {
+ // TODO: implement it if needed.
+ }
+
+ ///
+ /// Schedule sequence groups.
+ /// This function call changes the internal states of the scheduler,
+ /// such as this.Running, this.Wwapped, and this.Waiting.
+ ///
+ ///
+ public (List, SchedulerOutputs) Schedule()
+ {
+ var schedulerOutputs = ScheduleInternal();
+ var now = DateTime.Now;
+
+ // Create input data structures.
+ List seqGroupMetadataList = new();
+ int i = 0;
+ foreach(var scheduledSeqGroup in schedulerOutputs.ScheduledSeqGroups)
+ {
+ var seqGroup = scheduledSeqGroup.SeqGroup;
+ var tokenChunkSize = scheduledSeqGroup.TokenChunkSize;
+ seqGroup.MaybeSetFirstScheduledTime(now);
+
+ Dictionary seqData = new();
+
+ foreach(var seq in seqGroup.GetSeqsWithStatus(SequenceStatus.Running))
+ {
+ var seqId = seq.Id;
+ seqData[seqId] = seq.Data;
+ }
+
+ // It assumes the scheduled_seq_groups is ordered by prefill < decoding.
+ bool isPrompt = i < schedulerOutputs.NumPrefillGroups;
+ var seqGroupMetadata = new SequenceGroupMetadata(
+ seqGroup.RequestId,
+ isPrompt,
+ seqData,
+ seqGroup.SamplingMethod,
+ seqGroup.StoppingCriteria,
+ tokenChunkSize
+ );
+ seqGroupMetadataList.Add(seqGroupMetadata);
+
+ i++;
+ }
+
+ return (seqGroupMetadataList, schedulerOutputs);
+ }
+
+ ///
+ /// Free finished sequence groups.
+ ///
+ public void FreeFinishedSeqGroups()
+ {
+ Running = new LinkedList(Running.Where(x => !x.IsFinished));
+ }
+
+ ///
+ /// Schedule queued requests.
+ ///
+ ///
+ ///
+ private SchedulerOutputs ScheduleInternal()
+ {
+ if (SchedulerConfig.EnableChunkedPrefill)
+ {
+ // TODO: allow chunked prefill.
+ throw new NotImplementedException();
+ }
+ else
+ {
+ return ScheduleDefault();
+ }
+ }
+
+ ///
+ /// Schedule queued requests.
+ ///
+ /// The current policy is designed to opimimize the throughput. First,
+ /// it batches as many prefill requests as possible.And it schedules
+ /// decodes.If there's a pressure on GPU memory, decode requests can
+ /// be swapped or preempted.
+ ///
+ ///
+ private SchedulerOutputs ScheduleDefault()
+ {
+ // Include running requests to the budget.
+ var budget = new SchedulingBudget(SchedulerConfig.MaxNumBatchedTokens, SchedulerConfig.MaxNumSequences);
+ // Make sure we include num running seqs before scheduling prefill,
+ // so that we don't schedule beyond max_num_seqs for prefill.
+ foreach(var seqGroup in Running)
+ {
+ budget.AddNumSeqs(seqGroup.RequestId, seqGroup.GetMaxNumRunningSeqs());
+ }
+
+ var remainingWaiting = Waiting;
+ var prefills = SchedulerPrefillOutputs.CreateEmpty();
+ var remainingRunning = Running;
+ var runningScheduled = SchedulerRunningOutputs.CreateEmpty();
+ var remainingSwapped = Swapped;
+ var swappedIn = SchedulerSwappedInOutputs.CreateEmpty();
+
+ if(Swapped.Count == 0)
+ {
+ prefills = SchedulePrefills(Waiting, budget, false);
+ remainingWaiting = prefills.RemainingWaitingQueue;
+ }
+
+ var policy = PolicyFactory.DefaultPolicy;
+ // Don't schedule decodes if prefills are scheduled.
+ // NOTE: If `SchedulePrefills` doesn't enable chunking, this.Running
+ // only contains decode requests, not chunked prefills.
+
+ if(prefills.SeqGroups.Count == 0)
+ {
+ runningScheduled = ScheduleRunning(Running, budget, policy, false);
+ remainingRunning = runningScheduled.RemainingRunningQueue;
+
+ // If any sequence group is preempted, do not swap in any sequence group.
+ // Because it means there's no slot for new running requests.
+ if(runningScheduled.PreemptedSeqGroups.Count + runningScheduled.SwappedOutSeqGroups.Count == 0)
+ {
+ // TODO: implement the swapping.
+ }
+ }
+
+ Debug.Assert(budget.NumBatchedTokens <= SchedulerConfig.MaxNumBatchedTokens);
+ Debug.Assert(budget.NumCurrentSeqs <= SchedulerConfig.MaxNumSequences);
+
+ // Update waiting requests.
+ Waiting = remainingWaiting;
+ Waiting.ExtendFront(runningScheduled.PreemptedSeqGroups);
+ // Update new running requests.
+ Running = remainingRunning;
+ Running.ExtendFront(prefills.SeqGroups.Select(x => x.SeqGroup));
+ Running.ExtendBack(runningScheduled.DecodeSeqGroups.Select(x => x.SeqGroup));
+ Running.ExtendBack(swappedIn.DecodeSeqGroups.Select(x => x.SeqGroup));
+ // Update swapped requests.
+ Swapped = remainingSwapped;
+ Swapped.ExtendBack(runningScheduled.SwappedOutSeqGroups);
+
+ // There should be no prefill from running queue because this policy
+ // doesn't allow chunked prefills.
+ Debug.Assert(runningScheduled.PrefillSeqGroups.Count == 0);
+ Debug.Assert(swappedIn.PrefillSeqGroups.Count == 0);
+ return new SchedulerOutputs(
+ ScheduledSeqGroups: prefills.SeqGroups.Concat(runningScheduled.DecodeSeqGroups).Concat(swappedIn.DecodeSeqGroups),
+ NumPrefillGroups: prefills.SeqGroups.Count,
+ NumBatchedTokens: budget.NumBatchedTokens,
+ IgnoredSeqGroups: prefills.IgnoredSeqGroups
+ );
+ }
+
+ private SchedulerSwappedInOutputs ScheduleSwapped(LinkedList swappedQueue, SchedulingBudget budget, ISchedulingPolicy policy, bool enableChunking)
+ {
+ throw new NotImplementedException();
+ }
+
+ ///
+ /// Schedule sequence groups that are in prefill stage.
+ ///
+ /// Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
+ /// as a new prefill(that starts from beginning -> most recently generated
+ /// tokens).
+ ///
+ /// It schedules waiting requests as long as it fits `budget` and
+ /// curr_loras smaller than or equal with max_lora from the scheduling config.The input arguments
+ /// `budget` and `curr_loras` are updated based on scheduled seq_groups.
+ ///
+ /// The queue that contains prefill requests. The given arguments are NOT in-place modified.
+ /// The scheduling budget. The argument is in-place updated when any requests are scheduled.
+ ///
+ /// If True, seq group can be chunked and only a chunked number of tokens are scheduled if
+ /// has not enough capacity to schedule all tokens.
+ ///
+ ///
+ private SchedulerPrefillOutputs SchedulePrefills(LinkedList waiting, SchedulingBudget budget, bool enableChunking = false)
+ {
+ List ignoredSeqGroups = new();
+ List seqGroups = new();
+ // We don't sort waiting queue because we assume it is sorted.
+ // Copy the queue so that the input queue is not modified.
+ var waitingQueue = new LinkedList(waiting);
+
+ LinkedList leftoverWaitingSequences = new();
+ while(PassedDelay(DateTime.Now) && waitingQueue.Count > 0)
+ {
+ var seqGroup = waitingQueue.PeekFront();
+
+ var waitingSeqs = seqGroup.GetSeqsWithStatus(SequenceStatus.Waiting);
+ Debug.Assert(waitingSeqs.Count() == 1, "Waiting sequence group should have only one prompt sequence.");
+ var numNewTokens = GetNumNewTokens(seqGroup, SequenceStatus.Waiting, enableChunking, budget);
+ if (!enableChunking)
+ {
+ var numPromptTokens = waitingSeqs.First().Length;
+ Debug.Assert(numNewTokens == numPromptTokens);
+ }
+
+ if (numNewTokens > MaxPromptLength)
+ {
+ _logger?.LogWarning($"Input prompt ({numNewTokens} tokens) is too long " +
+ $"and exceeds limit of {MaxPromptLength}.");
+ foreach(var seq in waitingSeqs)
+ {
+ seq.Status = SequenceStatus.FinishIgnored;
+ }
+ ignoredSeqGroups.Add(seqGroup);
+ waitingQueue.RemoveFront();
+ continue;
+ }
+
+ // If the sequence group cannot be allocated, stop.
+ var canAlloc = KvCacheManager.CanAllocate(seqGroup);
+ if(canAlloc == AllocStatus.Later)
+ {
+ break;
+ }
+ else if(canAlloc == AllocStatus.Never)
+ {
+ _logger?.LogWarning($"Input prompt ({numNewTokens} tokens) is too long" +
+ " and exceeds the capacity of block_manager");
+ foreach(var seq in waitingSeqs)
+ {
+ seq.Status = SequenceStatus.FinishIgnored;
+ }
+ ignoredSeqGroups.Add(seqGroup);
+ waitingQueue.RemoveFront();
+ continue;
+ }
+
+ var numNewSeqs = seqGroup.GetMaxNumRunningSeqs();
+ if(numNewTokens == 0 || !budget.CanSchedule(numNewTokens, numNewSeqs))
+ {
+ break;
+ }
+
+ // Can schedule this request.
+ waitingQueue.RemoveFront();
+ AllocateAndSetRunning(seqGroup, numNewTokens);
+ seqGroups.Add(new ScheduledSequenceGroup(seqGroup, numNewTokens));
+ budget.AddNumBatchedTokens(seqGroup.RequestId, numNewTokens);
+ budget.AddNumSeqs(seqGroup.RequestId, numNewSeqs);
+ }
+
+ waitingQueue.ExtendFront(leftoverWaitingSequences);
+ if(seqGroups.Count > 0)
+ {
+ _prevIsPrompt = true;
+ }
+
+ return new SchedulerPrefillOutputs(waitingQueue, seqGroups, ignoredSeqGroups);
+ }
+
+ ///
+ /// Schedule sequence groups that are running.
+ ///
+ /// Running queue should include decode and chunked prefill requests.
+ ///
+ ///
+ /// The queue that contains running requests (i.e., decodes).
+ /// The given arguments are NOT in-place modified.
+ ///
+ ///
+ /// The scheduling budget. The argument is in-place updated
+ /// when any decodes are preempted.
+ ///
+ /// The sorting policy to sort running_queue.
+ ///
+ /// If True, seq group can be chunked and only a chunked number of tokens are scheduled if
+ /// `budget.num_batched_tokens` has not enough capacity to schedule all tokens.
+ ///
+ ///
+ private SchedulerRunningOutputs ScheduleRunning(LinkedList runningQueue, SchedulingBudget budget,
+ ISchedulingPolicy policy, bool enableChunking)
+ {
+ List decodeSeqGroups = new();
+ List prefillSeqGroups = new();
+ List preempted = new();
+ List swappedOut = new();
+
+ //NOTE: Preemption happens only when there is no available slot
+ //to keep all the sequence groups in the RUNNING state.
+ //In this case, the policy is responsible for deciding which sequence
+ //groups to preempt.
+ var now = DateTime.Now;
+ runningQueue = policy.SortByPriority(now, runningQueue);
+
+ while(runningQueue.Count > 0)
+ {
+ var seqGroup = runningQueue.PeekFront();
+ var numRunningTokens = GetNumNewTokens(seqGroup, SequenceStatus.Running, enableChunking, budget);
+
+ // We can have up to 1 running prefill at any given time in running
+ // queue, which means we can guarantee chunk size is at least 1.
+ Debug.Assert(numRunningTokens != 0);
+ var numRunningSeqs = seqGroup.GetMaxNumRunningSeqs();
+
+ runningQueue.RemoveFront();
+ bool appended = true;
+ while (!CanAppendSlots(seqGroup))
+ {
+ // TODO: implement the preemption logic
+ Debug.Assert(false);
+ }
+
+ if (appended)
+ {
+ _logger?.LogDebug($"append slot for {seqGroup}");
+ AppendSlots(seqGroup);
+ if (seqGroup.IsPrefill)
+ {
+ prefillSeqGroups.Add(new ScheduledSequenceGroup(seqGroup, numRunningTokens));
+ }
+ else
+ {
+ decodeSeqGroups.Add(new ScheduledSequenceGroup(seqGroup, 1));
+ }
+ budget.AddNumBatchedTokens(seqGroup.RequestId, numRunningTokens);
+ budget.AddNumSeqs(seqGroup.RequestId, numRunningSeqs);
+ }
+ }
+
+ Debug.Assert(runningQueue.Count == 0);
+ return new SchedulerRunningOutputs(runningQueue, decodeSeqGroups, prefillSeqGroups, preempted, swappedOut);
+ }
+
+ private void AllocateAndSetRunning(SequenceGroup seqGroup, int numNewTokens)
+ {
+ KvCacheManager.Allocate(seqGroup);
+ foreach (var seq in seqGroup.GetSeqsWithStatus(SequenceStatus.Waiting))
+ {
+ seq.Status = SequenceStatus.Running;
+ }
+ }
+
+ private bool PassedDelay(DateTime now)
+ {
+ if (_prevIsPrompt)
+ {
+ _lastPromptLatency = (now - _prevTime).Milliseconds;
+ }
+ _prevTime = now;
+ _prevIsPrompt = false;
+ // Delay scheduling prompts to let waiting queue fill up
+ if (SchedulerConfig.DelayFactor > 0 && Waiting.Count > 0)
+ {
+ var earliestArrivalTime = Waiting.Select(x => x.Metrics.ArrivalTime).Min();
+ return (now - earliestArrivalTime).Milliseconds > (SchedulerConfig.DelayFactor * _lastPromptLatency) || Running.Count == 0;
+ }
+ return true;
+ }
+
+ private bool CanAppendSlots(SequenceGroup seqGroup)
+ {
+ return KvCacheManager.CanAppendSlots(seqGroup);
+ }
+
+ private void AppendSlots(SequenceGroup seqGroup)
+ {
+ // TODO: Implement this method
+ }
+
+ ///
+ /// Get the next new tokens to compute for a given sequence group that's in a given `status`.
+ ///
+ /// The API could chunk the number of tokens to compute based on `budget`
+ /// if `enable_chunking` is True.If a sequence group has multiple
+ /// sequences(e.g., running beam search), it means it is in decoding
+ /// phase, so chunking doesn't happen.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ private int GetNumNewTokens(SequenceGroup seqGroup, SequenceStatus status, bool enableChunking, SchedulingBudget budget)
+ {
+ int numNewTokens = 0;
+ var seqs = seqGroup.GetSeqsWithStatus(status);
+ foreach(var seq in seqs)
+ {
+ numNewTokens += seq.NumNewTokens;
+ }
+ // Chunk if a running request cannot fit in.
+ // If number of seq > 1, it means it is doing beam search in a
+ // decode phase. Do not chunk in that case.
+ if(enableChunking && seqs.Count() == 1)
+ {
+ numNewTokens = Math.Min(numNewTokens, budget.RemainingTokenBudget);
+ }
+ return numNewTokens;
+ }
+
+ private void AbortInternal(LinkedList queue, HashSet requestIds)
+ {
+ Queue abortedGroups = new();
+ foreach (var seqGroup in queue)
+ {
+ if (requestIds.Count == 0)
+ {
+ break;
+ }
+ if (requestIds.Contains(seqGroup.RequestId))
+ {
+ _logger?.LogDebug($"Aborted seq group {seqGroup.RequestId}");
+ abortedGroups.Enqueue(seqGroup);
+ requestIds.Remove(seqGroup.RequestId);
+ }
+ }
+ foreach(var abortGroup in abortedGroups)
+ {
+ queue.Remove(abortGroup);
+ foreach(var seq in abortGroup.GetAllSeqs())
+ {
+ if (seq.IsFinished)
+ {
+ continue;
+ }
+ seq.Status = SequenceStatus.FinishAborted;
+ }
+ }
+ }
+ }
+
+ ///
+ /// The mode of preemption.
+ ///
+ public enum PreemptionMode
+ {
+ ///
+ /// Swap out the blocks of the preempted sequences to CPU memory
+ /// and swap them back in when the sequences are resumed.
+ ///
+ Swap,
+
+ ///
+ /// Discard the blocks of the preempted sequences and recompute them
+ /// when the sequences are resumed, treating the sequences as new prompts.
+ ///
+ Recompute
+ }
+}
diff --git a/LLama/Experimental/SchedulingPolicies.cs b/LLama/Experimental/SchedulingPolicies.cs
new file mode 100644
index 000000000..c22ff3b69
--- /dev/null
+++ b/LLama/Experimental/SchedulingPolicies.cs
@@ -0,0 +1,25 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using LLama.Experimental.Abstractions;
+using LLama.Experimental.Common;
+
+namespace LLama.Experimental
+{
+ ///
+ /// First in first out policy.
+ ///
+ public class FCFS: ISchedulingPolicy
+ {
+ ///
+ public int GetPriority(DateTime now, SequenceGroup seqGroup)
+ {
+ return (now - seqGroup.Metrics.ArrivalTime).Milliseconds;
+ }
+ }
+
+ public class PolicyFactory
+ {
+ public static ISchedulingPolicy DefaultPolicy { get; set; } = new FCFS();
+ }
+}
diff --git a/LLama/Experimental/Utils/ClassStringFormatter.cs b/LLama/Experimental/Utils/ClassStringFormatter.cs
new file mode 100644
index 000000000..6ae2cc117
--- /dev/null
+++ b/LLama/Experimental/Utils/ClassStringFormatter.cs
@@ -0,0 +1,33 @@
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Reflection;
+using System.Text;
+
+namespace LLama.Experimental.Utils
+{
+ internal static class ClassStringFormatter
+ {
+ public static string Format(T obj)
+ {
+ Type type = obj.GetType();
+ PropertyInfo[] properties = type.GetProperties();
+
+ string res = $"{type.Name}(";
+ foreach (var property in properties)
+ {
+ object? value = property.GetValue(obj);
+ res += $"\n {property.Name} = {value},";
+ }
+ if(properties.Length == 0)
+ {
+ res += ")";
+ }
+ else
+ {
+ res += "\n)";
+ }
+ return res;
+ }
+ }
+}
diff --git a/LLama/Experimental/Utils/IdCounter.cs b/LLama/Experimental/Utils/IdCounter.cs
new file mode 100644
index 000000000..6ff4fd558
--- /dev/null
+++ b/LLama/Experimental/Utils/IdCounter.cs
@@ -0,0 +1,26 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Experimental.Utils
+{
+ internal class IdCounter
+ {
+ private int _number;
+
+ public IdCounter(int start = 0)
+ {
+ _number = start;
+ }
+
+ public int Next()
+ {
+ return _number++;
+ }
+
+ public void Reset()
+ {
+ _number = 0;
+ }
+ }
+}
diff --git a/LLama/Extensions/IEnumerableExtensions.cs b/LLama/Extensions/IEnumerableExtensions.cs
index 3d1e2e814..c94335e1b 100644
--- a/LLama/Extensions/IEnumerableExtensions.cs
+++ b/LLama/Extensions/IEnumerableExtensions.cs
@@ -10,8 +10,13 @@ public static IEnumerable TakeLast(this IEnumerable source, int count)
{
return TakeLastImpl(source, count);
}
+
+ public static IEnumerable SkipLast(this IEnumerable source, int count)
+ {
+ return SkipLastImpl(source, count);
+ }
#elif !NET6_0_OR_GREATER && !NETSTANDARD2_1_OR_GREATER
- #error Target framework not supported!
+#error Target framework not supported!
#endif
internal static IEnumerable TakeLastImpl(IEnumerable source, int count)
@@ -24,5 +29,10 @@ internal static IEnumerable TakeLastImpl(IEnumerable source, int count)
list.RemoveRange(0, list.Count - count);
return list;
}
+
+ internal static IEnumerable SkipLastImpl(IEnumerable source, int count)
+ {
+ return source.Take(source.Count() - count);
+ }
}
}
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index 9517965e6..12ea24740 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -400,6 +400,11 @@ public DecodeResult Decode(LLamaBatch batch)
return (DecodeResult)NativeHandle.Decode(batch);
}
+ public DecodeResult Decode(LLamaNativeBatch batch)
+ {
+ return (DecodeResult)NativeHandle.Decode(batch);
+ }
+
///
///
///
diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj
index 3947b7c31..3b4493ec9 100644
--- a/LLama/LLamaSharp.csproj
+++ b/LLama/LLamaSharp.csproj
@@ -3,7 +3,7 @@
netstandard2.0;net6.0;net7.0;net8.0
LLama
enable
- 10
+ 12
AnyCPU;x64;Arm64
True
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index 2f881fa5d..630c7bba2 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -219,6 +219,18 @@ public DecodeResult Decode(LLamaBatch batch)
return (DecodeResult)NativeApi.llama_decode(this, nb);
}
+ ///
+ ///
+ ///
+ ///
+ ///
+ public DecodeResult Decode(LLamaNativeBatch batch)
+ {
+ // TODO: is global lock still necessary?
+ lock (GlobalInferenceLock)
+ return (DecodeResult)NativeApi.llama_decode(this, batch);
+ }
+
///
/// Decode a set of tokens in batch-size chunks.
///