From be2ad1089ea35d3dd08cf3b71c23a6e0d61bf687 Mon Sep 17 00:00:00 2001 From: Yury Gargay Date: Fri, 9 Feb 2024 17:45:50 +0100 Subject: [PATCH 1/4] Simplify checks definition serialisation --- management/server/account_test.go | 3 +- .../server/http/posture_checks_handler.go | 48 +++++----- .../http/posture_checks_handler_test.go | 24 ++--- management/server/policy.go | 2 +- management/server/policy_test.go | 6 +- management/server/posture/checks.go | 95 ++++--------------- management/server/posture/checks_test.go | 30 +++--- management/server/sqlite_store.go | 1 + 8 files changed, 76 insertions(+), 133 deletions(-) diff --git a/management/server/account_test.go b/management/server/account_test.go index 2d144df72ba..6527644d520 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1562,8 +1562,7 @@ func TestAccount_Copy(t *testing.T) { DNSSettings: DNSSettings{DisabledManagementGroups: []string{}}, PostureChecks: []*posture.Checks{ { - ID: "posture Checks1", - Checks: make([]posture.Check, 0), + ID: "posture Checks1", }, }, Settings: &Settings{}, diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go index bc50cfed7e1..1a9d4cda6f6 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/posture_checks_handler.go @@ -184,23 +184,22 @@ func (p *PostureChecksHandler) savePostureChecks( ID: postureChecksID, Name: req.Name, Description: req.Description, - Checks: make([]posture.Check, 0), } if nbVersionCheck := req.Checks.NbVersionCheck; nbVersionCheck != nil { - postureChecks.Checks = append(postureChecks.Checks, &posture.NBVersionCheck{ + postureChecks.Checks.NBVersionCheck = &posture.NBVersionCheck{ MinVersion: nbVersionCheck.MinVersion, - }) + } } if osVersionCheck := req.Checks.OsVersionCheck; osVersionCheck != nil { - postureChecks.Checks = append(postureChecks.Checks, &posture.OSVersionCheck{ + postureChecks.Checks.OSVersionCheck = &posture.OSVersionCheck{ Android: (*posture.MinVersionCheck)(osVersionCheck.Android), Darwin: (*posture.MinVersionCheck)(osVersionCheck.Darwin), Ios: (*posture.MinVersionCheck)(osVersionCheck.Ios), Linux: (*posture.MinKernelVersionCheck)(osVersionCheck.Linux), Windows: (*posture.MinKernelVersionCheck)(osVersionCheck.Windows), - }) + } } if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil { @@ -208,7 +207,7 @@ func (p *PostureChecksHandler) savePostureChecks( util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w) return } - postureChecks.Checks = append(postureChecks.Checks, toPostureGeoLocationCheck(geoLocationCheck)) + postureChecks.Checks.GeoLocationCheck = toPostureGeoLocationCheck(geoLocationCheck) } if err := p.accountManager.SavePostureChecks(account.Id, user.Id, &postureChecks); err != nil { @@ -271,28 +270,27 @@ func validatePostureChecksUpdate(req api.PostureCheckUpdate) error { func toPostureChecksResponse(postureChecks *posture.Checks) *api.PostureCheck { var checks api.Checks - for _, check := range postureChecks.Checks { - switch check.Name() { - case posture.NBVersionCheckName: - versionCheck := check.(*posture.NBVersionCheck) - checks.NbVersionCheck = &api.NBVersionCheck{ - MinVersion: versionCheck.MinVersion, - } - case posture.OSVersionCheckName: - osCheck := check.(*posture.OSVersionCheck) - checks.OsVersionCheck = &api.OSVersionCheck{ - Android: (*api.MinVersionCheck)(osCheck.Android), - Darwin: (*api.MinVersionCheck)(osCheck.Darwin), - Ios: (*api.MinVersionCheck)(osCheck.Ios), - Linux: (*api.MinKernelVersionCheck)(osCheck.Linux), - Windows: (*api.MinKernelVersionCheck)(osCheck.Windows), - } - case posture.GeoLocationCheckName: - geoLocationCheck := check.(*posture.GeoLocationCheck) - checks.GeoLocationCheck = toGeoLocationCheckResponse(geoLocationCheck) + + if postureChecks.Checks.NBVersionCheck != nil { + checks.NbVersionCheck = &api.NBVersionCheck{ + MinVersion: postureChecks.Checks.NBVersionCheck.MinVersion, + } + } + + if postureChecks.Checks.OSVersionCheck != nil { + checks.OsVersionCheck = &api.OSVersionCheck{ + Android: (*api.MinVersionCheck)(postureChecks.Checks.OSVersionCheck.Android), + Darwin: (*api.MinVersionCheck)(postureChecks.Checks.OSVersionCheck.Darwin), + Ios: (*api.MinVersionCheck)(postureChecks.Checks.OSVersionCheck.Ios), + Linux: (*api.MinKernelVersionCheck)(postureChecks.Checks.OSVersionCheck.Linux), + Windows: (*api.MinKernelVersionCheck)(postureChecks.Checks.OSVersionCheck.Windows), } } + if postureChecks.Checks.GeoLocationCheck != nil { + checks.GeoLocationCheck = toGeoLocationCheckResponse(postureChecks.Checks.GeoLocationCheck) + } + return &api.PostureCheck{ Id: postureChecks.ID, Name: postureChecks.Name, diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/posture_checks_handler_test.go index 195812cc135..98ca0f99676 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/posture_checks_handler_test.go @@ -85,8 +85,8 @@ func TestGetPostureCheck(t *testing.T) { postureCheck := &posture.Checks{ ID: "postureCheck", Name: "nbVersion", - Checks: []posture.Check{ - &posture.NBVersionCheck{ + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "1.0.0", }, }, @@ -94,8 +94,8 @@ func TestGetPostureCheck(t *testing.T) { osPostureCheck := &posture.Checks{ ID: "osPostureCheck", Name: "osVersion", - Checks: []posture.Check{ - &posture.OSVersionCheck{ + Checks: posture.ChecksDefinition{ + OSVersionCheck: &posture.OSVersionCheck{ Linux: &posture.MinKernelVersionCheck{ MinKernelVersion: "6.0.0", }, @@ -111,8 +111,8 @@ func TestGetPostureCheck(t *testing.T) { geoPostureCheck := &posture.Checks{ ID: "geoPostureCheck", Name: "geoLocation", - Checks: []posture.Check{ - &posture.GeoLocationCheck{ + Checks: posture.ChecksDefinition{ + GeoLocationCheck: &posture.GeoLocationCheck{ Locations: []posture.Location{ { CountryCode: "DE", @@ -638,8 +638,8 @@ func TestPostureCheckUpdate(t *testing.T) { p := initPostureChecksTestData(&posture.Checks{ ID: "postureCheck", Name: "postureCheck", - Checks: []posture.Check{ - &posture.NBVersionCheck{ + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "1.0.0", }, }, @@ -647,8 +647,8 @@ func TestPostureCheckUpdate(t *testing.T) { &posture.Checks{ ID: "osPostureCheck", Name: "osPostureCheck", - Checks: []posture.Check{ - &posture.OSVersionCheck{ + Checks: posture.ChecksDefinition{ + OSVersionCheck: &posture.OSVersionCheck{ Linux: &posture.MinKernelVersionCheck{ MinKernelVersion: "5.0.0", }, @@ -658,8 +658,8 @@ func TestPostureCheckUpdate(t *testing.T) { &posture.Checks{ ID: "geoPostureCheck", Name: "geoLocation", - Checks: []posture.Check{ - &posture.GeoLocationCheck{ + Checks: posture.ChecksDefinition{ + GeoLocationCheck: &posture.GeoLocationCheck{ Locations: []posture.Location{ { CountryCode: "DE", diff --git a/management/server/policy.go b/management/server/policy.go index 04f6b4c76b1..291a4f1f766 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -538,7 +538,7 @@ func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, pe continue } - for _, check := range postureChecks.Checks { + for _, check := range postureChecks.GetChecks() { isValid, err := check.Check(*peer) if err != nil { log.Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error()) diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 57c33962e94..b78f52faef6 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -572,11 +572,11 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { ID: "PostureChecksDefault", Name: "Default", Description: "This is a posture checks that check if peer is running required versions", - Checks: []posture.Check{ - &posture.NBVersionCheck{ + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.25", }, - &posture.OSVersionCheck{ + OSVersionCheck: &posture.OSVersionCheck{ Linux: &posture.MinKernelVersionCheck{ MinKernelVersion: "6.6.0", }, diff --git a/management/server/posture/checks.go b/management/server/posture/checks.go index 1613cf43f82..ddd7cd265da 100644 --- a/management/server/posture/checks.go +++ b/management/server/posture/checks.go @@ -1,14 +1,12 @@ package posture import ( - "encoding/json" - nbpeer "github.com/netbirdio/netbird/management/server/peer" ) const ( - NBVersionCheckName = "NBVersionCheck" - OSVersionCheckName = "OSVersionCheck" + NBVersionCheckName = "NBVersionCheck" + OSVersionCheckName = "OSVersionCheck" GeoLocationCheckName = "GeoLocationCheck" ) @@ -31,8 +29,14 @@ type Checks struct { // AccountID is a reference to the Account that this object belongs AccountID string `json:"-" gorm:"index"` - // Checks is a list of objects that perform the actual checks - Checks []Check `gorm:"serializer:json"` + // Checks is a set of objects that perform the actual checks + Checks ChecksDefinition `gorm:"serializer:json"` +} + +type ChecksDefinition struct { + NBVersionCheck *NBVersionCheck `json:",omitempty"` + OSVersionCheck *OSVersionCheck `json:",omitempty"` + GeoLocationCheck *GeoLocationCheck `json:",omitempty"` } // TableName returns the name of the table for the Checks model in the database. @@ -47,9 +51,8 @@ func (pc *Checks) Copy() *Checks { Name: pc.Name, Description: pc.Description, AccountID: pc.AccountID, - Checks: make([]Check, len(pc.Checks)), + Checks: pc.Checks, // TODO: copy by value } - copy(checks.Checks, pc.Checks) return checks } @@ -58,73 +61,17 @@ func (pc *Checks) EventMeta() map[string]any { return map[string]any{"name": pc.Name} } -// MarshalJSON returns the JSON encoding of the Checks object. -// The Checks object is marshaled as a map[string]json.RawMessage, -// where the key is the name of the check and the value is the JSON -// representation of the Check object. -func (pc *Checks) MarshalJSON() ([]byte, error) { - type Alias Checks - return json.Marshal(&struct { - Checks map[string]json.RawMessage - *Alias - }{ - Checks: pc.marshalChecks(), - Alias: (*Alias)(pc), - }) -} - -// UnmarshalJSON unmarshal the JSON data into the Checks object. -func (pc *Checks) UnmarshalJSON(data []byte) error { - type Alias Checks - aux := &struct { - Checks map[string]json.RawMessage - *Alias - }{ - Alias: (*Alias)(pc), +// GetChecks returns list of all initialized checks definitions +func (pc *Checks) GetChecks() []Check { + var checks []Check + if pc.Checks.NBVersionCheck != nil { + checks = append(checks, pc.Checks.NBVersionCheck) } - - if err := json.Unmarshal(data, &aux); err != nil { - return err + if pc.Checks.OSVersionCheck != nil { + checks = append(checks, pc.Checks.OSVersionCheck) } - return pc.unmarshalChecks(aux.Checks) -} - -func (pc *Checks) marshalChecks() map[string]json.RawMessage { - result := make(map[string]json.RawMessage) - for _, check := range pc.Checks { - data, err := json.Marshal(check) - if err != nil { - return result - } - result[check.Name()] = data - } - return result -} - -func (pc *Checks) unmarshalChecks(rawChecks map[string]json.RawMessage) error { - pc.Checks = make([]Check, 0, len(rawChecks)) - - for name, rawCheck := range rawChecks { - switch name { - case NBVersionCheckName: - check := &NBVersionCheck{} - if err := json.Unmarshal(rawCheck, check); err != nil { - return err - } - pc.Checks = append(pc.Checks, check) - case OSVersionCheckName: - check := &OSVersionCheck{} - if err := json.Unmarshal(rawCheck, check); err != nil { - return err - } - pc.Checks = append(pc.Checks, check) - case GeoLocationCheckName: - check := &GeoLocationCheck{} - if err := json.Unmarshal(rawCheck, check); err != nil { - return err - } - pc.Checks = append(pc.Checks, check) - } + if pc.Checks.GeoLocationCheck != nil { + checks = append(checks, pc.Checks.GeoLocationCheck) } - return nil + return checks } diff --git a/management/server/posture/checks_test.go b/management/server/posture/checks_test.go index cae82caab6a..27c3c479ca5 100644 --- a/management/server/posture/checks_test.go +++ b/management/server/posture/checks_test.go @@ -1,6 +1,7 @@ package posture import ( + "encoding/json" "testing" "github.com/stretchr/testify/assert" @@ -20,8 +21,8 @@ func TestChecks_MarshalJSON(t *testing.T) { Name: "name1", Description: "desc1", AccountID: "acc1", - Checks: []Check{ - &NBVersionCheck{ + Checks: ChecksDefinition{ + NBVersionCheck: &NBVersionCheck{ MinVersion: "1.0.0", }, }, @@ -47,8 +48,8 @@ func TestChecks_MarshalJSON(t *testing.T) { Name: "", Description: "", AccountID: "", - Checks: []Check{ - &NBVersionCheck{}, + Checks: ChecksDefinition{ + NBVersionCheck: &NBVersionCheck{}, }, }, want: []byte(` @@ -69,7 +70,7 @@ func TestChecks_MarshalJSON(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.checks.MarshalJSON() + got, err := json.Marshal(tt.checks) if (err != nil) != tt.wantErr { t.Errorf("Checks.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return @@ -97,7 +98,6 @@ func TestChecks_UnmarshalJSON(t *testing.T) { "Description": "desc1", "Checks": { "NBVersionCheck": { - "Enabled": true, "MinVersion": "1.0.0" } } @@ -107,8 +107,8 @@ func TestChecks_UnmarshalJSON(t *testing.T) { ID: "id1", Name: "name1", Description: "desc1", - Checks: []Check{ - &NBVersionCheck{ + Checks: ChecksDefinition{ + NBVersionCheck: &NBVersionCheck{ MinVersion: "1.0.0", }, }, @@ -121,25 +121,23 @@ func TestChecks_UnmarshalJSON(t *testing.T) { expectedError: true, }, { - name: "Empty JSON Posture Check Unmarshal", - in: []byte(`{}`), - expected: &Checks{ - Checks: make([]Check, 0), - }, + name: "Empty JSON Posture Check Unmarshal", + in: []byte(`{}`), + expected: &Checks{}, expectedError: false, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - checks := &Checks{} - err := checks.UnmarshalJSON(tc.in) + var checks Checks + err := json.Unmarshal(tc.in, &checks) if tc.expectedError { assert.Error(t, err) } else { assert.NoError(t, err) - assert.Equal(t, tc.expected, checks) + assert.Equal(t, tc.expected, &checks) } }) } diff --git a/management/server/sqlite_store.go b/management/server/sqlite_store.go index 5788e52d45d..78d420df4b9 100644 --- a/management/server/sqlite_store.go +++ b/management/server/sqlite_store.go @@ -375,6 +375,7 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) { Preload(clause.Associations). First(&account, "id = ?", accountID) if result.Error != nil { + log.Error(result.Error) return nil, status.Errorf(status.NotFound, "account not found") } From a16cebfc960478f57c0fe6f4710872c539c0cf59 Mon Sep 17 00:00:00 2001 From: Yury Gargay Date: Fri, 9 Feb 2024 17:47:36 +0100 Subject: [PATCH 2/4] Improve error --- management/server/sqlite_store.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/sqlite_store.go b/management/server/sqlite_store.go index 78d420df4b9..7a890b67466 100644 --- a/management/server/sqlite_store.go +++ b/management/server/sqlite_store.go @@ -375,7 +375,7 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) { Preload(clause.Associations). First(&account, "id = ?", accountID) if result.Error != nil { - log.Error(result.Error) + log.Errorf("when getting account from the store: %s", result.Error) return nil, status.Errorf(status.NotFound, "account not found") } From 1f2fd87afdfec729aec46acdb4df656f1ea2bcfd Mon Sep 17 00:00:00 2001 From: Yury Gargay Date: Mon, 12 Feb 2024 09:52:57 +0100 Subject: [PATCH 3/4] Implement ChecksDefinition Copy() --- management/server/posture/checks.go | 43 +++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/management/server/posture/checks.go b/management/server/posture/checks.go index ddd7cd265da..d5590bc80b2 100644 --- a/management/server/posture/checks.go +++ b/management/server/posture/checks.go @@ -33,25 +33,64 @@ type Checks struct { Checks ChecksDefinition `gorm:"serializer:json"` } +// ChecksDefinition contains definition of actual check type ChecksDefinition struct { NBVersionCheck *NBVersionCheck `json:",omitempty"` OSVersionCheck *OSVersionCheck `json:",omitempty"` GeoLocationCheck *GeoLocationCheck `json:",omitempty"` } +// Copy returns a copy of a checks definition. +func (cd ChecksDefinition) Copy() ChecksDefinition { + var cdCopy ChecksDefinition + if cd.NBVersionCheck != nil { + cdCopy.NBVersionCheck = &NBVersionCheck{ + MinVersion: cdCopy.NBVersionCheck.MinVersion, + } + } + if cd.OSVersionCheck != nil { + cdCopy.OSVersionCheck = &OSVersionCheck{} + osCheck := cdCopy.OSVersionCheck + if osCheck.Android != nil { + cdCopy.OSVersionCheck.Android = &MinVersionCheck{MinVersion: osCheck.Android.MinVersion} + } + if osCheck.Darwin != nil { + cdCopy.OSVersionCheck.Darwin = &MinVersionCheck{MinVersion: osCheck.Darwin.MinVersion} + } + if osCheck.Ios != nil { + cdCopy.OSVersionCheck.Ios = &MinVersionCheck{MinVersion: osCheck.Ios.MinVersion} + } + if osCheck.Linux != nil { + cdCopy.OSVersionCheck.Linux = &MinKernelVersionCheck{MinKernelVersion: osCheck.Linux.MinKernelVersion} + } + if osCheck.Windows != nil { + cdCopy.OSVersionCheck.Windows = &MinKernelVersionCheck{MinKernelVersion: osCheck.Windows.MinKernelVersion} + } + } + if cd.GeoLocationCheck != nil { + geoCheck := cd.GeoLocationCheck + cdCopy.GeoLocationCheck = &GeoLocationCheck{ + Action: geoCheck.Action, + Locations: make([]Location, len(geoCheck.Locations)), + } + copy(cd.GeoLocationCheck.Locations, geoCheck.Locations) + } + return cdCopy +} + // TableName returns the name of the table for the Checks model in the database. func (*Checks) TableName() string { return "posture_checks" } -// Copy returns a copy of a policy rule. +// Copy returns a copy of a posture checks. func (pc *Checks) Copy() *Checks { checks := &Checks{ ID: pc.ID, Name: pc.Name, Description: pc.Description, AccountID: pc.AccountID, - Checks: pc.Checks, // TODO: copy by value + Checks: pc.Checks.Copy(), } return checks } From 36571ca06671ed6e1e87f51e5218d7ccdd707bad Mon Sep 17 00:00:00 2001 From: Yury Gargay Date: Mon, 12 Feb 2024 10:47:31 +0100 Subject: [PATCH 4/4] Fix test --- management/server/posture/checks.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/posture/checks.go b/management/server/posture/checks.go index d5590bc80b2..5c2356b0097 100644 --- a/management/server/posture/checks.go +++ b/management/server/posture/checks.go @@ -45,7 +45,7 @@ func (cd ChecksDefinition) Copy() ChecksDefinition { var cdCopy ChecksDefinition if cd.NBVersionCheck != nil { cdCopy.NBVersionCheck = &NBVersionCheck{ - MinVersion: cdCopy.NBVersionCheck.MinVersion, + MinVersion: cd.NBVersionCheck.MinVersion, } } if cd.OSVersionCheck != nil {