Skip to content

Commit

Permalink
Refactor MDM request service dispatching from http to service package
Browse files Browse the repository at this point in the history
  • Loading branch information
jessepeterson committed Jun 8, 2021
1 parent 3c9a161 commit 032aae1
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 41 deletions.
61 changes: 20 additions & 41 deletions http/mdm.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package http

import (
"fmt"
"errors"
"net/http"
"strings"

Expand All @@ -11,79 +11,58 @@ import (
)

// CheckinHandlerFunc decodes an MDM check-in request and adapts it to service.
func CheckinHandlerFunc(service service.Checkin, logger log.Logger) http.HandlerFunc {
func CheckinHandlerFunc(svc service.Checkin, logger log.Logger) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
bodyBytes, err := ReadAllAndReplaceBody(r)
if err != nil {
logger.Info("msg", "reading body", "err", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
m, err := mdm.DecodeCheckin(bodyBytes)
if err != nil {
logger.Info("msg", "decoding check-in", "err", err)
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
mdmReq := &mdm.Request{
Context: r.Context(),
Certificate: GetCert(r.Context()),
}
switch message := m.(type) {
case *mdm.Authenticate:
err = service.Authenticate(mdmReq, message)
if err != nil {
err = fmt.Errorf("authenticate: %w", err)
}
case *mdm.TokenUpdate:
err = service.TokenUpdate(mdmReq, message)
if err != nil {
err = fmt.Errorf("tokenupdate: %w", err)
}
case *mdm.CheckOut:
err = service.CheckOut(mdmReq, message)
if err != nil {
err = fmt.Errorf("checkout: %w", err)
}
default:
logger.Info("err", mdm.ErrUnrecognizedMessageType)
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
respBytes, err := service.CheckinRequest(svc, mdmReq, bodyBytes)
if err != nil {
logger.Info("msg", "service error in check-in", "err", err)
logger.Info("msg", "check-in request", "err", err)
var decodeError *service.DecodeError
if errors.Is(err, mdm.ErrUnrecognizedMessageType) || errors.As(err, &decodeError) {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
w.Write(respBytes)
}
}

// CommandAndReportResultsHandlerFunc decodes an MDM command request and adapts it to service.
func CommandAndReportResultsHandlerFunc(service service.CommandAndReportResults, logger log.Logger) http.HandlerFunc {
func CommandAndReportResultsHandlerFunc(svc service.CommandAndReportResults, logger log.Logger) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
bodyBytes, err := ReadAllAndReplaceBody(r)
if err != nil {
logger.Info("msg", "reading body", "err", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
report, err := mdm.DecodeCommandResults(bodyBytes)
if err != nil {
logger.Info("msg", "decoding command report", "err", err)
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
mdmReq := &mdm.Request{
Context: r.Context(),
Certificate: GetCert(r.Context()),
}
cmd, err := service.CommandAndReportResults(mdmReq, report)
respBytes, err := service.CommandAndReportResultsRequest(svc, mdmReq, bodyBytes)
if err != nil {
logger.Info("msg", "command report results", "err", err)
var decodeError *service.DecodeError
if errors.As(err, &decodeError) {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if cmd != nil {
w.Write(cmd.Raw)
}
w.Write(respBytes)
}
}

Expand Down
64 changes: 64 additions & 0 deletions service/request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package service

import (
"fmt"

"github.com/micromdm/nanomdm/mdm"
)

type DecodeError struct {
Err error
}

func (e *DecodeError) Error() string { return "decoding MDM request: " + e.Err.Error() }

func (e *DecodeError) Unwrap() error { return e.Err }

func NewDecodeError(err error) *DecodeError { return &DecodeError{Err: err} }

// CheckinRequest is a simple adapter that takes the raw check-in bodyBytes
// and dispatches to the respective check-in method on svc.
func CheckinRequest(svc Checkin, r *mdm.Request, bodyBytes []byte) ([]byte, error) {
msg, err := mdm.DecodeCheckin(bodyBytes)
if err != nil {
return nil, NewDecodeError(err)
}
switch m := msg.(type) {
case *mdm.Authenticate:
err = svc.Authenticate(r, m)
if err != nil {
err = fmt.Errorf("authenticate service: %w", err)
}
case *mdm.TokenUpdate:
err = svc.TokenUpdate(r, m)
if err != nil {
err = fmt.Errorf("tokenupdate service: %w", err)
}
case *mdm.CheckOut:
err = svc.CheckOut(r, m)
if err != nil {
err = fmt.Errorf("checkout service: %w", err)
}
default:
return nil, mdm.ErrUnrecognizedMessageType
}
return nil, err
}

// CommandAndReportResultsRequest is a simple adapter that takes the raw
// command result report bodyBytes, dispatches to svc, and returns the
// response.
func CommandAndReportResultsRequest(svc CommandAndReportResults, r *mdm.Request, bodyBytes []byte) ([]byte, error) {
report, err := mdm.DecodeCommandResults(bodyBytes)
if err != nil {
return nil, NewDecodeError(err)
}
cmd, err := svc.CommandAndReportResults(r, report)
if err != nil {
return nil, fmt.Errorf("command and report results service: %w", err)
}
if cmd != nil {
return cmd.Raw, nil
}
return nil, nil
}

0 comments on commit 032aae1

Please sign in to comment.