From 032aae186c11b7a27094cefc5f699ebe344885f3 Mon Sep 17 00:00:00 2001 From: Jesse Peterson Date: Tue, 8 Jun 2021 13:59:18 -0700 Subject: [PATCH] Refactor MDM request service dispatching from http to service package --- http/mdm.go | 61 +++++++++++++++---------------------------- service/request.go | 64 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 41 deletions(-) create mode 100644 service/request.go diff --git a/http/mdm.go b/http/mdm.go index 1d8ea8f..f9e5574 100644 --- a/http/mdm.go +++ b/http/mdm.go @@ -1,7 +1,7 @@ package http import ( - "fmt" + "errors" "net/http" "strings" @@ -11,7 +11,7 @@ import ( ) // CheckinHandlerFunc decodes an MDM check-in request and adapts it to service. -func CheckinHandlerFunc(service service.Checkin, logger log.Logger) http.HandlerFunc { +func CheckinHandlerFunc(svc service.Checkin, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { bodyBytes, err := ReadAllAndReplaceBody(r) if err != nil { @@ -19,46 +19,27 @@ func CheckinHandlerFunc(service service.Checkin, logger log.Logger) http.Handler http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } - m, err := mdm.DecodeCheckin(bodyBytes) - if err != nil { - logger.Info("msg", "decoding check-in", "err", err) - http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) - return - } mdmReq := &mdm.Request{ Context: r.Context(), Certificate: GetCert(r.Context()), } - switch message := m.(type) { - case *mdm.Authenticate: - err = service.Authenticate(mdmReq, message) - if err != nil { - err = fmt.Errorf("authenticate: %w", err) - } - case *mdm.TokenUpdate: - err = service.TokenUpdate(mdmReq, message) - if err != nil { - err = fmt.Errorf("tokenupdate: %w", err) - } - case *mdm.CheckOut: - err = service.CheckOut(mdmReq, message) - if err != nil { - err = fmt.Errorf("checkout: %w", err) - } - default: - logger.Info("err", mdm.ErrUnrecognizedMessageType) - http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) - return - } + respBytes, err := service.CheckinRequest(svc, mdmReq, bodyBytes) if err != nil { - logger.Info("msg", "service error in check-in", "err", err) + logger.Info("msg", "check-in request", "err", err) + var decodeError *service.DecodeError + if errors.Is(err, mdm.ErrUnrecognizedMessageType) || errors.As(err, &decodeError) { + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return } + w.Write(respBytes) } } // CommandAndReportResultsHandlerFunc decodes an MDM command request and adapts it to service. -func CommandAndReportResultsHandlerFunc(service service.CommandAndReportResults, logger log.Logger) http.HandlerFunc { +func CommandAndReportResultsHandlerFunc(svc service.CommandAndReportResults, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { bodyBytes, err := ReadAllAndReplaceBody(r) if err != nil { @@ -66,24 +47,22 @@ func CommandAndReportResultsHandlerFunc(service service.CommandAndReportResults, http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } - report, err := mdm.DecodeCommandResults(bodyBytes) - if err != nil { - logger.Info("msg", "decoding command report", "err", err) - http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) - return - } mdmReq := &mdm.Request{ Context: r.Context(), Certificate: GetCert(r.Context()), } - cmd, err := service.CommandAndReportResults(mdmReq, report) + respBytes, err := service.CommandAndReportResultsRequest(svc, mdmReq, bodyBytes) if err != nil { logger.Info("msg", "command report results", "err", err) + var decodeError *service.DecodeError + if errors.As(err, &decodeError) { + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return } - if cmd != nil { - w.Write(cmd.Raw) - } + w.Write(respBytes) } } diff --git a/service/request.go b/service/request.go new file mode 100644 index 0000000..86d8d81 --- /dev/null +++ b/service/request.go @@ -0,0 +1,64 @@ +package service + +import ( + "fmt" + + "github.com/micromdm/nanomdm/mdm" +) + +type DecodeError struct { + Err error +} + +func (e *DecodeError) Error() string { return "decoding MDM request: " + e.Err.Error() } + +func (e *DecodeError) Unwrap() error { return e.Err } + +func NewDecodeError(err error) *DecodeError { return &DecodeError{Err: err} } + +// CheckinRequest is a simple adapter that takes the raw check-in bodyBytes +// and dispatches to the respective check-in method on svc. +func CheckinRequest(svc Checkin, r *mdm.Request, bodyBytes []byte) ([]byte, error) { + msg, err := mdm.DecodeCheckin(bodyBytes) + if err != nil { + return nil, NewDecodeError(err) + } + switch m := msg.(type) { + case *mdm.Authenticate: + err = svc.Authenticate(r, m) + if err != nil { + err = fmt.Errorf("authenticate service: %w", err) + } + case *mdm.TokenUpdate: + err = svc.TokenUpdate(r, m) + if err != nil { + err = fmt.Errorf("tokenupdate service: %w", err) + } + case *mdm.CheckOut: + err = svc.CheckOut(r, m) + if err != nil { + err = fmt.Errorf("checkout service: %w", err) + } + default: + return nil, mdm.ErrUnrecognizedMessageType + } + return nil, err +} + +// CommandAndReportResultsRequest is a simple adapter that takes the raw +// command result report bodyBytes, dispatches to svc, and returns the +// response. +func CommandAndReportResultsRequest(svc CommandAndReportResults, r *mdm.Request, bodyBytes []byte) ([]byte, error) { + report, err := mdm.DecodeCommandResults(bodyBytes) + if err != nil { + return nil, NewDecodeError(err) + } + cmd, err := svc.CommandAndReportResults(r, report) + if err != nil { + return nil, fmt.Errorf("command and report results service: %w", err) + } + if cmd != nil { + return cmd.Raw, nil + } + return nil, nil +}