Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAI v1 Chat Completions API #171

Merged
merged 8 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/reference/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,11 @@
"api_token": {
"type": "string",
"nullable": true
},
"apply_chat_template": {
"type": "boolean",
"default": "false",
"example": true
}
}
},
Expand Down
2 changes: 2 additions & 0 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ message Request {
bool prefill_logprobs = 6;
/// Adapter index
uint32 adapter_index = 7;
/// Apply chat template to inputs
bool apply_chat_template = 8;
}

message Batch {
Expand Down
1 change: 1 addition & 0 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ impl Client {
}),
adapter_index: 0,
prefill_logprobs: true,
apply_chat_template: false,
});
n_tokens += max_input_length;
}
Expand Down
1 change: 1 addition & 0 deletions router/src/health.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ impl Health {
ignore_eos_token: false,
}),
adapter_index: 0,
apply_chat_template: false,
};
let batch = Batch {
id: BATCH_ID,
Expand Down
40 changes: 39 additions & 1 deletion router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ pub(crate) struct GenerateParameters {
#[schema(default = "true")]
pub decoder_input_details: bool,
#[serde(default)]
#[schema(default = "false")]
pub apply_chat_template: bool,
#[serde(default)]
#[schema(
exclusive_minimum = 0,
nullable = true,
Expand Down Expand Up @@ -177,6 +180,7 @@ fn default_parameters() -> GenerateParameters {
watermark: false,
details: false,
decoder_input_details: false,
apply_chat_template: false,
seed: None,
}
}
Expand Down Expand Up @@ -320,7 +324,7 @@ struct UsageInfo {
#[derive(Clone, Debug, Deserialize, ToSchema)]
struct ChatCompletionRequest {
model: String,
messages: Vec<String>,
messages: Vec<std::collections::HashMap<String, String>>,
temperature: Option<f32>,
top_p: Option<f32>,
n: Option<i32>,
Expand Down Expand Up @@ -451,6 +455,40 @@ impl From<CompletionRequest> for CompatGenerateRequest {
watermark: false,
details: true,
decoder_input_details: req.logprobs.is_some(),
apply_chat_template: false,
seed: None,
},
stream: req.stream.unwrap_or(false),
}
}
}

impl From<ChatCompletionRequest> for CompatGenerateRequest {
fn from(req: ChatCompletionRequest) -> Self {
CompatGenerateRequest {
inputs: serde_json::to_string(&req.messages).unwrap(),
parameters: GenerateParameters {
adapter_id: req.model.parse().ok(),
adapter_source: None,
api_token: None,
best_of: req.n.map(|x| x as usize),
temperature: req.temperature,
repetition_penalty: None,
top_k: None,
top_p: req.top_p,
typical_p: None,
do_sample: !req.n.is_none(),
max_new_tokens: req
.max_tokens
.map(|x| x as u32)
.unwrap_or(default_max_new_tokens()),
return_full_text: None,
stop: req.stop,
truncate: None,
watermark: false,
details: true,
decoder_input_details: false,
apply_chat_template: true,
seed: None,
},
stream: req.stream.unwrap_or(false),
Expand Down
1 change: 1 addition & 0 deletions router/src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ impl AdapterSchedulerState {
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
adapter_index: adapter.index(),
apply_chat_template: entry.request.apply_chat_template,
});
// Set batch_time
entry.batch_time = Some(Instant::now());
Expand Down
71 changes: 66 additions & 5 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, CompatGenerateRequest, CompletionRequest, CompletionResponse,
CompletionStreamResponse, Details, ErrorResponse, FinishReason, GenerateParameters,
GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, StreamDetails,
StreamResponse, Token, Validation,
BestOfSequence, ChatCompletionRequest, CompatGenerateRequest, CompletionRequest,
CompletionResponse, CompletionStreamResponse, Details, ErrorResponse, FinishReason,
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
StreamDetails, StreamResponse, Token, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
Expand Down Expand Up @@ -78,7 +78,7 @@ async fn compat_generate(
}
}

/// Generate tokens if `stream == false` or a stream of token if `stream == true`
/// OpenAI compatible completions endpoint
#[utoipa::path(
post,
tag = "LoRAX",
Expand Down Expand Up @@ -138,6 +138,66 @@ async fn completions_v1(
}
}

/// OpenAI compatible chat completions endpoint
#[utoipa::path(
post,
tag = "LoRAX",
path = "/v1/chat/completions",
request_body = ChatCompletionRequest,
responses(
(status = 200, description = "Generated Text",
content(
("application/json" = ChatCompletionResponse),
("text/event-stream" = ChatCompletionStreamResponse),
)),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(skip(infer, req))]
async fn chat_completions_v1(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
req: Json<ChatCompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let req = req.0;
let mut gen_req = CompatGenerateRequest::from(req);

// default return_full_text given the pipeline_tag
if gen_req.parameters.return_full_text.is_none() {
gen_req.parameters.return_full_text = Some(default_return_full_text.0)
}

// switch on stream
if gen_req.stream {
let callback = move |resp: StreamResponse| {
Event::default()
.json_data(CompletionStreamResponse::from(resp))
.map_or_else(
|err| {
tracing::error!("Failed to serialize CompletionStreamResponse: {err}");
Event::default()
},
|data| data,
)
};

let (headers, stream) =
generate_stream_with_callback(infer, Json(gen_req.into()), callback).await;
Ok((headers, Sse::new(stream).keep_alive(KeepAlive::default())).into_response())
} else {
let (headers, generation) = generate(infer, Json(gen_req.into())).await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![CompletionResponse::from(generation.0)])).into_response())
}
}

/// LoRAX endpoint info
#[utoipa::path(
get,
Expand Down Expand Up @@ -771,6 +831,7 @@ pub async fn run(
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/v1/completions", post(completions_v1))
.route("/v1/chat/completions", post(chat_completions_v1))
// AWS Sagemaker route
.route("/invocations", post(compat_generate))
// Base Health route
Expand Down
3 changes: 3 additions & 0 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ impl Validation {
watermark,
adapter_id,
decoder_input_details,
apply_chat_template,
..
} = request.parameters;

Expand Down Expand Up @@ -270,6 +271,7 @@ impl Validation {
parameters,
stopping_parameters,
adapter,
apply_chat_template,
})
}

Expand Down Expand Up @@ -344,6 +346,7 @@ pub(crate) struct ValidGenerateRequest {
pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters,
pub adapter: Adapter,
pub apply_chat_template: bool,
}

#[derive(Error, Debug)]
Expand Down
4 changes: 3 additions & 1 deletion server/lorax_server/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
weight_files,
Weights,
)
from lorax_server.utils.tokenizer import TokenizerManager


class BloomCausalLMBatch(CausalLMBatch):
Expand All @@ -28,10 +29,11 @@ def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
tokenizers: TokenizerManager,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
batch = super().from_pb(pb=pb, tokenizer=tokenizer, tokenizers=tokenizers, dtype=dtype, device=device)
batch.keys_head_dim_last = False
return batch

Expand Down
6 changes: 5 additions & 1 deletion server/lorax_server/models/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import torch
import inspect

Expand All @@ -15,6 +16,7 @@
)
from lorax_server.pb import generate_pb2
from lorax_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from lorax_server.utils.tokenizer import TokenizerManager

tracer = trace.get_tracer(__name__)

Expand Down Expand Up @@ -69,6 +71,7 @@ def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
tokenizers: TokenizerManager,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
Expand All @@ -86,7 +89,8 @@ def from_pb(
adapter_indices_list = []
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
req_inputs = tokenizers.get_inputs(r, tokenizer)
inputs.append(req_inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
Expand Down
14 changes: 10 additions & 4 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import defaultdict
import json
import math
import itertools
from loguru import logger
Expand Down Expand Up @@ -29,11 +30,11 @@
from lorax_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, load_module_map
from lorax_server.utils.dist import MEMORY_FRACTION
from lorax_server.utils.lora import LM_HEAD, AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights
from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights
from lorax_server.utils.segments import SegmentConcatBuilder, find_segments
from lorax_server.utils.weights import shard_on_dim
from lorax_server.utils.graph import GraphCache
from lorax_server.utils.sgmv import get_tmp_tensor
from lorax_server.utils.tokenizer import TokenizerManager

tracer = trace.get_tracer(__name__)

Expand Down Expand Up @@ -114,13 +115,15 @@ def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
tokenizers: TokenizerManager,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
batch_inputs = []
max_truncation = 0
for r in pb.requests:
batch_inputs.append(r.inputs)
inputs = tokenizers.get_inputs(r, tokenizer)
batch_inputs.append(inputs)
max_truncation = max(max_truncation, r.truncate)

batch_tokenized_inputs = tokenizer(
Expand Down Expand Up @@ -746,7 +749,7 @@ def load_adapter(self, adapter_id, adapter_source, adapter_index):
elif adapter_id != BASE_MODEL_ADAPTER_ID:
logger.info(f"Loading adapter weights into model: {adapter_id}")
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
module_map, adapter_config, adapter_weight_names = load_module_map(
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = load_module_map(
self.model_id, adapter_id, adapter_source, weight_names
)

Expand All @@ -758,6 +761,9 @@ def load_adapter(self, adapter_id, adapter_source, adapter_index):

if len(unused_weight_names) > 0:
logger.warning(f"{adapter_id} unused adapter weights: {unused_weight_names}")

if adapter_tokenizer is not None:
self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)

self.adapter_id = adapter_id

Expand Down
6 changes: 5 additions & 1 deletion server/lorax_server/models/flash_mistral.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import math
import torch
import torch.distributed
Expand Down Expand Up @@ -32,6 +33,7 @@
from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID
from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ, AdapterBatchData, AdapterBatchMetadata
from lorax_server.utils.segments import find_segments
from lorax_server.utils.tokenizer import TokenizerManager

tracer = trace.get_tracer(__name__)

Expand All @@ -55,6 +57,7 @@ def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
tokenizers: TokenizerManager,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
Expand All @@ -64,7 +67,8 @@ def from_pb(
batch_inputs = []
max_truncation = 0
for r in pb.requests:
batch_inputs.append(r.inputs)
inputs = tokenizers.get_inputs(r, tokenizer)
batch_inputs.append(inputs)
max_truncation = max(max_truncation, r.truncate)

batch_tokenized_inputs = tokenizer(
Expand Down
Loading
Loading