From 386ce37168f8fea8bd662052f05f805202d871b1 Mon Sep 17 00:00:00 2001 From: Victor Lyuboslavsky Date: Tue, 18 Feb 2025 11:09:43 -0600 Subject: [PATCH] Refactoring endpoint_utils (#26342) For #26218 Refactoring service/android endpoint_utils to remove duplication. No functional changes. - [x] Manual QA for all new/changed functionality --- server/mdm/android/service/endpoint_utils.go | 254 ++------- server/mdm/android/service/handler.go | 17 +- server/service/apple_mdm.go | 14 +- server/service/devices.go | 10 +- server/service/endpoint_utils.go | 517 +++++------------- server/service/endpoint_utils_test.go | 2 +- server/service/frontend.go | 5 +- server/service/handler.go | 14 +- server/service/hosts.go | 4 +- server/service/installer.go | 2 +- server/service/mdm.go | 2 +- server/service/microsoft_mdm.go | 16 +- .../endpoint_utils/endpoint_utils.go | 421 ++++++++++++++ server/service/orbit.go | 6 +- server/service/osquery.go | 2 +- server/service/scripts.go | 2 +- server/service/sessions.go | 4 +- server/service/software_installers.go | 2 +- server/service/transport.go | 51 +- 19 files changed, 639 insertions(+), 706 deletions(-) diff --git a/server/mdm/android/service/endpoint_utils.go b/server/mdm/android/service/endpoint_utils.go index 0e0cd7cd333f..a4644388fb6c 100644 --- a/server/mdm/android/service/endpoint_utils.go +++ b/server/mdm/android/service/endpoint_utils.go @@ -1,244 +1,78 @@ package service -// TODO(26218): Refactor this to remove duplication. - import ( - "bufio" - "compress/gzip" "context" - "fmt" "io" "net/http" - "reflect" - "strings" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/mdm/android" "github.com/fleetdm/fleet/v4/server/service/middleware/auth" - "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" + eu "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" "github.com/go-json-experiment/json" "github.com/go-json-experiment/json/jsontext" - "github.com/go-kit/kit/endpoint" kithttp "github.com/go-kit/kit/transport/http" "github.com/gorilla/mux" ) -type handlerFunc func(ctx context.Context, request interface{}, svc android.Service) fleet.Errorer - func encodeResponse(ctx context.Context, w http.ResponseWriter, response interface{}) error { - if e, ok := response.(fleet.Errorer); ok && e.Error() != nil { - endpoint_utils.EncodeError(ctx, e.Error(), w) - return nil - } - - if e, ok := response.(statuser); ok { - w.WriteHeader(e.Status()) - if e.Status() == http.StatusNoContent { - return nil - } - } - - return json.MarshalWrite(w, response, jsontext.WithIndent(" ")) + return eu.EncodeCommonResponse(ctx, w, response, + func(w http.ResponseWriter, response interface{}) error { + return json.MarshalWrite(w, response, jsontext.WithIndent(" ")) + }, + ) } -// statuser allows response types to implement a custom -// http success status - default is 200 OK -type statuser interface { - Status() int -} - -// makeDecoder creates a decoder for the type for the struct passed on. If the -// struct has at least 1 json tag it'll unmarshall the body. If the struct has -// a `url` tag with value list_options it'll gather fleet.ListOptions from the -// URL (similarly for host_options, carve_options, user_options that derive -// from the common list_options). Note that these behaviors do not work for embedded structs. -// -// Finally, any other `url` tag will be treated as a path variable (of the form -// /path/{name} in the route's path) from the URL path pattern, and it'll be -// decoded and set accordingly. Variables can be optional by setting the tag as -// follows: `url:"some-id,optional"`. -// The "list_options" are optional by default and it'll ignore the optional -// portion of the tag. -// -// If iface implements the requestDecoder interface, it returns a function that -// calls iface.DecodeRequest(ctx, r) - i.e. the value itself fully controls its -// own decoding. -// -// If iface implements the bodyDecoder interface, it calls iface.DecodeBody -// after having decoded any non-body fields (such as url and query parameters) -// into the struct. func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc { - if iface == nil { - return func(ctx context.Context, r *http.Request) (interface{}, error) { - return nil, nil - } - } - - t := reflect.TypeOf(iface) - if t.Kind() != reflect.Struct { - panic(fmt.Sprintf("makeDecoder only understands structs, not %T", iface)) - } - - return func(ctx context.Context, r *http.Request) (interface{}, error) { - v := reflect.New(t) - nilBody := false - buf := bufio.NewReader(r.Body) - var body io.Reader = buf - if _, err := buf.Peek(1); err == io.EOF { - nilBody = true - } else { - if r.Header.Get("content-encoding") == "gzip" { - gzr, err := gzip.NewReader(buf) - if err != nil { - return nil, endpoint_utils.BadRequestErr("gzip decoder error", err) - } - defer gzr.Close() - body = gzr - } - - req := v.Interface() - err = json.UnmarshalRead(body, req) - if err != nil { - return nil, endpoint_utils.BadRequestErr("json decoder error", err) - } - v = reflect.ValueOf(req) - } - - fields := endpoint_utils.AllFields(v) - for _, fp := range fields { - field := fp.V - - urlTagValue, ok := fp.Sf.Tag.Lookup("url") - - var err error - if ok { - optional := false - urlTagValue, optional, err = endpoint_utils.ParseTag(urlTagValue) - if err != nil { - return nil, err - } - err = endpoint_utils.DecodeURLTagValue(r, field, urlTagValue, optional) - if err != nil { - return nil, err - } - continue - } - - _, jsonExpected := fp.Sf.Tag.Lookup("json") - if jsonExpected && nilBody { - return nil, badRequest("Expected JSON Body") - } + return eu.MakeDecoder(iface, func(body io.Reader, req any) error { + return json.UnmarshalRead(body, req) + }, nil, nil, nil) +} - err = endpoint_utils.DecodeQueryTagValue(r, fp) - if err != nil { - return nil, err - } - } +// Compile-time check to ensure that endpointer implements Endpointer. +var _ eu.Endpointer[eu.AndroidFunc] = &endpointer{} - return v.Interface(), nil - } +type endpointer struct { + svc android.Service } -func badRequest(msg string) error { - return &fleet.BadRequestError{Message: msg} +func (e *endpointer) CallHandlerFunc(f eu.AndroidFunc, ctx context.Context, request interface{}, + svc interface{}) (fleet.Errorer, error) { + return f(ctx, request, svc.(android.Service)), nil } -type authEndpointer struct { - fleetSvc fleet.Service - svc android.Service - opts []kithttp.ServerOption - r *mux.Router - authFunc func(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpoint - versions []string - customMiddleware []endpoint.Middleware +func (e *endpointer) Service() interface{} { + return e.svc } func newUserAuthenticatedEndpointer(fleetSvc fleet.Service, svc android.Service, opts []kithttp.ServerOption, r *mux.Router, - versions ...string) *authEndpointer { - return &authEndpointer{ - fleetSvc: fleetSvc, - svc: svc, - opts: opts, - r: r, - authFunc: auth.AuthenticatedUser, - versions: versions, + versions ...string) *eu.CommonEndpointer[eu.AndroidFunc] { + return &eu.CommonEndpointer[eu.AndroidFunc]{ + EP: &endpointer{ + svc: svc, + }, + MakeDecoderFn: makeDecoder, + EncodeFn: encodeResponse, + Opts: opts, + AuthFunc: auth.AuthenticatedUser, + FleetService: fleetSvc, + Router: r, + Versions: versions, } } -func newNoAuthEndpointer(svc android.Service, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer { - return &authEndpointer{ - fleetSvc: nil, - svc: svc, - opts: opts, - r: r, - authFunc: auth.UnauthenticatedRequest, - versions: versions, +func newNoAuthEndpointer(fleetSvc fleet.Service, svc android.Service, opts []kithttp.ServerOption, r *mux.Router, + versions ...string) *eu.CommonEndpointer[eu.AndroidFunc] { + return &eu.CommonEndpointer[eu.AndroidFunc]{ + EP: &endpointer{ + svc: svc, + }, + MakeDecoderFn: makeDecoder, + EncodeFn: encodeResponse, + Opts: opts, + AuthFunc: auth.UnauthenticatedRequest, + FleetService: fleetSvc, + Router: r, + Versions: versions, } } - -var pathReplacer = strings.NewReplacer( - "/", "_", - "{", "_", - "}", "_", -) - -func getNameFromPathAndVerb(verb, path string) string { - prefix := strings.ToLower(verb) + "_" - return prefix + pathReplacer.Replace(strings.TrimPrefix(strings.TrimRight(path, "/"), "/api/_version_/fleet/")) -} - -func (e *authEndpointer) POST(path string, f handlerFunc, v interface{}) { - e.handleEndpoint(path, f, v, "POST") -} - -func (e *authEndpointer) GET(path string, f handlerFunc, v interface{}) { - e.handleEndpoint(path, f, v, "GET") -} - -func (e *authEndpointer) PUT(path string, f handlerFunc, v interface{}) { - e.handleEndpoint(path, f, v, "PUT") -} - -func (e *authEndpointer) PATCH(path string, f handlerFunc, v interface{}) { - e.handleEndpoint(path, f, v, "PATCH") -} - -func (e *authEndpointer) DELETE(path string, f handlerFunc, v interface{}) { - e.handleEndpoint(path, f, v, "DELETE") -} - -func (e *authEndpointer) HEAD(path string, f handlerFunc, v interface{}) { - e.handleEndpoint(path, f, v, "HEAD") -} - -func (e *authEndpointer) handlePathHandler(path string, pathHandler func(path string) http.Handler, verb string) { - versions := e.versions - versionedPath := strings.Replace(path, "/_version_/", fmt.Sprintf("/{fleetversion:(?:%s)}/", strings.Join(versions, "|")), 1) - nameAndVerb := getNameFromPathAndVerb(verb, path) - e.r.Handle(versionedPath, pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb) -} - -func (e *authEndpointer) handleHTTPHandler(path string, h http.Handler, verb string) { - self := func(_ string) http.Handler { return h } - e.handlePathHandler(path, self, verb) -} - -func (e *authEndpointer) handleEndpoint(path string, f handlerFunc, v interface{}, verb string) { - e.handleHTTPHandler(path, e.makeEndpoint(f, v), verb) -} - -func (e *authEndpointer) makeEndpoint(f handlerFunc, v interface{}) http.Handler { - next := func(ctx context.Context, request interface{}) (interface{}, error) { - return f(ctx, request, e.svc), nil - } - endPt := e.authFunc(e.fleetSvc, next) - - // apply middleware in reverse order so that the first wraps the second - // wraps the third etc. - for i := len(e.customMiddleware) - 1; i >= 0; i-- { - mw := e.customMiddleware[i] - endPt = mw(endPt) - } - - return newServer(endPt, makeDecoder(v), e.opts) -} diff --git a/server/mdm/android/service/handler.go b/server/mdm/android/service/handler.go index 66d2b5f9692c..37c15ccb7379 100644 --- a/server/mdm/android/service/handler.go +++ b/server/mdm/android/service/handler.go @@ -1,13 +1,9 @@ package service import ( - "net/http" - "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/mdm/android" - "github.com/fleetdm/fleet/v4/server/service/middleware/authzcheck" "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" - "github.com/go-kit/kit/endpoint" kithttp "github.com/go-kit/kit/transport/http" "github.com/gorilla/mux" ) @@ -29,11 +25,9 @@ func attachFleetAPIRoutes(r *mux.Router, fleetSvc fleet.Service, svc android.Ser ue.GET("/api/_version_/fleet/android_enterprise/{id:[0-9]+}/enrollment_token", androidEnrollmentTokenEndpoint, androidEnrollmentTokenRequest{}) - // unauthenticated endpoints - most of those are either login-related, - // invite-related or host-enrolling. So they typically do some kind of - // one-time authentication by verifying that a valid secret token is provided - // with the request. - ne := newNoAuthEndpointer(svc, opts, r, apiVersions()...) + // unauthenticated endpoints + // They typically do one-time authentication by verifying that a valid secret token is provided with the request. + ne := newNoAuthEndpointer(fleetSvc, svc, opts, r, apiVersions()...) // Android management ne.GET("/api/_version_/fleet/android_enterprise/{id:[0-9]+}/connect", androidEnterpriseSignupCallbackEndpoint, @@ -44,8 +38,3 @@ func attachFleetAPIRoutes(r *mux.Router, fleetSvc fleet.Service, svc android.Ser func apiVersions() []string { return []string{"v1"} } - -func newServer(e endpoint.Endpoint, decodeFn kithttp.DecodeRequestFunc, opts []kithttp.ServerOption) http.Handler { - e = authzcheck.NewMiddleware().AuthzCheck()(e) - return kithttp.NewServer(e, decodeFn, encodeResponse, opts...) -} diff --git a/server/service/apple_mdm.go b/server/service/apple_mdm.go index 2d830e60083b..72e92216d8f9 100644 --- a/server/service/apple_mdm.go +++ b/server/service/apple_mdm.go @@ -696,7 +696,7 @@ type getMDMAppleConfigProfileResponse struct { func (r getMDMAppleConfigProfileResponse) Error() error { return r.Err } -func (r getMDMAppleConfigProfileResponse) hijackRender(ctx context.Context, w http.ResponseWriter) { +func (r getMDMAppleConfigProfileResponse) HijackRender(ctx context.Context, w http.ResponseWriter) { w.Header().Set("Content-Length", strconv.FormatInt(r.fileLength, 10)) w.Header().Set("Content-Type", "application/x-apple-aspen-config") w.Header().Set("X-Content-Type-Options", "nosniff") @@ -1389,13 +1389,13 @@ func (r mdmAppleEnrollResponse) Error() error { return r.Err } type mdmAppleEnrollResponse struct { Err error `json:"error,omitempty"` - // Profile field is used in hijackRender for the response. + // Profile field is used in HijackRender for the response. Profile []byte SoftwareUpdateRequired *fleet.MDMAppleSoftwareUpdateRequired } -func (r mdmAppleEnrollResponse) hijackRender(ctx context.Context, w http.ResponseWriter) { +func (r mdmAppleEnrollResponse) HijackRender(ctx context.Context, w http.ResponseWriter) { if r.SoftwareUpdateRequired != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusForbidden) @@ -1718,7 +1718,7 @@ type mdmAppleGetInstallerResponse struct { installer []byte } -func (r mdmAppleGetInstallerResponse) hijackRender(ctx context.Context, w http.ResponseWriter) { +func (r mdmAppleGetInstallerResponse) HijackRender(ctx context.Context, w http.ResponseWriter) { w.Header().Set("Content-Length", strconv.FormatInt(r.size, 10)) w.Header().Set("Content-Type", "application/octet-stream") w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment;filename="%s"`, r.name)) @@ -2322,7 +2322,7 @@ type downloadBootstrapPackageResponse struct { func (r downloadBootstrapPackageResponse) Error() error { return r.Err } -func (r downloadBootstrapPackageResponse) hijackRender(ctx context.Context, w http.ResponseWriter) { +func (r downloadBootstrapPackageResponse) HijackRender(ctx context.Context, w http.ResponseWriter) { w.Header().Set("Content-Length", strconv.Itoa(len(r.pkg.Bytes))) w.Header().Set("Content-Type", "application/octet-stream") w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment;filename="%s"`, r.pkg.Name)) @@ -2659,7 +2659,7 @@ type callbackMDMAppleSSOResponse struct { redirectURL string } -func (r callbackMDMAppleSSOResponse) hijackRender(ctx context.Context, w http.ResponseWriter) { +func (r callbackMDMAppleSSOResponse) HijackRender(ctx context.Context, w http.ResponseWriter) { w.Header().Set("Location", r.redirectURL) w.WriteHeader(http.StatusSeeOther) } @@ -4832,7 +4832,7 @@ type mdmAppleOTAResponse struct { func (r mdmAppleOTAResponse) Error() error { return r.Err } -func (r mdmAppleOTAResponse) hijackRender(ctx context.Context, w http.ResponseWriter) { +func (r mdmAppleOTAResponse) HijackRender(ctx context.Context, w http.ResponseWriter) { w.Header().Set("Content-Length", fmt.Sprintf("%d", len(r.xml))) w.Header().Set("Content-Type", "application/x-apple-aspen-config") w.Header().Set("X-Content-Type-Options", "nosniff") diff --git a/server/service/devices.go b/server/service/devices.go index 990dc6718f30..2de1136a7762 100644 --- a/server/service/devices.go +++ b/server/service/devices.go @@ -41,7 +41,7 @@ type devicePingResponse struct{} func (r devicePingResponse) Error() error { return nil } -func (r devicePingResponse) hijackRender(ctx context.Context, w http.ResponseWriter) { +func (r devicePingResponse) HijackRender(ctx context.Context, w http.ResponseWriter) { writeCapabilitiesHeader(w, fleet.GetServerDeviceCapabilities()) } @@ -379,11 +379,11 @@ func (r *transparencyURLRequest) deviceAuthToken() string { } type transparencyURLResponse struct { - RedirectURL string `json:"-"` // used to control the redirect, see hijackRender method + RedirectURL string `json:"-"` // used to control the redirect, see HijackRender method Err error `json:"error,omitempty"` } -func (r transparencyURLResponse) hijackRender(ctx context.Context, w http.ResponseWriter) { +func (r transparencyURLResponse) HijackRender(ctx context.Context, w http.ResponseWriter) { w.Header().Set("Location", r.RedirectURL) w.WriteHeader(http.StatusTemporaryRedirect) } @@ -489,13 +489,13 @@ func (r *getDeviceMDMManualEnrollProfileRequest) deviceAuthToken() string { } type getDeviceMDMManualEnrollProfileResponse struct { - // Profile field is used in hijackRender for the response. + // Profile field is used in HijackRender for the response. Profile []byte Err error `json:"error,omitempty"` } -func (r getDeviceMDMManualEnrollProfileResponse) hijackRender(ctx context.Context, w http.ResponseWriter) { +func (r getDeviceMDMManualEnrollProfileResponse) HijackRender(ctx context.Context, w http.ResponseWriter) { // make the browser download the content to a file w.Header().Add("Content-Disposition", `attachment; filename="fleet-mdm-enrollment-profile.mobileconfig"`) // explicitly set the content length before the write, so the caller can diff --git a/server/service/endpoint_utils.go b/server/service/endpoint_utils.go index 5c00105bc79c..690b44b68e4e 100644 --- a/server/service/endpoint_utils.go +++ b/server/service/endpoint_utils.go @@ -1,207 +1,137 @@ package service import ( - "bufio" - "compress/gzip" "context" "crypto/x509" "encoding/json" - "fmt" "io" "net/http" "net/url" "reflect" - "strconv" - "strings" "github.com/fleetdm/fleet/v4/server/contexts/capabilities" - "github.com/fleetdm/fleet/v4/server/contexts/license" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/service/middleware/auth" - "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" + eu "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" "github.com/go-kit/kit/endpoint" kithttp "github.com/go-kit/kit/transport/http" "github.com/go-kit/log" "github.com/gorilla/mux" ) -type handlerFunc func(ctx context.Context, request interface{}, svc fleet.Service) (fleet.Errorer, error) - -// A value that implements requestDecoder takes control of decoding the request -// as a whole - that is, it is responsible for decoding the body and any url -// or query argument itself. -type requestDecoder interface { - DecodeRequest(ctx context.Context, r *http.Request) (interface{}, error) +func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc { + return eu.MakeDecoder(iface, jsonDecode, parseCustomTags, isBodyDecoder, decodeBody) } -// A value that implements bodyDecoder takes control of decoding the request -// body. +// A value that implements bodyDecoder takes control of decoding the request body. type bodyDecoder interface { DecodeBody(ctx context.Context, r io.Reader, u url.Values, c []*x509.Certificate) error } -// makeDecoder creates a decoder for the type for the struct passed on. If the -// struct has at least 1 json tag it'll unmarshall the body. If the struct has -// a `url` tag with value list_options it'll gather fleet.ListOptions from the -// URL (similarly for host_options, carve_options, user_options that derive -// from the common list_options). Note that these behaviors do not work for embedded structs. -// -// Finally, any other `url` tag will be treated as a path variable (of the form -// /path/{name} in the route's path) from the URL path pattern, and it'll be -// decoded and set accordingly. Variables can be optional by setting the tag as -// follows: `url:"some-id,optional"`. -// The "list_options" are optional by default and it'll ignore the optional -// portion of the tag. -// -// If iface implements the requestDecoder interface, it returns a function that -// calls iface.DecodeRequest(ctx, r) - i.e. the value itself fully controls its -// own decoding. -// -// If iface implements the bodyDecoder interface, it calls iface.DecodeBody -// after having decoded any non-body fields (such as url and query parameters) -// into the struct. -func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc { - if iface == nil { - return func(ctx context.Context, r *http.Request) (interface{}, error) { - return nil, nil - } - } - if rd, ok := iface.(requestDecoder); ok { - return func(ctx context.Context, r *http.Request) (interface{}, error) { - return rd.DecodeRequest(ctx, r) - } +func decodeBody(ctx context.Context, r *http.Request, v reflect.Value, body io.Reader) error { + bd := v.Interface().(bodyDecoder) + var certs []*x509.Certificate + if (r.TLS != nil) && (r.TLS.PeerCertificates != nil) { + certs = r.TLS.PeerCertificates } - t := reflect.TypeOf(iface) - if t.Kind() != reflect.Struct { - panic(fmt.Sprintf("makeDecoder only understands structs, not %T", iface)) + if err := bd.DecodeBody(ctx, body, r.URL.Query(), certs); err != nil { + return err } + return nil +} - return func(ctx context.Context, r *http.Request) (interface{}, error) { - v := reflect.New(t) - nilBody := false - - var isBodyDecoder bool - if _, ok := v.Interface().(bodyDecoder); ok { - isBodyDecoder = true +func parseCustomTags(urlTagValue string, r *http.Request, field reflect.Value) (bool, error) { + switch urlTagValue { + case "list_options": + opts, err := listOptionsFromRequest(r) + if err != nil { + return false, err } + field.Set(reflect.ValueOf(opts)) + return true, nil - buf := bufio.NewReader(r.Body) - var body io.Reader = buf - if _, err := buf.Peek(1); err == io.EOF { - nilBody = true - } else { - if r.Header.Get("content-encoding") == "gzip" { - gzr, err := gzip.NewReader(buf) - if err != nil { - return nil, endpoint_utils.BadRequestErr("gzip decoder error", err) - } - defer gzr.Close() - body = gzr - } - - if !isBodyDecoder { - req := v.Interface() - if err := json.NewDecoder(body).Decode(req); err != nil { - return nil, endpoint_utils.BadRequestErr("json decoder error", err) - } - v = reflect.ValueOf(req) - } + case "user_options": + opts, err := userListOptionsFromRequest(r) + if err != nil { + return false, err } + field.Set(reflect.ValueOf(opts)) + return true, nil - fields := endpoint_utils.AllFields(v) - for _, fp := range fields { - field := fp.V - - urlTagValue, ok := fp.Sf.Tag.Lookup("url") - - var err error - if ok { - optional := false - urlTagValue, optional, err = endpoint_utils.ParseTag(urlTagValue) - if err != nil { - return nil, err - } - switch urlTagValue { - case "list_options": - opts, err := listOptionsFromRequest(r) - if err != nil { - return nil, err - } - field.Set(reflect.ValueOf(opts)) - - case "user_options": - opts, err := userListOptionsFromRequest(r) - if err != nil { - return nil, err - } - field.Set(reflect.ValueOf(opts)) + case "host_options": + opts, err := hostListOptionsFromRequest(r) + if err != nil { + return false, err + } + field.Set(reflect.ValueOf(opts)) + return true, nil - case "host_options": - opts, err := hostListOptionsFromRequest(r) - if err != nil { - return nil, err - } - field.Set(reflect.ValueOf(opts)) + case "carve_options": + opts, err := carveListOptionsFromRequest(r) + if err != nil { + return false, err + } + field.Set(reflect.ValueOf(opts)) + return true, nil + } + return false, nil +} - case "carve_options": - opts, err := carveListOptionsFromRequest(r) - if err != nil { - return nil, err - } - field.Set(reflect.ValueOf(opts)) +func jsonDecode(body io.Reader, req any) error { + return json.NewDecoder(body).Decode(req) +} - default: - err := endpoint_utils.DecodeURLTagValue(r, field, urlTagValue, optional) - if err != nil { - return nil, err - } - continue - } - } +func isBodyDecoder(v reflect.Value) bool { + _, ok := v.Interface().(bodyDecoder) + return ok +} - _, jsonExpected := fp.Sf.Tag.Lookup("json") - if jsonExpected && nilBody { - return nil, badRequest("Expected JSON Body") - } +// Compile-time check to ensure that endpointer implements Endpointer. +var _ eu.Endpointer[eu.HandlerFunc] = &endpointer{} - err = endpoint_utils.DecodeQueryTagValue(r, fp) - if err != nil { - return nil, err - } - } +type endpointer struct { + svc fleet.Service +} - if isBodyDecoder { - bd := v.Interface().(bodyDecoder) - var certs []*x509.Certificate - if (r.TLS != nil) && (r.TLS.PeerCertificates != nil) { - certs = r.TLS.PeerCertificates - } +func (e *endpointer) CallHandlerFunc(f eu.HandlerFunc, ctx context.Context, request interface{}, + svc interface{}) (fleet.Errorer, error) { + return f(ctx, request, svc.(fleet.Service)) +} - if err := bd.DecodeBody(ctx, body, r.URL.Query(), certs); err != nil { - return nil, err - } - } +func (e *endpointer) Service() interface{} { + return e.svc +} - if !license.IsPremium(ctx) { - for _, fp := range fields { - if prem, ok := fp.Sf.Tag.Lookup("premium"); ok { - val, err := strconv.ParseBool(prem) - if err != nil { - return nil, err - } - if val && !fp.V.IsZero() { - return nil, &fleet.BadRequestError{Message: fmt.Sprintf( - "option %s requires a premium license", - fp.Sf.Name, - )} - } - continue - } - } - } +func newUserAuthenticatedEndpointer(svc fleet.Service, opts []kithttp.ServerOption, r *mux.Router, + versions ...string) *eu.CommonEndpointer[eu.HandlerFunc] { + return &eu.CommonEndpointer[eu.HandlerFunc]{ + EP: &endpointer{ + svc: svc, + }, + MakeDecoderFn: makeDecoder, + EncodeFn: encodeResponse, + Opts: opts, + AuthFunc: auth.AuthenticatedUser, + FleetService: svc, + Router: r, + Versions: versions, + } +} - return v.Interface(), nil +func newNoAuthEndpointer(svc fleet.Service, opts []kithttp.ServerOption, r *mux.Router, + versions ...string) *eu.CommonEndpointer[eu.HandlerFunc] { + return &eu.CommonEndpointer[eu.HandlerFunc]{ + EP: &endpointer{ + svc: svc, + }, + MakeDecoderFn: makeDecoder, + EncodeFn: encodeResponse, + Opts: opts, + AuthFunc: auth.UnauthenticatedRequest, + FleetService: svc, + Router: r, + Versions: versions, } } @@ -209,20 +139,8 @@ func badRequest(msg string) error { return &fleet.BadRequestError{Message: msg} } -type authEndpointer struct { - svc fleet.Service - opts []kithttp.ServerOption - r *mux.Router - authFunc func(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpoint - versions []string - startingAtVersion string - endingAtVersion string - alternativePaths []string - customMiddleware []endpoint.Middleware - usePathPrefix bool -} - -func newDeviceAuthenticatedEndpointer(svc fleet.Service, logger log.Logger, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer { +func newDeviceAuthenticatedEndpointer(svc fleet.Service, logger log.Logger, opts []kithttp.ServerOption, r *mux.Router, + versions ...string) *eu.CommonEndpointer[eu.HandlerFunc] { authFunc := func(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpoint { return authenticatedDevice(svc, logger, next) } @@ -232,39 +150,42 @@ func newDeviceAuthenticatedEndpointer(svc fleet.Service, logger log.Logger, opts // Add the capabilities reported by the device to the request context opts = append(opts, capabilitiesContextFunc()) - return &authEndpointer{ - svc: svc, - opts: opts, - r: r, - authFunc: authFunc, - versions: versions, + return &eu.CommonEndpointer[eu.HandlerFunc]{ + EP: &endpointer{ + svc: svc, + }, + MakeDecoderFn: makeDecoder, + EncodeFn: encodeResponse, + Opts: opts, + AuthFunc: authFunc, + FleetService: svc, + Router: r, + Versions: versions, } -} -func newUserAuthenticatedEndpointer(svc fleet.Service, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer { - return &authEndpointer{ - svc: svc, - opts: opts, - r: r, - authFunc: auth.AuthenticatedUser, - versions: versions, - } } -func newHostAuthenticatedEndpointer(svc fleet.Service, logger log.Logger, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer { +func newHostAuthenticatedEndpointer(svc fleet.Service, logger log.Logger, opts []kithttp.ServerOption, r *mux.Router, + versions ...string) *eu.CommonEndpointer[eu.HandlerFunc] { authFunc := func(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpoint { return authenticatedHost(svc, logger, next) } - return &authEndpointer{ - svc: svc, - opts: opts, - r: r, - authFunc: authFunc, - versions: versions, + return &eu.CommonEndpointer[eu.HandlerFunc]{ + EP: &endpointer{ + svc: svc, + }, + MakeDecoderFn: makeDecoder, + EncodeFn: encodeResponse, + Opts: opts, + AuthFunc: authFunc, + FleetService: svc, + Router: r, + Versions: versions, } } -func newOrbitAuthenticatedEndpointer(svc fleet.Service, logger log.Logger, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer { +func newOrbitAuthenticatedEndpointer(svc fleet.Service, logger log.Logger, opts []kithttp.ServerOption, r *mux.Router, + versions ...string) *eu.CommonEndpointer[eu.HandlerFunc] { authFunc := func(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpoint { return authenticatedOrbitHost(svc, logger, next) } @@ -274,39 +195,20 @@ func newOrbitAuthenticatedEndpointer(svc fleet.Service, logger log.Logger, opts // Add the capabilities reported by Orbit to the request context opts = append(opts, capabilitiesContextFunc()) - return &authEndpointer{ - svc: svc, - opts: opts, - r: r, - authFunc: authFunc, - versions: versions, - } -} - -func newNoAuthEndpointer(svc fleet.Service, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer { - return &authEndpointer{ - svc: svc, - opts: opts, - r: r, - authFunc: auth.UnauthenticatedRequest, - versions: versions, + return &eu.CommonEndpointer[eu.HandlerFunc]{ + EP: &endpointer{ + svc: svc, + }, + MakeDecoderFn: makeDecoder, + EncodeFn: encodeResponse, + Opts: opts, + AuthFunc: authFunc, + FleetService: svc, + Router: r, + Versions: versions, } } -var pathReplacer = strings.NewReplacer( - "/", "_", - "{", "_", - "}", "_", -) - -func getNameFromPathAndVerb(verb, path, startAt string) string { - prefix := strings.ToLower(verb) + "_" - if startAt != "" { - prefix += pathReplacer.Replace(startAt) + "_" - } - return prefix + pathReplacer.Replace(strings.TrimPrefix(strings.TrimRight(path, "/"), "/api/_version_/fleet/")) -} - func capabilitiesResponseFunc(capabilities fleet.CapabilityMap) kithttp.ServerOption { return kithttp.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context { writeCapabilitiesHeader(w, capabilities) @@ -325,162 +227,3 @@ func writeCapabilitiesHeader(w http.ResponseWriter, capabilities fleet.Capabilit w.Header().Set(fleet.CapabilitiesHeader, capabilities.String()) } - -func writeBrowserSecurityHeaders(w http.ResponseWriter) { - // Strict-Transport-Security informs browsers that the site should only be - // accessed using HTTPS, and that any future attempts to access it using - // HTTP should automatically be converted to HTTPS. - w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains;") - // X-Frames-Options disallows embedding the UI in other sites via , - //