Skip to content

Commit

Permalink
Adding test for auth/unauth quaries, plus minor fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
getvictor committed Nov 6, 2023
1 parent cca3505 commit f4c7df1
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 8 deletions.
112 changes: 110 additions & 2 deletions server/service/integration_live_queries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"testing"
"time"

"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/live_query/live_query_mock"
"github.com/fleetdm/fleet/v4/server/ptr"
Expand Down Expand Up @@ -346,6 +347,107 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestMultipleHostMultipleQuery() {
}
}

// TestLiveQueriesSomeFailToAuthorize when a user requests to run a mix of authorized and unauthorized queries
func (s *liveQueriesTestSuite) TestLiveQueriesSomeFailToAuthorize() {
t := s.T()

host := s.hosts[0]

// Unauthorized query
q1, err := s.ds.NewQuery(
context.Background(), &fleet.Query{
Query: "select 1 from osquery;",
Description: "desc1",
Name: t.Name() + "query1",
Logging: fleet.LoggingSnapshot,
},
)
require.NoError(t, err)

// Authorized query
q2, err := s.ds.NewQuery(
context.Background(), &fleet.Query{
Query: "select 2 from osquery;",
Description: "desc2",
Name: t.Name() + "query2",
Logging: fleet.LoggingSnapshot,
ObserverCanRun: true,
},
)
require.NoError(t, err)

s.lq.On("QueriesForHost", uint(1)).Return(map[string]string{fmt.Sprint(q1.ID): "select 2 from osquery;"}, nil)
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
s.lq.On("RunQuery", mock.Anything, "select 2 from osquery;", []uint{host.ID}).Return(nil)
s.lq.On("StopQuery", mock.Anything).Return(nil)

// Switch to observer user.
originalToken := s.token
s.token = getTestUserToken(t, s.server, "user2")
defer func() {
s.token = originalToken
}()

liveQueryRequest := runLiveQueryRequest{
QueryIDs: []uint{q1.ID, q2.ID},
HostIDs: []uint{host.ID},
}
liveQueryResp := runLiveQueryResponse{}

wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
}()

// Give the above call a couple of seconds to create the campaign
time.Sleep(2 * time.Second)

cid2 := getCIDForQ(s, q2)

distributedReq := SubmitDistributedQueryResultsRequest{
NodeKey: *host.NodeKey,
Results: map[string][]map[string]string{
hostDistributedQueryPrefix + cid2: {{"col3": "c", "col4": "d"}, {"col3": "e", "col4": "f"}},
},
Statuses: map[string]fleet.OsqueryStatus{
hostDistributedQueryPrefix + cid2: 0,
},
Messages: map[string]string{
hostDistributedQueryPrefix + cid2: "some other msg",
},
}
distributedResp := submitDistributedQueryResultsResponse{}
s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)

wg.Wait()

require.Len(t, liveQueryResp.Results, 2)
assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount)

sort.Slice(
liveQueryResp.Results, func(i, j int) bool {
return liveQueryResp.Results[i].QueryID < liveQueryResp.Results[j].QueryID
},
)

require.True(t, q1.ID < q2.ID)

assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
assert.Nil(t, liveQueryResp.Results[0].Results)
assert.Equal(t, authz.ForbiddenErrorMessage, *liveQueryResp.Results[0].Error)

assert.Equal(t, q2.ID, liveQueryResp.Results[1].QueryID)
require.Len(t, liveQueryResp.Results[1].Results, 1)
q2Results := liveQueryResp.Results[1].Results[0]
require.Len(t, q2Results.Rows, 2)
assert.Equal(t, "c", q2Results.Rows[0]["col3"])
assert.Equal(t, "d", q2Results.Rows[0]["col4"])
assert.Equal(t, "e", q2Results.Rows[1]["col3"])
assert.Equal(t, "f", q2Results.Rows[1]["col4"])
}

// TestLiveQueriesInvalidInput without query/host IDs
func (s *liveQueriesTestSuite) TestLiveQueriesInvalidInputs() {
t := s.T()
Expand Down Expand Up @@ -376,8 +478,14 @@ func (s *liveQueriesTestSuite) TestLiveQueriesInvalidInputs() {
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)

liveQueryRequest = runLiveQueryRequest{
QueryIDs: []uint{},
HostIDs: []uint{},
QueryIDs: nil,
HostIDs: []uint{host.ID},
}
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)

liveQueryRequest = runLiveQueryRequest{
QueryIDs: []uint{q1.ID},
HostIDs: nil,
}
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)
}
Expand Down
17 changes: 11 additions & 6 deletions server/service/live_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"sync"
"time"

"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/contexts/logging"
Expand Down Expand Up @@ -49,6 +50,8 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se
logging.WithExtras(ctx, "live_query_rest_period_err", err)
}

// Only allow a host to be specified once in HostIDs
req.HostIDs = server.RemoveDuplicatesFromSlice(req.HostIDs)
res := runLiveQueryResponse{
Summary: summaryPayload{
TargetedHostCount: len(req.HostIDs),
Expand All @@ -60,12 +63,14 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se
if err != nil {
return nil, err
}
// Check if all query results were forbidden due to lack of authorization
allResultsForbidden := len(queryResults) > 0
for _, r := range queryResults {
if r.Error == nil || *r.Error != authz.ForbiddenErrorMessage {
allResultsForbidden = false
break
// Check if all query results were forbidden due to lack of authorization.
allResultsForbidden := len(queryResults) > 0 && respondedHostCount == 0
if allResultsForbidden {
for _, r := range queryResults {
if r.Error == nil || *r.Error != authz.ForbiddenErrorMessage {
allResultsForbidden = false
break
}
}
}
if allResultsForbidden {
Expand Down
15 changes: 15 additions & 0 deletions server/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,18 @@ func Base64DecodePaddingAgnostic(s string) ([]byte, error) {
us := strings.TrimRight(s, string(base64.StdPadding))
return base64.RawStdEncoding.DecodeString(us)
}

// RemoveDuplicatesFromSlice returns a slice with all the duplicates removed from the input slice.
func RemoveDuplicatesFromSlice[T comparable](slice []T) []T {
// We are using the allKeys map as a set here
allKeys := make(map[T]struct{})
var list []T

for _, i := range slice {
if _, exists := allKeys[i]; !exists {
allKeys[i] = struct{}{}
list = append(list, i)
}
}
return list
}
27 changes: 27 additions & 0 deletions server/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,30 @@ func TestBase64DecodePaddingAgnostic(t *testing.T) {
require.Equal(t, got, c.want)
}
}

func TestRemoveDuplicatesFromSlice(t *testing.T) {
tests := map[string]struct {
input []interface{}
output []interface{}
}{
"no duplicates": {
input: []interface{}{34, 56, 1},
output: []interface{}{34, 56, 1},
},
"1 duplicate": {
input: []interface{}{"a", "d", "a"},
output: []interface{}{"a", "d"},
},
"all duplicates": {
input: []interface{}{true, true, true},
output: []interface{}{true},
},
}
for name, test := range tests {
t.Run(
name, func(t *testing.T) {
require.Equal(t, test.output, RemoveDuplicatesFromSlice(test.input))
},
)
}
}

0 comments on commit f4c7df1

Please sign in to comment.