From 436be0b3bdb01d49ba3d7d3e1f8358ad55362745 Mon Sep 17 00:00:00 2001 From: Varun Gupta Date: Tue, 8 Oct 2024 14:13:54 -0700 Subject: [PATCH] Update random routing section and add support for anonymous user (#276) --- pkg/plugins/gateway/algorithms/random.go | 9 +- pkg/plugins/gateway/gateway.go | 155 +++++++++++++---------- 2 files changed, 92 insertions(+), 72 deletions(-) diff --git a/pkg/plugins/gateway/algorithms/random.go b/pkg/plugins/gateway/algorithms/random.go index b6f0820b..abb9c067 100644 --- a/pkg/plugins/gateway/algorithms/random.go +++ b/pkg/plugins/gateway/algorithms/random.go @@ -20,6 +20,7 @@ import ( "context" "fmt" + "golang.org/x/exp/rand" v1 "k8s.io/api/core/v1" ) @@ -36,9 +37,13 @@ func (r randomRouter) Route(ctx context.Context, pods map[string]*v1.Pod) (strin return "", fmt.Errorf("no pods to forward request") } + k := rand.Intn(len(pods)) for _, pod := range pods { - selectedPod = pod - break + if k == 0 { + selectedPod = pod + break + } + k-- } return selectedPod.Status.PodIP, nil diff --git a/pkg/plugins/gateway/gateway.go b/pkg/plugins/gateway/gateway.go index 3057600f..f9e0e94c 100644 --- a/pkg/plugins/gateway/gateway.go +++ b/pkg/plugins/gateway/gateway.go @@ -134,7 +134,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { } } -func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, string, string) { +func (s *Server) HandleRequestHeaders(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, string, string) { klog.Info("--- In RequestHeaders processing ...") var username, model, routingStrategy, targetPodIP string r := req.Request @@ -155,52 +155,21 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req } } - user, err := utils.GetUser(utils.User{Name: username}, s.redisClient) - if err != nil { - return generateErrorResponse( - envoyTypePb.StatusCode_Forbidden, - []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ - Key: "x-user-missing", RawValue: []byte("true"), - }}}, - fmt.Sprintf("pre query: username is missing: %v", err.Error())), username, targetPodIP - } - - if user.Rpm == 0 { - user.Rpm = int64(defaultRPM) - } - if user.Tpm == 0 { - user.Tpm = user.Rpm * int64(defaultTPMMultiplier) - } - - code, err := s.checkRPM(ctx, username, user.Rpm) - if err != nil { - return generateErrorResponse( - code, - []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ - Key: "x-rpm-exceeded", RawValue: []byte("true"), - }}}, - fmt.Sprintf("pre query: error on checking rpm: %v", err.Error())), username, targetPodIP - } - - rpm, code, err := s.incrRPM(ctx, username) - if err != nil { - return generateErrorResponse( - code, - []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ - Key: "x-error-update-rpm", RawValue: []byte("true"), - }}}, - fmt.Sprintf("pre query: error on updating rpm: %v", err.Error())), username, targetPodIP - } - klog.Infof("RequestStart %s: RPM: %v for user: %v", reqeustID, rpm, user.Name) + if username != "" { + user, err := utils.GetUser(utils.User{Name: username}, s.redisClient) + if err != nil { + return generateErrorResponse( + envoyTypePb.StatusCode_Forbidden, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-user-missing", RawValue: []byte("true"), + }}}, + fmt.Sprintf("pre query: username is missing: %v", err.Error())), username, targetPodIP + } - code, err = s.checkTPM(ctx, username, user.Tpm) - if err != nil { - return generateErrorResponse( - code, - []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ - Key: "x-tpm-exceeded", RawValue: []byte("true"), - }}}, - fmt.Sprintf("pre query: error on checking tpm: %v", err.Error())), username, targetPodIP + errRes := s.checkLimits(ctx, requestID, user) + if errRes != nil { + return errRes, user.Name, targetPodIP + } } headers := []*configPb.HeaderValueOption{ @@ -210,18 +179,19 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req RawValue: []byte("true"), }, }, - { - Header: &configPb.HeaderValue{ - Key: "x-updated-rpm", - RawValue: []byte(fmt.Sprintf("%d", rpm)), - }, - }, + // TODO (varun): refactor this part with model name input from request body + // { + // Header: &configPb.HeaderValue{ + // Key: "x-updated-rpm", + // RawValue: []byte(fmt.Sprintf("%d", rpm)), + // }, + // }, } if routingStrategy != "" { pods, err := s.cache.GetPodsForModel(model) if len(pods) == 0 || err != nil { return generateErrorResponse( - code, + envoyTypePb.StatusCode_InternalServerError, []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ Key: "x-no-model-deployment", RawValue: []byte("true"), }}}, @@ -231,7 +201,7 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req targetPodIP, err = s.selectTargetPod(ctx, routingStrategy, pods) if targetPodIP == "" || err != nil { return generateErrorResponse( - code, + envoyTypePb.StatusCode_InternalServerError, []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ Key: "x-select-target-pod", RawValue: []byte("true"), }}}, @@ -244,7 +214,7 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req RawValue: []byte(targetPodIP), }, }) - klog.Infof("RequestStart %s: SelectedTargetPodIP: %s", reqeustID, targetPodIP) + klog.Infof("RequestStart %s: SelectedTargetPodIP: %s", requestID, targetPodIP) } resp := &extProcPb.ProcessingResponse{ @@ -340,16 +310,18 @@ func (s *Server) HandleResponseBody(ctx context.Context, reqeustID string, req * err.Error()) } - tpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_TPM_CURRENT", user), int64(res.Usage.TotalTokens)) - if err != nil { - return generateErrorResponse( - envoyTypePb.StatusCode_InternalServerError, - []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ - Key: "x-error-update-tpm", RawValue: []byte("true"), - }}}, - fmt.Sprintf("post query: error on updating tpm: %v", err.Error())) + if user != "" { + tpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_TPM_CURRENT", user), int64(res.Usage.TotalTokens)) + if err != nil { + return generateErrorResponse( + envoyTypePb.StatusCode_InternalServerError, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-error-update-tpm", RawValue: []byte("true"), + }}}, + fmt.Sprintf("post query: error on updating tpm: %v", err.Error())) + } + klog.Infof("RequestEnd %s: TPM: %v for user: %v", reqeustID, tpm, user) } - klog.Infof("RequestEnd %s: TPM: %v for user: %v", reqeustID, tpm, user) return &extProcPb.ProcessingResponse{ Response: &extProcPb.ProcessingResponse_ResponseBody{ @@ -357,12 +329,13 @@ func (s *Server) HandleResponseBody(ctx context.Context, reqeustID string, req * Response: &extProcPb.CommonResponse{ HeaderMutation: &extProcPb.HeaderMutation{ SetHeaders: []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - Key: "x-updated-tpm", - RawValue: []byte(fmt.Sprintf("%d", tpm)), - }, - }, + // TODO (varun): refactor with read model name from body + // { + // Header: &configPb.HeaderValue{ + // Key: "x-updated-tpm", + // RawValue: []byte(fmt.Sprintf("%d", tpm)), + // }, + // }, }, }, }, @@ -371,6 +344,48 @@ func (s *Server) HandleResponseBody(ctx context.Context, reqeustID string, req * } } +func (s *Server) checkLimits(ctx context.Context, requestID string, user utils.User) *extProcPb.ProcessingResponse { + if user.Rpm == 0 { + user.Rpm = int64(defaultRPM) + } + if user.Tpm == 0 { + user.Tpm = user.Rpm * int64(defaultTPMMultiplier) + } + + code, err := s.checkRPM(ctx, user.Name, user.Rpm) + if err != nil { + return generateErrorResponse( + code, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-rpm-exceeded", RawValue: []byte("true"), + }}}, + fmt.Sprintf("pre query: error on checking rpm: %v", err.Error())) + } + + rpm, code, err := s.incrRPM(ctx, user.Name) + if err != nil { + return generateErrorResponse( + code, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-error-update-rpm", RawValue: []byte("true"), + }}}, + fmt.Sprintf("pre query: error on updating rpm: %v", err.Error())) + } + klog.Infof("RequestStart %s: RPM: %v for user: %v", requestID, rpm, user.Name) + + code, err = s.checkTPM(ctx, user.Name, user.Tpm) + if err != nil { + return generateErrorResponse( + code, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-tpm-exceeded", RawValue: []byte("true"), + }}}, + fmt.Sprintf("pre query: error on checking tpm: %v", err.Error())) + } + + return nil +} + func (s *Server) checkRPM(ctx context.Context, user string, rpmLimit int64) (envoyTypePb.StatusCode, error) { rpmCurrent, err := s.ratelimiter.Get(ctx, fmt.Sprintf("%v_RPM_CURRENT", user)) if err != nil {