Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update random routing section and add support for anonymous user #276

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions pkg/plugins/gateway/algorithms/random.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"fmt"

"golang.org/x/exp/rand"
v1 "k8s.io/api/core/v1"
)

Expand All @@ -36,9 +37,13 @@ func (r randomRouter) Route(ctx context.Context, pods map[string]*v1.Pod) (strin
return "", fmt.Errorf("no pods to forward request")
}

k := rand.Intn(len(pods))
for _, pod := range pods {
selectedPod = pod
break
if k == 0 {
selectedPod = pod
break
}
k--
}

return selectedPod.Status.PodIP, nil
Expand Down
155 changes: 85 additions & 70 deletions pkg/plugins/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
}
}

func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, string, string) {
func (s *Server) HandleRequestHeaders(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, string, string) {
klog.Info("--- In RequestHeaders processing ...")
var username, model, routingStrategy, targetPodIP string
r := req.Request
Expand All @@ -155,52 +155,21 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req
}
}

user, err := utils.GetUser(utils.User{Name: username}, s.redisClient)
if err != nil {
return generateErrorResponse(
envoyTypePb.StatusCode_Forbidden,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-user-missing", RawValue: []byte("true"),
}}},
fmt.Sprintf("pre query: username is missing: %v", err.Error())), username, targetPodIP
}

if user.Rpm == 0 {
user.Rpm = int64(defaultRPM)
}
if user.Tpm == 0 {
user.Tpm = user.Rpm * int64(defaultTPMMultiplier)
}

code, err := s.checkRPM(ctx, username, user.Rpm)
if err != nil {
return generateErrorResponse(
code,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-rpm-exceeded", RawValue: []byte("true"),
}}},
fmt.Sprintf("pre query: error on checking rpm: %v", err.Error())), username, targetPodIP
}

rpm, code, err := s.incrRPM(ctx, username)
if err != nil {
return generateErrorResponse(
code,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-error-update-rpm", RawValue: []byte("true"),
}}},
fmt.Sprintf("pre query: error on updating rpm: %v", err.Error())), username, targetPodIP
}
klog.Infof("RequestStart %s: RPM: %v for user: %v", reqeustID, rpm, user.Name)
if username != "" {
user, err := utils.GetUser(utils.User{Name: username}, s.redisClient)
if err != nil {
return generateErrorResponse(
envoyTypePb.StatusCode_Forbidden,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-user-missing", RawValue: []byte("true"),
}}},
fmt.Sprintf("pre query: username is missing: %v", err.Error())), username, targetPodIP
}

code, err = s.checkTPM(ctx, username, user.Tpm)
if err != nil {
return generateErrorResponse(
code,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-tpm-exceeded", RawValue: []byte("true"),
}}},
fmt.Sprintf("pre query: error on checking tpm: %v", err.Error())), username, targetPodIP
errRes := s.checkLimits(ctx, requestID, user)
if errRes != nil {
return errRes, user.Name, targetPodIP
}
}

headers := []*configPb.HeaderValueOption{
Expand All @@ -210,18 +179,19 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req
RawValue: []byte("true"),
},
},
{
Header: &configPb.HeaderValue{
Key: "x-updated-rpm",
RawValue: []byte(fmt.Sprintf("%d", rpm)),
},
},
// TODO (varun): refactor this part with model name input from request body
// {
// Header: &configPb.HeaderValue{
// Key: "x-updated-rpm",
// RawValue: []byte(fmt.Sprintf("%d", rpm)),
// },
// },
}
if routingStrategy != "" {
pods, err := s.cache.GetPodsForModel(model)
if len(pods) == 0 || err != nil {
return generateErrorResponse(
code,
envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-no-model-deployment", RawValue: []byte("true"),
}}},
Expand All @@ -231,7 +201,7 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req
targetPodIP, err = s.selectTargetPod(ctx, routingStrategy, pods)
if targetPodIP == "" || err != nil {
return generateErrorResponse(
code,
envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-select-target-pod", RawValue: []byte("true"),
}}},
Expand All @@ -244,7 +214,7 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req
RawValue: []byte(targetPodIP),
},
})
klog.Infof("RequestStart %s: SelectedTargetPodIP: %s", reqeustID, targetPodIP)
klog.Infof("RequestStart %s: SelectedTargetPodIP: %s", requestID, targetPodIP)
}

resp := &extProcPb.ProcessingResponse{
Expand Down Expand Up @@ -340,29 +310,32 @@ func (s *Server) HandleResponseBody(ctx context.Context, reqeustID string, req *
err.Error())
}

tpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_TPM_CURRENT", user), int64(res.Usage.TotalTokens))
if err != nil {
return generateErrorResponse(
envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-error-update-tpm", RawValue: []byte("true"),
}}},
fmt.Sprintf("post query: error on updating tpm: %v", err.Error()))
if user != "" {
tpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_TPM_CURRENT", user), int64(res.Usage.TotalTokens))
if err != nil {
return generateErrorResponse(
envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-error-update-tpm", RawValue: []byte("true"),
}}},
fmt.Sprintf("post query: error on updating tpm: %v", err.Error()))
}
klog.Infof("RequestEnd %s: TPM: %v for user: %v", reqeustID, tpm, user)
}
klog.Infof("RequestEnd %s: TPM: %v for user: %v", reqeustID, tpm, user)

return &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ResponseBody{
ResponseBody: &extProcPb.BodyResponse{
Response: &extProcPb.CommonResponse{
HeaderMutation: &extProcPb.HeaderMutation{
SetHeaders: []*configPb.HeaderValueOption{
{
Header: &configPb.HeaderValue{
Key: "x-updated-tpm",
RawValue: []byte(fmt.Sprintf("%d", tpm)),
},
},
// TODO (varun): refactor with read model name from body
// {
// Header: &configPb.HeaderValue{
// Key: "x-updated-tpm",
// RawValue: []byte(fmt.Sprintf("%d", tpm)),
// },
// },
},
},
},
Expand All @@ -371,6 +344,48 @@ func (s *Server) HandleResponseBody(ctx context.Context, reqeustID string, req *
}
}

func (s *Server) checkLimits(ctx context.Context, requestID string, user utils.User) *extProcPb.ProcessingResponse {
if user.Rpm == 0 {
user.Rpm = int64(defaultRPM)
}
if user.Tpm == 0 {
user.Tpm = user.Rpm * int64(defaultTPMMultiplier)
}

code, err := s.checkRPM(ctx, user.Name, user.Rpm)
if err != nil {
return generateErrorResponse(
code,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-rpm-exceeded", RawValue: []byte("true"),
}}},
fmt.Sprintf("pre query: error on checking rpm: %v", err.Error()))
}

rpm, code, err := s.incrRPM(ctx, user.Name)
if err != nil {
return generateErrorResponse(
code,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-error-update-rpm", RawValue: []byte("true"),
}}},
fmt.Sprintf("pre query: error on updating rpm: %v", err.Error()))
}
klog.Infof("RequestStart %s: RPM: %v for user: %v", requestID, rpm, user.Name)

code, err = s.checkTPM(ctx, user.Name, user.Tpm)
if err != nil {
return generateErrorResponse(
code,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-tpm-exceeded", RawValue: []byte("true"),
}}},
fmt.Sprintf("pre query: error on checking tpm: %v", err.Error()))
}

return nil
}

func (s *Server) checkRPM(ctx context.Context, user string, rpmLimit int64) (envoyTypePb.StatusCode, error) {
rpmCurrent, err := s.ratelimiter.Get(ctx, fmt.Sprintf("%v_RPM_CURRENT", user))
if err != nil {
Expand Down
Loading