From b1d46ad07d5acaafee5c70c7105825874263f613 Mon Sep 17 00:00:00 2001 From: Jesse Peterson Date: Sun, 3 Apr 2022 13:29:12 -0700 Subject: [PATCH] Move id_first and id_count to ctxlog pairs for API endpoints. --- http/api.go | 45 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/http/api.go b/http/api.go index e0fa3f7..75c0f95 100644 --- a/http/api.go +++ b/http/api.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "crypto/x509" "encoding/json" "encoding/pem" @@ -38,6 +39,36 @@ type apiResult struct { RequestType string `json:"request_type,omitempty"` } +type ( + ctxKeyIDFirst struct{} + ctxKeyIDCount struct{} +) + +func setAPIIDs(ctx context.Context, idFirst string, idCount int) context.Context { + ctx = context.WithValue(ctx, ctxKeyIDFirst{}, idFirst) + return context.WithValue(ctx, ctxKeyIDCount{}, idCount) +} + +func ctxKVs(ctx context.Context) (out []interface{}) { + id, ok := ctx.Value(ctxKeyIDFirst{}).(string) + if ok { + out = append(out, "id_first", id) + } + eType, ok := ctx.Value(ctxKeyIDCount{}).(int) + if ok { + out = append(out, "id_count", eType) + } + return +} + +func setupCtxLog(ctx context.Context, ids []string, logger log.Logger) (context.Context, log.Logger) { + if len(ids) > 0 { + ctx = setAPIIDs(ctx, ids[0], len(ids)) + ctx = ctxlog.AddFunc(ctx, ctxKVs) + } + return ctx, ctxlog.Logger(ctx, logger) +} + // PushHandlerFunc sends APNs push notifications to MDM enrollments. // // Note the whole URL path is used as the identifier to push to. This @@ -46,12 +77,12 @@ type apiResult struct { // users. func PushHandlerFunc(pusher push.Pusher, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - logger := ctxlog.Logger(r.Context(), logger) ids := strings.Split(r.URL.Path, ",") + ctx, logger := setupCtxLog(r.Context(), ids, logger) output := apiResult{ Status: make(enrolledAPIResults), } - pushResp, err := pusher.Push(r.Context(), ids) + pushResp, err := pusher.Push(ctx, ids) if err != nil { logger.Info("msg", "push", "err", err) output.PushError = err.Error() @@ -90,7 +121,8 @@ func PushHandlerFunc(pusher push.Pusher, logger log.Logger) http.HandlerFunc { // for "API" users. func RawCommandEnqueueHandler(enqueuer storage.CommandEnqueuer, pusher push.Pusher, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - logger := ctxlog.Logger(r.Context(), logger) + ids := strings.Split(r.URL.Path, ",") + ctx, logger := setupCtxLog(r.Context(), ids, logger) b, err := ReadAllAndReplaceBody(r) if err != nil { logger.Info("msg", "reading body", "err", err) @@ -103,7 +135,6 @@ func RawCommandEnqueueHandler(enqueuer storage.CommandEnqueuer, pusher push.Push http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } - ids := strings.Split(r.URL.Path, ",") nopush := r.URL.Query().Get("nopush") != "" output := apiResult{ Status: make(enrolledAPIResults), @@ -111,14 +142,14 @@ func RawCommandEnqueueHandler(enqueuer storage.CommandEnqueuer, pusher push.Push CommandUUID: command.CommandUUID, RequestType: command.Command.RequestType, } - idErrs, err := enqueuer.EnqueueCommand(r.Context(), ids, command) + idErrs, err := enqueuer.EnqueueCommand(ctx, ids, command) if err != nil { logger.Info("msg", "enqueue command", "err", err) output.CommandError = err.Error() } pushResp := make(map[string]*push.Response) if !nopush { - pushResp, err = pusher.Push(r.Context(), ids) + pushResp, err = pusher.Push(ctx, ids) if err != nil { logger.Info("msg", "push", "err", err) output.PushError = err.Error() @@ -149,8 +180,6 @@ func RawCommandEnqueueHandler(enqueuer storage.CommandEnqueuer, pusher push.Push "msg", "enqueue", "command_uuid", command.CommandUUID, "request_type", command.Command.RequestType, - "id_count", len(ids), - "id_first", ids[0], ) logger.Debug("msg", "push", "count", len(pushResp)) json, err := json.MarshalIndent(output, "", "\t")