Skip to content

Commit

Permalink
Update random routing section and add support for anonymous user (#276)
Browse files Browse the repository at this point in the history
  • Loading branch information
varungup90 authored Oct 8, 2024
1 parent c6e1c2b commit 436be0b
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 72 deletions.
9 changes: 7 additions & 2 deletions pkg/plugins/gateway/algorithms/random.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"fmt"

"golang.org/x/exp/rand"
v1 "k8s.io/api/core/v1"
)

Expand All @@ -36,9 +37,13 @@ func (r randomRouter) Route(ctx context.Context, pods map[string]*v1.Pod) (strin
return "", fmt.Errorf("no pods to forward request")
}

k := rand.Intn(len(pods))
for _, pod := range pods {
selectedPod = pod
break
if k == 0 {
selectedPod = pod
break
}
k--
}

return selectedPod.Status.PodIP, nil
Expand Down
155 changes: 85 additions & 70 deletions pkg/plugins/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
}
}

func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, string, string) {
func (s *Server) HandleRequestHeaders(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, string, string) {
klog.Info("--- In RequestHeaders processing ...")
var username, model, routingStrategy, targetPodIP string
r := req.Request
Expand All @@ -155,52 +155,21 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req
}
}

user, err := utils.GetUser(utils.User{Name: username}, s.redisClient)
if err != nil {
return generateErrorResponse(
envoyTypePb.StatusCode_Forbidden,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-user-missing", RawValue: []byte("true"),
}}},
fmt.Sprintf("pre query: username is missing: %v", err.Error())), username, targetPodIP
}

if user.Rpm == 0 {
user.Rpm = int64(defaultRPM)
}
if user.Tpm == 0 {
user.Tpm = user.Rpm * int64(defaultTPMMultiplier)
}

code, err := s.checkRPM(ctx, username, user.Rpm)
if err != nil {
return generateErrorResponse(
code,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-rpm-exceeded", RawValue: []byte("true"),
}}},
fmt.Sprintf("pre query: error on checking rpm: %v", err.Error())), username, targetPodIP
}

rpm, code, err := s.incrRPM(ctx, username)
if err != nil {
return generateErrorResponse(
code,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-error-update-rpm", RawValue: []byte("true"),
}}},
fmt.Sprintf("pre query: error on updating rpm: %v", err.Error())), username, targetPodIP
}
klog.Infof("RequestStart %s: RPM: %v for user: %v", reqeustID, rpm, user.Name)
if username != "" {
user, err := utils.GetUser(utils.User{Name: username}, s.redisClient)
if err != nil {
return generateErrorResponse(
envoyTypePb.StatusCode_Forbidden,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-user-missing", RawValue: []byte("true"),
}}},
fmt.Sprintf("pre query: username is missing: %v", err.Error())), username, targetPodIP
}

code, err = s.checkTPM(ctx, username, user.Tpm)
if err != nil {
return generateErrorResponse(
code,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-tpm-exceeded", RawValue: []byte("true"),
}}},
fmt.Sprintf("pre query: error on checking tpm: %v", err.Error())), username, targetPodIP
errRes := s.checkLimits(ctx, requestID, user)
if errRes != nil {
return errRes, user.Name, targetPodIP
}
}

headers := []*configPb.HeaderValueOption{
Expand All @@ -210,18 +179,19 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req
RawValue: []byte("true"),
},
},
{
Header: &configPb.HeaderValue{
Key: "x-updated-rpm",
RawValue: []byte(fmt.Sprintf("%d", rpm)),
},
},
// TODO (varun): refactor this part with model name input from request body
// {
// Header: &configPb.HeaderValue{
// Key: "x-updated-rpm",
// RawValue: []byte(fmt.Sprintf("%d", rpm)),
// },
// },
}
if routingStrategy != "" {
pods, err := s.cache.GetPodsForModel(model)
if len(pods) == 0 || err != nil {
return generateErrorResponse(
code,
envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-no-model-deployment", RawValue: []byte("true"),
}}},
Expand All @@ -231,7 +201,7 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req
targetPodIP, err = s.selectTargetPod(ctx, routingStrategy, pods)
if targetPodIP == "" || err != nil {
return generateErrorResponse(
code,
envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-select-target-pod", RawValue: []byte("true"),
}}},
Expand All @@ -244,7 +214,7 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req
RawValue: []byte(targetPodIP),
},
})
klog.Infof("RequestStart %s: SelectedTargetPodIP: %s", reqeustID, targetPodIP)
klog.Infof("RequestStart %s: SelectedTargetPodIP: %s", requestID, targetPodIP)
}

resp := &extProcPb.ProcessingResponse{
Expand Down Expand Up @@ -340,29 +310,32 @@ func (s *Server) HandleResponseBody(ctx context.Context, reqeustID string, req *
err.Error())
}

tpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_TPM_CURRENT", user), int64(res.Usage.TotalTokens))
if err != nil {
return generateErrorResponse(
envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-error-update-tpm", RawValue: []byte("true"),
}}},
fmt.Sprintf("post query: error on updating tpm: %v", err.Error()))
if user != "" {
tpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_TPM_CURRENT", user), int64(res.Usage.TotalTokens))
if err != nil {
return generateErrorResponse(
envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-error-update-tpm", RawValue: []byte("true"),
}}},
fmt.Sprintf("post query: error on updating tpm: %v", err.Error()))
}
klog.Infof("RequestEnd %s: TPM: %v for user: %v", reqeustID, tpm, user)
}
klog.Infof("RequestEnd %s: TPM: %v for user: %v", reqeustID, tpm, user)

return &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ResponseBody{
ResponseBody: &extProcPb.BodyResponse{
Response: &extProcPb.CommonResponse{
HeaderMutation: &extProcPb.HeaderMutation{
SetHeaders: []*configPb.HeaderValueOption{
{
Header: &configPb.HeaderValue{
Key: "x-updated-tpm",
RawValue: []byte(fmt.Sprintf("%d", tpm)),
},
},
// TODO (varun): refactor with read model name from body
// {
// Header: &configPb.HeaderValue{
// Key: "x-updated-tpm",
// RawValue: []byte(fmt.Sprintf("%d", tpm)),
// },
// },
},
},
},
Expand All @@ -371,6 +344,48 @@ func (s *Server) HandleResponseBody(ctx context.Context, reqeustID string, req *
}
}

func (s *Server) checkLimits(ctx context.Context, requestID string, user utils.User) *extProcPb.ProcessingResponse {
if user.Rpm == 0 {
user.Rpm = int64(defaultRPM)
}
if user.Tpm == 0 {
user.Tpm = user.Rpm * int64(defaultTPMMultiplier)
}

code, err := s.checkRPM(ctx, user.Name, user.Rpm)
if err != nil {
return generateErrorResponse(
code,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-rpm-exceeded", RawValue: []byte("true"),
}}},
fmt.Sprintf("pre query: error on checking rpm: %v", err.Error()))
}

rpm, code, err := s.incrRPM(ctx, user.Name)
if err != nil {
return generateErrorResponse(
code,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-error-update-rpm", RawValue: []byte("true"),
}}},
fmt.Sprintf("pre query: error on updating rpm: %v", err.Error()))
}
klog.Infof("RequestStart %s: RPM: %v for user: %v", requestID, rpm, user.Name)

code, err = s.checkTPM(ctx, user.Name, user.Tpm)
if err != nil {
return generateErrorResponse(
code,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-tpm-exceeded", RawValue: []byte("true"),
}}},
fmt.Sprintf("pre query: error on checking tpm: %v", err.Error()))
}

return nil
}

func (s *Server) checkRPM(ctx context.Context, user string, rpmLimit int64) (envoyTypePb.StatusCode, error) {
rpmCurrent, err := s.ratelimiter.Get(ctx, fmt.Sprintf("%v_RPM_CURRENT", user))
if err != nil {
Expand Down

0 comments on commit 436be0b

Please sign in to comment.