Skip to content

Commit

Permalink
Disallow Geo check creation/update without configured Geo DB (#1548)
Browse files Browse the repository at this point in the history
  • Loading branch information
bcmmbaga authored Feb 8, 2024
1 parent 74d6918 commit b284c4d
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 10 deletions.
4 changes: 2 additions & 2 deletions management/server/http/geolocations_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ type GeolocationsHandler struct {
claimsExtractor *jwtclaims.ClaimsExtractor
}

// NewLocationsHandlerHandler creates a new Location handler
func NewLocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *GeolocationsHandler {
// NewGeolocationsHandlerHandler creates a new Geolocations handler
func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *GeolocationsHandler {
return &GeolocationsHandler{
accountManager: accountManager,
geolocationManager: geolocationManager,
Expand Down
4 changes: 2 additions & 2 deletions management/server/http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func (apiHandler *apiHandler) addEventsEndpoint() {
}

func (apiHandler *apiHandler) addPostureCheckEndpoint() {
postureCheckHandler := NewPostureChecksHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
postureCheckHandler := NewPostureChecksHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.GetAllPostureChecks).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.CreatePostureCheck).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.UpdatePostureCheck).Methods("PUT", "OPTIONS")
Expand All @@ -218,7 +218,7 @@ func (apiHandler *apiHandler) addPostureCheckEndpoint() {
func (apiHandler *apiHandler) addLocationsEndpoint() {
// enable location endpoints if location manager is enabled
if apiHandler.geolocationManager != nil {
locationHandler := NewLocationsHandlerHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg)
locationHandler := NewGeolocationsHandlerHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/locations/countries", locationHandler.GetAllCountries).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/locations/countries/{country}/cities", locationHandler.GetCitiesByCountry).Methods("GET", "OPTIONS")
}
Expand Down
15 changes: 11 additions & 4 deletions management/server/http/posture_checks_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/rs/xid"

"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
Expand All @@ -22,14 +23,16 @@ var (

// PostureChecksHandler is a handler that returns posture checks of the account.
type PostureChecksHandler struct {
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
accountManager server.AccountManager
geolocationManager *geolocation.Geolocation
claimsExtractor *jwtclaims.ClaimsExtractor
}

// NewPostureChecksHandler creates a new PostureChecks handler
func NewPostureChecksHandler(accountManager server.AccountManager, authCfg AuthCfg) *PostureChecksHandler {
func NewPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *PostureChecksHandler {
return &PostureChecksHandler{
accountManager: accountManager,
accountManager: accountManager,
geolocationManager: geolocationManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
Expand Down Expand Up @@ -201,6 +204,10 @@ func (p *PostureChecksHandler) savePostureChecks(
}

if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
if p.geolocationManager == nil {
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
return
}
postureChecks.Checks = append(postureChecks.Checks, toPostureGeoLocationCheck(geoLocationCheck))
}

Expand Down
126 changes: 124 additions & 2 deletions management/server/http/posture_checks_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/stretchr/testify/assert"

"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
Expand Down Expand Up @@ -67,6 +68,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
}, user, nil
},
},
geolocationManager: &geolocation.Geolocation{},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
Expand Down Expand Up @@ -208,6 +210,7 @@ func TestPostureCheckUpdate(t *testing.T) {
requestType string
requestPath string
requestBody io.Reader
setupHandlerFunc func(handler *PostureChecksHandler)
}{
{
name: "Create Posture Checks NB version",
Expand Down Expand Up @@ -236,6 +239,36 @@ func TestPostureCheckUpdate(t *testing.T) {
},
},
},
{
name: "Create Posture Checks NB version with No geolocation DB",
requestType: http.MethodPost,
requestPath: "/api/posture-checks",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"description": "default",
"checks": {
"nb_version_check": {
"min_version": "1.2.3"
}
}
}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedPostureCheck: &api.PostureCheck{
Id: "postureCheck",
Name: "default",
Description: str("default"),
Checks: api.Checks{
NbVersionCheck: &api.NBVersionCheck{
MinVersion: "1.2.3",
},
},
},
setupHandlerFunc: func(handler *PostureChecksHandler) {
handler.geolocationManager = nil
},
},
{
name: "Create Posture Checks OS version",
requestType: http.MethodPost,
Expand Down Expand Up @@ -318,6 +351,32 @@ func TestPostureCheckUpdate(t *testing.T) {
},
},
},
{
name: "Create Posture Checks Geo Location with No geolocation DB",
requestType: http.MethodPost,
requestPath: "/api/posture-checks",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"description": "default",
"checks": {
"geo_location_check": {
"locations": [
{
"city_name": "Berlin",
"country_code": "DE"
}
],
"action": "allow"
}
}
}`)),
expectedStatus: http.StatusPreconditionFailed,
expectedBody: false,
setupHandlerFunc: func(handler *PostureChecksHandler) {
handler.geolocationManager = nil
},
},
{
name: "Create Posture Checks Invalid Check",
requestType: http.MethodPost,
Expand Down Expand Up @@ -433,6 +492,39 @@ func TestPostureCheckUpdate(t *testing.T) {
},
},
},
{
name: "Update Posture Checks OS Version with No geolocation DB",
requestType: http.MethodPut,
requestPath: "/api/posture-checks/osPostureCheck",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"checks": {
"os_version_check": {
"linux": {
"min_kernel_version": "6.9.0"
}
}
}
}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedPostureCheck: &api.PostureCheck{
Id: "postureCheck",
Name: "default",
Description: str(""),
Checks: api.Checks{
OsVersionCheck: &api.OSVersionCheck{
Linux: &api.MinKernelVersionCheck{
MinKernelVersion: "6.9.0",
},
},
},
},
setupHandlerFunc: func(handler *PostureChecksHandler) {
handler.geolocationManager = nil
},
},
{
name: "Update Posture Checks Geo Location",
requestType: http.MethodPut,
Expand Down Expand Up @@ -471,6 +563,31 @@ func TestPostureCheckUpdate(t *testing.T) {
},
},
},
{
name: "Update Posture Checks Geo Location with No geolocation DB",
requestType: http.MethodPut,
requestPath: "/api/posture-checks/geoPostureCheck",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"checks": {
"geo_location_check": {
"locations": [
{
"city_name": "Los Angeles",
"country_code": "US"
}
],
"action": "allow"
}
}
}`)),
expectedStatus: http.StatusPreconditionFailed,
expectedBody: false,
setupHandlerFunc: func(handler *PostureChecksHandler) {
handler.geolocationManager = nil
},
},
{
name: "Update Posture Checks Invalid Check",
requestType: http.MethodPut,
Expand Down Expand Up @@ -560,9 +677,14 @@ func TestPostureCheckUpdate(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)

defaultHandler := *p
if tc.setupHandlerFunc != nil {
tc.setupHandlerFunc(&defaultHandler)
}

router := mux.NewRouter()
router.HandleFunc("/api/posture-checks", p.CreatePostureCheck).Methods("POST")
router.HandleFunc("/api/posture-checks/{postureCheckId}", p.UpdatePostureCheck).Methods("PUT")
router.HandleFunc("/api/posture-checks", defaultHandler.CreatePostureCheck).Methods("POST")
router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.UpdatePostureCheck).Methods("PUT")
router.ServeHTTP(recorder, req)

res := recorder.Result()
Expand Down

0 comments on commit b284c4d

Please sign in to comment.