diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index 228e8671..d96bfc79 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -403,12 +403,20 @@ func (c *Cache) addPodAndModelMapping(podName, modelName string) { func (c *Cache) deletePodAndModelMapping(podName, modelName string) { if models, ok := c.PodToModelMapping[podName]; ok { delete(models, modelName) - c.PodToModelMapping[podName] = models + if len(models) != 0 { + c.PodToModelMapping[podName] = models + } else { + delete(c.PodToModelMapping, podName) + } } if pods, ok := c.ModelToPodMapping[modelName]; ok { delete(pods, podName) - c.ModelToPodMapping[modelName] = pods + if len(pods) != 0 { + c.ModelToPodMapping[modelName] = pods + } else { + delete(c.ModelToPodMapping, modelName) + } } } diff --git a/pkg/plugins/gateway/gateway.go b/pkg/plugins/gateway/gateway.go index 1a2cbad0..d1a27b7d 100644 --- a/pkg/plugins/gateway/gateway.go +++ b/pkg/plugins/gateway/gateway.go @@ -223,7 +223,7 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *e "error processing request body"), targetPodIP, stream } - if model, ok = jsonMap["model"].(string); !ok || model == "" { // || !s.cache.CheckModelExists(model) # enable when dynamic lora is enabled + if model, ok = jsonMap["model"].(string); !ok || model == "" { klog.ErrorS(nil, "model error in request", "requestID", requestID, "jsonMap", jsonMap) return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError, []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ @@ -231,6 +231,15 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *e fmt.Sprintf("no model in request body or model %s does not exist", model)), targetPodIP, stream } + // early reject the request if model doesn't exist. + if !s.cache.CheckModelExists(model) { + klog.ErrorS(nil, "model doesn't exist in cache, probably wrong model name", "requestID", requestID, "jsonMap", jsonMap) + return generateErrorResponse(envoyTypePb.StatusCode_BadRequest, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-no-model", RawValue: []byte(model)}}}, + fmt.Sprintf("model %s does not exist", model)), targetPodIP, stream + } + stream, ok = jsonMap["stream"].(bool) if stream && ok { streamOptions, ok := jsonMap["stream_options"].(map[string]interface{}) @@ -250,8 +259,7 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *e } headers := []*configPb.HeaderValueOption{} - switch { - case routingStrategy == "": + if routingStrategy == "" { headers = append(headers, &configPb.HeaderValueOption{ Header: &configPb.HeaderValue{ Key: "model", @@ -259,7 +267,7 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *e }, }) klog.InfoS("request start", "requestID", requestID, "model", model) - case routingStrategy != "": + } else { pods, err := s.cache.GetPodsForModel(model) if len(pods) == 0 || err != nil { return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError, @@ -348,8 +356,7 @@ func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req * var usage openai.CompletionUsage headers := []*configPb.HeaderValueOption{} - switch stream { - case true: + if stream { t := &http.Response{ Body: io.NopCloser(bytes.NewReader(b.ResponseBody.GetBody())), } @@ -370,7 +377,7 @@ func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req * }}}, err.Error()) } - case false: + } else { if err := json.Unmarshal(b.ResponseBody.Body, &res); err != nil { klog.ErrorS(err, "error to unmarshal response", "requestID", requestID, "responseBody", string(b.ResponseBody.GetBody())) return generateErrorResponse(