diff --git a/changes/24790-admx-policies b/changes/24790-admx-policies new file mode 100644 index 000000000000..515825cb48da --- /dev/null +++ b/changes/24790-admx-policies @@ -0,0 +1 @@ +Fixes issue verifying Windows CSP profiles that contain ADMX policies. diff --git a/server/mdm/microsoft/admx/admx.go b/server/mdm/microsoft/admx/admx.go new file mode 100644 index 000000000000..647425ab1f5e --- /dev/null +++ b/server/mdm/microsoft/admx/admx.go @@ -0,0 +1,104 @@ +// Package admx handles ADMX (Administrative Template File) policies for Microsoft MDM server. +// See: https://learn.microsoft.com/en-us/windows/client-management/understanding-admx-backed-policies +// +// ADMX policy payload example: +// +// +// +// +// +// ]]> +package admx + +import ( + "encoding/xml" + "fmt" + "slices" + "strings" +) + +func IsADMX(text string) bool { + // We try to unmarshal the string to see if it looks like a valid ADMX policy + policy, err := unmarshal(text) + if err != nil { + return false + } + return policy.Enabled.Local == "enabled" || policy.Disabled.Local == "disabled" || len(policy.Data) > 0 +} + +func Equal(a, b string) (bool, error) { + aPolicy, err := unmarshal(a) + if err != nil { + return false, fmt.Errorf("unmarshalling ADMX policy a: %w", err) + } + bPolicy, err := unmarshal(b) + if err != nil { + return false, fmt.Errorf("unmarshalling ADMX policy b: %w", err) + } + return aPolicy.Equal(bPolicy), nil +} + +func unmarshal(a string) (admxPolicy, error) { + // We unmarshal into a string to get the CDATA content and decode XML escape characters. + // We wrap the policy in an tag to ensure it can be unmarshalled by the XML decoder. + var unescaped string + err := xml.Unmarshal([]byte(``+a+``), &unescaped) + if err != nil { + return admxPolicy{}, fmt.Errorf("unmarshalling ADMX policy to string: %w", err) + } + // ADMX policy elements are not case-sensitive. For example: and are equivalent + // For simplicity, we compare everything in lowercase. + var policy admxPolicy + err = xml.Unmarshal([]byte(``+strings.ToLower(unescaped)+``), &policy) + if err != nil { + return admxPolicy{}, fmt.Errorf("unmarshalling ADMX policy: %w", err) + } + return policy, nil +} + +type admxPolicy struct { + Enabled xml.Name `xml:"enabled,omitempty"` + Disabled xml.Name `xml:"disabled,omitempty"` + Data []admxPolicyItem `xml:"data"` +} + +func (a admxPolicy) Equal(b admxPolicy) bool { + if a.Disabled.Local != b.Disabled.Local { + return false + } + if a.Disabled.Local == "disabled" { + // If the ADMX policy is disabled, the data is not relevant + return true + } + if a.Enabled.Local != b.Enabled.Local { + return false + } + if len(a.Data) != len(b.Data) { + return false + } + a.sortData() + b.sortData() + for i := range a.Data { + if !a.Data[i].Equal(b.Data[i]) { + return false + } + } + return true +} + +func (a *admxPolicy) sortData() { + slices.SortFunc(a.Data, func(i, j admxPolicyItem) int { + return strings.Compare(i.ID, j.ID) + }) +} + +type admxPolicyItem struct { + ID string `xml:"id,attr"` + Value string `xml:"value,attr"` +} + +func (a admxPolicyItem) Equal(b admxPolicyItem) bool { + return a.ID == b.ID && a.Value == b.Value +} diff --git a/server/mdm/microsoft/admx/admx_test.go b/server/mdm/microsoft/admx/admx_test.go new file mode 100644 index 000000000000..1751c23ac6fc --- /dev/null +++ b/server/mdm/microsoft/admx/admx_test.go @@ -0,0 +1,174 @@ +package admx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsADMX(t *testing.T) { + t.Parallel() + assert.False(t, IsADMX("")) + assert.False(t, IsADMX("not an ADMX policy")) + assert.False(t, IsADMX(``)) + assert.False(t, IsADMX(`]]>`)) + assert.False(t, IsADMX(``)) + assert.True(t, IsADMX(`]]>`)) + assert.True(t, IsADMX(`]]>`)) + assert.True(t, IsADMX(`]]>`)) + assert.True(t, IsADMX( + ` + ]]>`)) + assert.True(t, + IsADMX("<Enabled/><Data id=\"EnableScriptBlockInvocationLogging\" value=\"true\"/><Data id=\"ExecutionPolicy\" value=\"AllSigned\"/><Data id=\"Listbox_ModuleNames\" value=\"*\"/><Data id=\"OutputDirectory\" value=\"false\"/><Data id=\"SourcePathForUpdateHelp\" value=\"false\"/>")) + assert.True(t, IsADMX( + `<Enabled/> + ]]> + + + + ]]>`)) +} + +func TestEqual(t *testing.T) { + t.Parallel() + testCases := []struct { + name, a, b, errorContains string + equal bool + }{ + { + name: "empty policies", + a: "", + b: "", + equal: true, + errorContains: "", + }, + { + name: "enabled policies", + a: "]]>", + b: "<Enabled/>", + equal: true, + errorContains: "", + }, + { + name: "disabled policies", + a: "]]>", + b: "<Disabled/>", + equal: true, + errorContains: "", + }, + { + name: "unequal policies", + a: "]]>", + b: "<enabled/>", + equal: false, + errorContains: "", + }, + { + name: "enabled policies with data", + a: ` + + + + + ]]>`, + b: "<Enabled/><Data id=\"EnableScriptBlockInvocationLogging\" value=\"true\"/><Data id=\"ExecutionPolicy\" value=\"AllSigned\"/><Data id=\"Listbox_ModuleNames\" value=\"*\"/><Data id=\"OutputDirectory\" value=\"false\"/><Data id=\"SourcePathForUpdateHelp\" value=\"false\"/>", + equal: true, + errorContains: "", + }, + { + name: "enabled policies with data and nonstandard format", + a: `<Enabled/> + ]]> + + + + ]]>`, + b: "<Enabled/><Data id=\"EnableScriptBlockInvocationLogging\" value=\"true\"/><Data id=\"ExecutionPolicy\" value=\"AllSigned\"/><Data id=\"Listbox_ModuleNames\" value=\"*\"/><Data id=\"OutputDirectory\" value=\"false\"/><Data id=\"SourcePathForUpdateHelp\" value=\"false\"/>", + equal: true, + errorContains: "", + }, + { + name: "disabled policies with data", + a: ` + + ]]>`, + b: "<Disabled/><Data id=\"EnableScriptBlockInvocationLogging\" value=\"true\"/><Data id=\"ExecutionPolicy\" value=\"AllSigned\"/><Data id=\"Listbox_ModuleNames\" value=\"*\"/><Data id=\"OutputDirectory\" value=\"false\"/><Data id=\"SourcePathForUpdateHelp\" value=\"false\"/>", + equal: true, + errorContains: "", + }, + { + name: "unparsable policy a 1", + a: " + + + + + ]]>`, + b: "<Enabled/><Data id=\"EnableScriptBlockInvocationLogging\" value=\"true\"/><Data id=\"ExecutionPolicy\" value=\"AllSigned\"/><Data id=\"Listbox_ModuleNames\" value=\"*\"/><Data id=\"OutputDirectory\" value=\"false\"/><Data id=\"SourcePathForUpdateHelp\" value=\"false\"/>", + equal: false, + errorContains: "", + }, + { + name: "unequal policies with data 1", + a: ` + + ]]>`, + b: "<Enabled/><Data id=\"EnableScriptBlockInvocationLogging\" value=\"true\"/><Data id=\"ExecutionPolicy\" value=\"AllSigned\"/><Data id=\"Listbox_ModuleNames\" value=\"*\"/><Data id=\"OutputDirectory\" value=\"false\"/><Data id=\"SourcePathForUpdateHelp\" value=\"false\"/>", + equal: false, + errorContains: "", + }, + { + name: "unequal policies with data 2", + a: ` + + + + + ]]>`, + b: "<Enabled/><Data id=\"EnableScriptBlockInvocationLogging\" value=\"true\"/><Data id=\"ExecutionPolicy\" value=\"AllSigned\"/><Data id=\"Listbox_ModuleNames\" value=\"*\"/><Data id=\"OutputDirectory\" value=\"false\"/><Data id=\"SourcePathForUpdateHelp\" value=\"false\"/>", + equal: false, + errorContains: "", + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + equal, err := Equal(tt.a, tt.b) + if tt.errorContains == "" { + assert.NoError(t, err) + } else { + assert.ErrorContains(t, err, tt.errorContains) + } + assert.Equal(t, tt.equal, equal) + }) + } +} diff --git a/server/mdm/microsoft/profile_verifier.go b/server/mdm/microsoft/profile_verifier.go index f63017875159..7c8494f9e5e6 100644 --- a/server/mdm/microsoft/profile_verifier.go +++ b/server/mdm/microsoft/profile_verifier.go @@ -14,6 +14,7 @@ import ( "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/mdm" + "github.com/fleetdm/fleet/v4/server/mdm/microsoft/admx" ) // LoopOverExpectedHostProfiles loops all the values on all the profiles for a @@ -72,20 +73,24 @@ func HashLocURI(profileName, locURI string) string { func VerifyHostMDMProfiles(ctx context.Context, ds fleet.ProfileVerificationStore, host *fleet.Host, rawProfileResultsSyncML []byte) error { profileResults, err := transformProfileResults(rawProfileResultsSyncML) if err != nil { - return err + return ctxerr.Wrap(ctx, err, "transforming policy results") } verified, missing, err := compareResultsToExpectedProfiles(ctx, ds, host, profileResults) if err != nil { - return err + return ctxerr.Wrap(ctx, err, "comparing results to expected profiles") } toFail, toRetry, err := splitMissingProfilesIntoFailAndRetryBuckets(ctx, ds, host, missing, verified) if err != nil { - return err + return ctxerr.Wrap(ctx, err, "splitting missing profiles into fail and retry buckets") } - return ds.UpdateHostMDMProfilesVerification(ctx, host, slices.Collect(maps.Keys(verified)), toFail, toRetry) + err = ds.UpdateHostMDMProfilesVerification(ctx, host, slices.Collect(maps.Keys(verified)), toFail, toRetry) + if err != nil { + return ctxerr.Wrap(ctx, err, "updating host mdm profiles during verification") + } + return nil } func splitMissingProfilesIntoFailAndRetryBuckets(ctx context.Context, ds fleet.ProfileVerificationStore, host *fleet.Host, @@ -127,11 +132,11 @@ func compareResultsToExpectedProfiles(ctx context.Context, ds fleet.ProfileVerif missing = map[string]struct{}{} verified = map[string]struct{}{} err = LoopOverExpectedHostProfiles(ctx, ds, host, func(profile *fleet.ExpectedMDMProfile, ref, locURI, wantData string) { - // if we didn't get a status for a LocURI, mark the profile as - // missing. + // if we didn't get a status for a LocURI, mark the profile as missing. gotStatus, ok := profileResults.cmdRefToStatus[ref] if !ok { missing[profile.Name] = struct{}{} + return } // it's okay if we didn't get a result gotResults := profileResults.cmdRefToResult[ref] @@ -139,7 +144,20 @@ func compareResultsToExpectedProfiles(ctx context.Context, ds fleet.ProfileVerif // TODO: should we be more granular instead? eg: special case // `4xx` responses? I'm sure there are edge cases we're not // accounting for here, but it's unclear at this moment. - if !strings.HasPrefix(gotStatus, "2") || wantData != gotResults { + var equal bool + switch { + case !strings.HasPrefix(gotStatus, "2"): + equal = false + case wantData == gotResults: + equal = true + case admx.IsADMX(wantData): + equal, err = admx.Equal(wantData, gotResults) + if err != nil { + err = fmt.Errorf("comparing ADMX policies: %w", err) + return + } + } + if !equal { withinGracePeriod := profile.IsWithinGracePeriod(host.DetailUpdatedAt) if !withinGracePeriod { missing[profile.Name] = struct{}{} diff --git a/server/mdm/microsoft/profile_verifier_test.go b/server/mdm/microsoft/profile_verifier_test.go index 709c768c8c01..ed16b8a21ab3 100644 --- a/server/mdm/microsoft/profile_verifier_test.go +++ b/server/mdm/microsoft/profile_verifier_test.go @@ -223,6 +223,41 @@ func TestVerifyHostMDMProfilesHappyPaths(t *testing.T) { toFail: []string{"N2", "N4"}, toRetry: []string{"N1"}, }, + { + name: "single profile with CDATA reported and verified", + hostProfiles: []hostProfile{ + {"N1", syncml.ForTestWithData(map[string]string{ + "L1": ` + + + + + ]]>`, + }), 0}, + }, + report: []osqueryReport{{"N1", "200", "L1", + "<Enabled/><Data id=\"EnableScriptBlockInvocationLogging\" value=\"true\"/><Data id=\"ExecutionPolicy\" value=\"AllSigned\"/><Data id=\"Listbox_ModuleNames\" value=\"*\"/><Data id=\"OutputDirectory\" value=\"false\"/><Data id=\"SourcePathForUpdateHelp\" value=\"false\"/>", + }}, + toVerify: []string{"N1"}, + toFail: []string{}, + toRetry: []string{}, + }, + { + name: "single profile with CDATA to retry", + hostProfiles: []hostProfile{ + {"N1", syncml.ForTestWithData(map[string]string{ + "L1": ` + + ]]>`, + }), 0}, + }, + report: []osqueryReport{{"N1", "200", "L1", + "<Enabled/><Data id=\"EnableScriptBlockInvocationLogging\" value=\"true\"/><Data id=\"ExecutionPolicy\" value=\"AllSigned\"/><Data id=\"Listbox_ModuleNames\" value=\"*\"/><Data id=\"OutputDirectory\" value=\"false\"/><Data id=\"SourcePathForUpdateHelp\" value=\"false\"/>", + }}, + toVerify: []string{}, + toFail: []string{}, + toRetry: []string{"N1"}, + }, } for _, tt := range cases {