Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify checks definition serialisation #1555

Merged
merged 4 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions management/server/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1562,8 +1562,7 @@ func TestAccount_Copy(t *testing.T) {
DNSSettings: DNSSettings{DisabledManagementGroups: []string{}},
PostureChecks: []*posture.Checks{
{
ID: "posture Checks1",
Checks: make([]posture.Check, 0),
ID: "posture Checks1",
},
},
Settings: &Settings{},
Expand Down
48 changes: 23 additions & 25 deletions management/server/http/posture_checks_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,31 +184,30 @@ func (p *PostureChecksHandler) savePostureChecks(
ID: postureChecksID,
Name: req.Name,
Description: req.Description,
Checks: make([]posture.Check, 0),
}

if nbVersionCheck := req.Checks.NbVersionCheck; nbVersionCheck != nil {
postureChecks.Checks = append(postureChecks.Checks, &posture.NBVersionCheck{
postureChecks.Checks.NBVersionCheck = &posture.NBVersionCheck{
MinVersion: nbVersionCheck.MinVersion,
})
}
}

if osVersionCheck := req.Checks.OsVersionCheck; osVersionCheck != nil {
postureChecks.Checks = append(postureChecks.Checks, &posture.OSVersionCheck{
postureChecks.Checks.OSVersionCheck = &posture.OSVersionCheck{
Android: (*posture.MinVersionCheck)(osVersionCheck.Android),
Darwin: (*posture.MinVersionCheck)(osVersionCheck.Darwin),
Ios: (*posture.MinVersionCheck)(osVersionCheck.Ios),
Linux: (*posture.MinKernelVersionCheck)(osVersionCheck.Linux),
Windows: (*posture.MinKernelVersionCheck)(osVersionCheck.Windows),
})
}
}

if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
if p.geolocationManager == nil {
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
return
}
postureChecks.Checks = append(postureChecks.Checks, toPostureGeoLocationCheck(geoLocationCheck))
postureChecks.Checks.GeoLocationCheck = toPostureGeoLocationCheck(geoLocationCheck)
}

if err := p.accountManager.SavePostureChecks(account.Id, user.Id, &postureChecks); err != nil {
Expand Down Expand Up @@ -271,28 +270,27 @@ func validatePostureChecksUpdate(req api.PostureCheckUpdate) error {

func toPostureChecksResponse(postureChecks *posture.Checks) *api.PostureCheck {
var checks api.Checks
for _, check := range postureChecks.Checks {
switch check.Name() {
case posture.NBVersionCheckName:
versionCheck := check.(*posture.NBVersionCheck)
checks.NbVersionCheck = &api.NBVersionCheck{
MinVersion: versionCheck.MinVersion,
}
case posture.OSVersionCheckName:
osCheck := check.(*posture.OSVersionCheck)
checks.OsVersionCheck = &api.OSVersionCheck{
Android: (*api.MinVersionCheck)(osCheck.Android),
Darwin: (*api.MinVersionCheck)(osCheck.Darwin),
Ios: (*api.MinVersionCheck)(osCheck.Ios),
Linux: (*api.MinKernelVersionCheck)(osCheck.Linux),
Windows: (*api.MinKernelVersionCheck)(osCheck.Windows),
}
case posture.GeoLocationCheckName:
geoLocationCheck := check.(*posture.GeoLocationCheck)
checks.GeoLocationCheck = toGeoLocationCheckResponse(geoLocationCheck)

if postureChecks.Checks.NBVersionCheck != nil {
checks.NbVersionCheck = &api.NBVersionCheck{
MinVersion: postureChecks.Checks.NBVersionCheck.MinVersion,
}
}

if postureChecks.Checks.OSVersionCheck != nil {
checks.OsVersionCheck = &api.OSVersionCheck{
Android: (*api.MinVersionCheck)(postureChecks.Checks.OSVersionCheck.Android),
Darwin: (*api.MinVersionCheck)(postureChecks.Checks.OSVersionCheck.Darwin),
Ios: (*api.MinVersionCheck)(postureChecks.Checks.OSVersionCheck.Ios),
Linux: (*api.MinKernelVersionCheck)(postureChecks.Checks.OSVersionCheck.Linux),
Windows: (*api.MinKernelVersionCheck)(postureChecks.Checks.OSVersionCheck.Windows),
}
}

if postureChecks.Checks.GeoLocationCheck != nil {
checks.GeoLocationCheck = toGeoLocationCheckResponse(postureChecks.Checks.GeoLocationCheck)
}

return &api.PostureCheck{
Id: postureChecks.ID,
Name: postureChecks.Name,
Expand Down
24 changes: 12 additions & 12 deletions management/server/http/posture_checks_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,17 @@ func TestGetPostureCheck(t *testing.T) {
postureCheck := &posture.Checks{
ID: "postureCheck",
Name: "nbVersion",
Checks: []posture.Check{
&posture.NBVersionCheck{
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "1.0.0",
},
},
}
osPostureCheck := &posture.Checks{
ID: "osPostureCheck",
Name: "osVersion",
Checks: []posture.Check{
&posture.OSVersionCheck{
Checks: posture.ChecksDefinition{
OSVersionCheck: &posture.OSVersionCheck{
Linux: &posture.MinKernelVersionCheck{
MinKernelVersion: "6.0.0",
},
Expand All @@ -111,8 +111,8 @@ func TestGetPostureCheck(t *testing.T) {
geoPostureCheck := &posture.Checks{
ID: "geoPostureCheck",
Name: "geoLocation",
Checks: []posture.Check{
&posture.GeoLocationCheck{
Checks: posture.ChecksDefinition{
GeoLocationCheck: &posture.GeoLocationCheck{
Locations: []posture.Location{
{
CountryCode: "DE",
Expand Down Expand Up @@ -638,17 +638,17 @@ func TestPostureCheckUpdate(t *testing.T) {
p := initPostureChecksTestData(&posture.Checks{
ID: "postureCheck",
Name: "postureCheck",
Checks: []posture.Check{
&posture.NBVersionCheck{
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "1.0.0",
},
},
},
&posture.Checks{
ID: "osPostureCheck",
Name: "osPostureCheck",
Checks: []posture.Check{
&posture.OSVersionCheck{
Checks: posture.ChecksDefinition{
OSVersionCheck: &posture.OSVersionCheck{
Linux: &posture.MinKernelVersionCheck{
MinKernelVersion: "5.0.0",
},
Expand All @@ -658,8 +658,8 @@ func TestPostureCheckUpdate(t *testing.T) {
&posture.Checks{
ID: "geoPostureCheck",
Name: "geoLocation",
Checks: []posture.Check{
&posture.GeoLocationCheck{
Checks: posture.ChecksDefinition{
GeoLocationCheck: &posture.GeoLocationCheck{
Locations: []posture.Location{
{
CountryCode: "DE",
Expand Down
2 changes: 1 addition & 1 deletion management/server/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, pe
continue
}

for _, check := range postureChecks.Checks {
for _, check := range postureChecks.GetChecks() {
isValid, err := check.Check(*peer)
if err != nil {
log.Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error())
Expand Down
6 changes: 3 additions & 3 deletions management/server/policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -572,11 +572,11 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
ID: "PostureChecksDefault",
Name: "Default",
Description: "This is a posture checks that check if peer is running required versions",
Checks: []posture.Check{
&posture.NBVersionCheck{
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.25",
},
&posture.OSVersionCheck{
OSVersionCheck: &posture.OSVersionCheck{
Linux: &posture.MinKernelVersionCheck{
MinKernelVersion: "6.6.0",
},
Expand Down
136 changes: 61 additions & 75 deletions management/server/posture/checks.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
package posture

import (
"encoding/json"

nbpeer "github.com/netbirdio/netbird/management/server/peer"
)

const (
NBVersionCheckName = "NBVersionCheck"
OSVersionCheckName = "OSVersionCheck"
NBVersionCheckName = "NBVersionCheck"
OSVersionCheckName = "OSVersionCheck"
GeoLocationCheckName = "GeoLocationCheck"
)

Expand All @@ -31,25 +29,69 @@ type Checks struct {
// AccountID is a reference to the Account that this object belongs
AccountID string `json:"-" gorm:"index"`

// Checks is a list of objects that perform the actual checks
Checks []Check `gorm:"serializer:json"`
// Checks is a set of objects that perform the actual checks
Checks ChecksDefinition `gorm:"serializer:json"`
}

// ChecksDefinition contains definition of actual check
type ChecksDefinition struct {
NBVersionCheck *NBVersionCheck `json:",omitempty"`
OSVersionCheck *OSVersionCheck `json:",omitempty"`
GeoLocationCheck *GeoLocationCheck `json:",omitempty"`
}

// Copy returns a copy of a checks definition.
func (cd ChecksDefinition) Copy() ChecksDefinition {
var cdCopy ChecksDefinition
if cd.NBVersionCheck != nil {
cdCopy.NBVersionCheck = &NBVersionCheck{
MinVersion: cd.NBVersionCheck.MinVersion,
}
}
if cd.OSVersionCheck != nil {
cdCopy.OSVersionCheck = &OSVersionCheck{}
osCheck := cdCopy.OSVersionCheck
if osCheck.Android != nil {
cdCopy.OSVersionCheck.Android = &MinVersionCheck{MinVersion: osCheck.Android.MinVersion}
}
if osCheck.Darwin != nil {
cdCopy.OSVersionCheck.Darwin = &MinVersionCheck{MinVersion: osCheck.Darwin.MinVersion}
}
if osCheck.Ios != nil {
cdCopy.OSVersionCheck.Ios = &MinVersionCheck{MinVersion: osCheck.Ios.MinVersion}
}
if osCheck.Linux != nil {
cdCopy.OSVersionCheck.Linux = &MinKernelVersionCheck{MinKernelVersion: osCheck.Linux.MinKernelVersion}
}
if osCheck.Windows != nil {
cdCopy.OSVersionCheck.Windows = &MinKernelVersionCheck{MinKernelVersion: osCheck.Windows.MinKernelVersion}
}
}
if cd.GeoLocationCheck != nil {
geoCheck := cd.GeoLocationCheck
cdCopy.GeoLocationCheck = &GeoLocationCheck{
Action: geoCheck.Action,
Locations: make([]Location, len(geoCheck.Locations)),
}
copy(cd.GeoLocationCheck.Locations, geoCheck.Locations)
}
return cdCopy
}

// TableName returns the name of the table for the Checks model in the database.
func (*Checks) TableName() string {
return "posture_checks"
}

// Copy returns a copy of a policy rule.
// Copy returns a copy of a posture checks.
func (pc *Checks) Copy() *Checks {
checks := &Checks{
ID: pc.ID,
Name: pc.Name,
Description: pc.Description,
AccountID: pc.AccountID,
Checks: make([]Check, len(pc.Checks)),
Checks: pc.Checks.Copy(),
}
copy(checks.Checks, pc.Checks)
return checks
}

Expand All @@ -58,73 +100,17 @@ func (pc *Checks) EventMeta() map[string]any {
return map[string]any{"name": pc.Name}
}

// MarshalJSON returns the JSON encoding of the Checks object.
// The Checks object is marshaled as a map[string]json.RawMessage,
// where the key is the name of the check and the value is the JSON
// representation of the Check object.
func (pc *Checks) MarshalJSON() ([]byte, error) {
type Alias Checks
return json.Marshal(&struct {
Checks map[string]json.RawMessage
*Alias
}{
Checks: pc.marshalChecks(),
Alias: (*Alias)(pc),
})
}

// UnmarshalJSON unmarshal the JSON data into the Checks object.
func (pc *Checks) UnmarshalJSON(data []byte) error {
type Alias Checks
aux := &struct {
Checks map[string]json.RawMessage
*Alias
}{
Alias: (*Alias)(pc),
// GetChecks returns list of all initialized checks definitions
func (pc *Checks) GetChecks() []Check {
var checks []Check
if pc.Checks.NBVersionCheck != nil {
checks = append(checks, pc.Checks.NBVersionCheck)
}

if err := json.Unmarshal(data, &aux); err != nil {
return err
if pc.Checks.OSVersionCheck != nil {
checks = append(checks, pc.Checks.OSVersionCheck)
}
return pc.unmarshalChecks(aux.Checks)
}

func (pc *Checks) marshalChecks() map[string]json.RawMessage {
result := make(map[string]json.RawMessage)
for _, check := range pc.Checks {
data, err := json.Marshal(check)
if err != nil {
return result
}
result[check.Name()] = data
if pc.Checks.GeoLocationCheck != nil {
checks = append(checks, pc.Checks.GeoLocationCheck)
}
return result
}

func (pc *Checks) unmarshalChecks(rawChecks map[string]json.RawMessage) error {
pc.Checks = make([]Check, 0, len(rawChecks))

for name, rawCheck := range rawChecks {
switch name {
case NBVersionCheckName:
check := &NBVersionCheck{}
if err := json.Unmarshal(rawCheck, check); err != nil {
return err
}
pc.Checks = append(pc.Checks, check)
case OSVersionCheckName:
check := &OSVersionCheck{}
if err := json.Unmarshal(rawCheck, check); err != nil {
return err
}
pc.Checks = append(pc.Checks, check)
case GeoLocationCheckName:
check := &GeoLocationCheck{}
if err := json.Unmarshal(rawCheck, check); err != nil {
return err
}
pc.Checks = append(pc.Checks, check)
}
}
return nil
return checks
}
Loading
Loading