Skip to content

Commit

Permalink
Refactoring endpoint_utils (#26342)
Browse files Browse the repository at this point in the history
For #26218 

Refactoring service/android endpoint_utils to remove duplication.
No functional changes.

- [x] Manual QA for all new/changed functionality
  • Loading branch information
getvictor authored Feb 18, 2025
1 parent f200bb3 commit 386ce37
Show file tree
Hide file tree
Showing 19 changed files with 639 additions and 706 deletions.
254 changes: 44 additions & 210 deletions server/mdm/android/service/endpoint_utils.go
Original file line number Diff line number Diff line change
@@ -1,244 +1,78 @@
package service

// TODO(26218): Refactor this to remove duplication.

import (
"bufio"
"compress/gzip"
"context"
"fmt"
"io"
"net/http"
"reflect"
"strings"

"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mdm/android"
"github.com/fleetdm/fleet/v4/server/service/middleware/auth"
"github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
eu "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
"github.com/go-json-experiment/json"
"github.com/go-json-experiment/json/jsontext"
"github.com/go-kit/kit/endpoint"
kithttp "github.com/go-kit/kit/transport/http"
"github.com/gorilla/mux"
)

type handlerFunc func(ctx context.Context, request interface{}, svc android.Service) fleet.Errorer

func encodeResponse(ctx context.Context, w http.ResponseWriter, response interface{}) error {
if e, ok := response.(fleet.Errorer); ok && e.Error() != nil {
endpoint_utils.EncodeError(ctx, e.Error(), w)
return nil
}

if e, ok := response.(statuser); ok {
w.WriteHeader(e.Status())
if e.Status() == http.StatusNoContent {
return nil
}
}

return json.MarshalWrite(w, response, jsontext.WithIndent(" "))
return eu.EncodeCommonResponse(ctx, w, response,
func(w http.ResponseWriter, response interface{}) error {
return json.MarshalWrite(w, response, jsontext.WithIndent(" "))
},
)
}

// statuser allows response types to implement a custom
// http success status - default is 200 OK
type statuser interface {
Status() int
}

// makeDecoder creates a decoder for the type for the struct passed on. If the
// struct has at least 1 json tag it'll unmarshall the body. If the struct has
// a `url` tag with value list_options it'll gather fleet.ListOptions from the
// URL (similarly for host_options, carve_options, user_options that derive
// from the common list_options). Note that these behaviors do not work for embedded structs.
//
// Finally, any other `url` tag will be treated as a path variable (of the form
// /path/{name} in the route's path) from the URL path pattern, and it'll be
// decoded and set accordingly. Variables can be optional by setting the tag as
// follows: `url:"some-id,optional"`.
// The "list_options" are optional by default and it'll ignore the optional
// portion of the tag.
//
// If iface implements the requestDecoder interface, it returns a function that
// calls iface.DecodeRequest(ctx, r) - i.e. the value itself fully controls its
// own decoding.
//
// If iface implements the bodyDecoder interface, it calls iface.DecodeBody
// after having decoded any non-body fields (such as url and query parameters)
// into the struct.
func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
if iface == nil {
return func(ctx context.Context, r *http.Request) (interface{}, error) {
return nil, nil
}
}

t := reflect.TypeOf(iface)
if t.Kind() != reflect.Struct {
panic(fmt.Sprintf("makeDecoder only understands structs, not %T", iface))
}

return func(ctx context.Context, r *http.Request) (interface{}, error) {
v := reflect.New(t)
nilBody := false
buf := bufio.NewReader(r.Body)
var body io.Reader = buf
if _, err := buf.Peek(1); err == io.EOF {
nilBody = true
} else {
if r.Header.Get("content-encoding") == "gzip" {
gzr, err := gzip.NewReader(buf)
if err != nil {
return nil, endpoint_utils.BadRequestErr("gzip decoder error", err)
}
defer gzr.Close()
body = gzr
}

req := v.Interface()
err = json.UnmarshalRead(body, req)
if err != nil {
return nil, endpoint_utils.BadRequestErr("json decoder error", err)
}
v = reflect.ValueOf(req)
}

fields := endpoint_utils.AllFields(v)
for _, fp := range fields {
field := fp.V

urlTagValue, ok := fp.Sf.Tag.Lookup("url")

var err error
if ok {
optional := false
urlTagValue, optional, err = endpoint_utils.ParseTag(urlTagValue)
if err != nil {
return nil, err
}
err = endpoint_utils.DecodeURLTagValue(r, field, urlTagValue, optional)
if err != nil {
return nil, err
}
continue
}

_, jsonExpected := fp.Sf.Tag.Lookup("json")
if jsonExpected && nilBody {
return nil, badRequest("Expected JSON Body")
}
return eu.MakeDecoder(iface, func(body io.Reader, req any) error {
return json.UnmarshalRead(body, req)
}, nil, nil, nil)
}

err = endpoint_utils.DecodeQueryTagValue(r, fp)
if err != nil {
return nil, err
}
}
// Compile-time check to ensure that endpointer implements Endpointer.
var _ eu.Endpointer[eu.AndroidFunc] = &endpointer{}

return v.Interface(), nil
}
type endpointer struct {
svc android.Service
}

func badRequest(msg string) error {
return &fleet.BadRequestError{Message: msg}
func (e *endpointer) CallHandlerFunc(f eu.AndroidFunc, ctx context.Context, request interface{},
svc interface{}) (fleet.Errorer, error) {
return f(ctx, request, svc.(android.Service)), nil
}

type authEndpointer struct {
fleetSvc fleet.Service
svc android.Service
opts []kithttp.ServerOption
r *mux.Router
authFunc func(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpoint
versions []string
customMiddleware []endpoint.Middleware
func (e *endpointer) Service() interface{} {
return e.svc
}

func newUserAuthenticatedEndpointer(fleetSvc fleet.Service, svc android.Service, opts []kithttp.ServerOption, r *mux.Router,
versions ...string) *authEndpointer {
return &authEndpointer{
fleetSvc: fleetSvc,
svc: svc,
opts: opts,
r: r,
authFunc: auth.AuthenticatedUser,
versions: versions,
versions ...string) *eu.CommonEndpointer[eu.AndroidFunc] {
return &eu.CommonEndpointer[eu.AndroidFunc]{
EP: &endpointer{
svc: svc,
},
MakeDecoderFn: makeDecoder,
EncodeFn: encodeResponse,
Opts: opts,
AuthFunc: auth.AuthenticatedUser,
FleetService: fleetSvc,
Router: r,
Versions: versions,
}
}

func newNoAuthEndpointer(svc android.Service, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer {
return &authEndpointer{
fleetSvc: nil,
svc: svc,
opts: opts,
r: r,
authFunc: auth.UnauthenticatedRequest,
versions: versions,
func newNoAuthEndpointer(fleetSvc fleet.Service, svc android.Service, opts []kithttp.ServerOption, r *mux.Router,
versions ...string) *eu.CommonEndpointer[eu.AndroidFunc] {
return &eu.CommonEndpointer[eu.AndroidFunc]{
EP: &endpointer{
svc: svc,
},
MakeDecoderFn: makeDecoder,
EncodeFn: encodeResponse,
Opts: opts,
AuthFunc: auth.UnauthenticatedRequest,
FleetService: fleetSvc,
Router: r,
Versions: versions,
}
}

var pathReplacer = strings.NewReplacer(
"/", "_",
"{", "_",
"}", "_",
)

func getNameFromPathAndVerb(verb, path string) string {
prefix := strings.ToLower(verb) + "_"
return prefix + pathReplacer.Replace(strings.TrimPrefix(strings.TrimRight(path, "/"), "/api/_version_/fleet/"))
}

func (e *authEndpointer) POST(path string, f handlerFunc, v interface{}) {
e.handleEndpoint(path, f, v, "POST")
}

func (e *authEndpointer) GET(path string, f handlerFunc, v interface{}) {
e.handleEndpoint(path, f, v, "GET")
}

func (e *authEndpointer) PUT(path string, f handlerFunc, v interface{}) {
e.handleEndpoint(path, f, v, "PUT")
}

func (e *authEndpointer) PATCH(path string, f handlerFunc, v interface{}) {
e.handleEndpoint(path, f, v, "PATCH")
}

func (e *authEndpointer) DELETE(path string, f handlerFunc, v interface{}) {
e.handleEndpoint(path, f, v, "DELETE")
}

func (e *authEndpointer) HEAD(path string, f handlerFunc, v interface{}) {
e.handleEndpoint(path, f, v, "HEAD")
}

func (e *authEndpointer) handlePathHandler(path string, pathHandler func(path string) http.Handler, verb string) {
versions := e.versions
versionedPath := strings.Replace(path, "/_version_/", fmt.Sprintf("/{fleetversion:(?:%s)}/", strings.Join(versions, "|")), 1)
nameAndVerb := getNameFromPathAndVerb(verb, path)
e.r.Handle(versionedPath, pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb)
}

func (e *authEndpointer) handleHTTPHandler(path string, h http.Handler, verb string) {
self := func(_ string) http.Handler { return h }
e.handlePathHandler(path, self, verb)
}

func (e *authEndpointer) handleEndpoint(path string, f handlerFunc, v interface{}, verb string) {
e.handleHTTPHandler(path, e.makeEndpoint(f, v), verb)
}

func (e *authEndpointer) makeEndpoint(f handlerFunc, v interface{}) http.Handler {
next := func(ctx context.Context, request interface{}) (interface{}, error) {
return f(ctx, request, e.svc), nil
}
endPt := e.authFunc(e.fleetSvc, next)

// apply middleware in reverse order so that the first wraps the second
// wraps the third etc.
for i := len(e.customMiddleware) - 1; i >= 0; i-- {
mw := e.customMiddleware[i]
endPt = mw(endPt)
}

return newServer(endPt, makeDecoder(v), e.opts)
}
17 changes: 3 additions & 14 deletions server/mdm/android/service/handler.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
package service

import (
"net/http"

"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mdm/android"
"github.com/fleetdm/fleet/v4/server/service/middleware/authzcheck"
"github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
"github.com/go-kit/kit/endpoint"
kithttp "github.com/go-kit/kit/transport/http"
"github.com/gorilla/mux"
)
Expand All @@ -29,11 +25,9 @@ func attachFleetAPIRoutes(r *mux.Router, fleetSvc fleet.Service, svc android.Ser
ue.GET("/api/_version_/fleet/android_enterprise/{id:[0-9]+}/enrollment_token", androidEnrollmentTokenEndpoint,
androidEnrollmentTokenRequest{})

// unauthenticated endpoints - most of those are either login-related,
// invite-related or host-enrolling. So they typically do some kind of
// one-time authentication by verifying that a valid secret token is provided
// with the request.
ne := newNoAuthEndpointer(svc, opts, r, apiVersions()...)
// unauthenticated endpoints
// They typically do one-time authentication by verifying that a valid secret token is provided with the request.
ne := newNoAuthEndpointer(fleetSvc, svc, opts, r, apiVersions()...)

// Android management
ne.GET("/api/_version_/fleet/android_enterprise/{id:[0-9]+}/connect", androidEnterpriseSignupCallbackEndpoint,
Expand All @@ -44,8 +38,3 @@ func attachFleetAPIRoutes(r *mux.Router, fleetSvc fleet.Service, svc android.Ser
func apiVersions() []string {
return []string{"v1"}
}

func newServer(e endpoint.Endpoint, decodeFn kithttp.DecodeRequestFunc, opts []kithttp.ServerOption) http.Handler {
e = authzcheck.NewMiddleware().AuthzCheck()(e)
return kithttp.NewServer(e, decodeFn, encodeResponse, opts...)
}
14 changes: 7 additions & 7 deletions server/service/apple_mdm.go
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ type getMDMAppleConfigProfileResponse struct {

func (r getMDMAppleConfigProfileResponse) Error() error { return r.Err }

func (r getMDMAppleConfigProfileResponse) hijackRender(ctx context.Context, w http.ResponseWriter) {
func (r getMDMAppleConfigProfileResponse) HijackRender(ctx context.Context, w http.ResponseWriter) {
w.Header().Set("Content-Length", strconv.FormatInt(r.fileLength, 10))
w.Header().Set("Content-Type", "application/x-apple-aspen-config")
w.Header().Set("X-Content-Type-Options", "nosniff")
Expand Down Expand Up @@ -1389,13 +1389,13 @@ func (r mdmAppleEnrollResponse) Error() error { return r.Err }
type mdmAppleEnrollResponse struct {
Err error `json:"error,omitempty"`

// Profile field is used in hijackRender for the response.
// Profile field is used in HijackRender for the response.
Profile []byte

SoftwareUpdateRequired *fleet.MDMAppleSoftwareUpdateRequired
}

func (r mdmAppleEnrollResponse) hijackRender(ctx context.Context, w http.ResponseWriter) {
func (r mdmAppleEnrollResponse) HijackRender(ctx context.Context, w http.ResponseWriter) {
if r.SoftwareUpdateRequired != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
Expand Down Expand Up @@ -1718,7 +1718,7 @@ type mdmAppleGetInstallerResponse struct {
installer []byte
}

func (r mdmAppleGetInstallerResponse) hijackRender(ctx context.Context, w http.ResponseWriter) {
func (r mdmAppleGetInstallerResponse) HijackRender(ctx context.Context, w http.ResponseWriter) {
w.Header().Set("Content-Length", strconv.FormatInt(r.size, 10))
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment;filename="%s"`, r.name))
Expand Down Expand Up @@ -2322,7 +2322,7 @@ type downloadBootstrapPackageResponse struct {

func (r downloadBootstrapPackageResponse) Error() error { return r.Err }

func (r downloadBootstrapPackageResponse) hijackRender(ctx context.Context, w http.ResponseWriter) {
func (r downloadBootstrapPackageResponse) HijackRender(ctx context.Context, w http.ResponseWriter) {
w.Header().Set("Content-Length", strconv.Itoa(len(r.pkg.Bytes)))
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment;filename="%s"`, r.pkg.Name))
Expand Down Expand Up @@ -2659,7 +2659,7 @@ type callbackMDMAppleSSOResponse struct {
redirectURL string
}

func (r callbackMDMAppleSSOResponse) hijackRender(ctx context.Context, w http.ResponseWriter) {
func (r callbackMDMAppleSSOResponse) HijackRender(ctx context.Context, w http.ResponseWriter) {
w.Header().Set("Location", r.redirectURL)
w.WriteHeader(http.StatusSeeOther)
}
Expand Down Expand Up @@ -4832,7 +4832,7 @@ type mdmAppleOTAResponse struct {

func (r mdmAppleOTAResponse) Error() error { return r.Err }

func (r mdmAppleOTAResponse) hijackRender(ctx context.Context, w http.ResponseWriter) {
func (r mdmAppleOTAResponse) HijackRender(ctx context.Context, w http.ResponseWriter) {
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(r.xml)))
w.Header().Set("Content-Type", "application/x-apple-aspen-config")
w.Header().Set("X-Content-Type-Options", "nosniff")
Expand Down
Loading

0 comments on commit 386ce37

Please sign in to comment.