Skip to content

Commit

Permalink
Fixes to /fleet/queries/run endpoint (#14909)
Browse files Browse the repository at this point in the history
Fixes to /fleet/queries/run endpoint:
- now returns 403 for an unauthorized user
- now returns 400 when query_ids or host_ids are not specified

#11446 and #11901

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

API clarifications are in a separate PR
#14956

- [x] Changes file added for user-visible changes in `changes/` or
`orbit/changes/`.
See [Changes
files](https://fleetdm.com/docs/contributing/committing-changes#changes-files)
for more information.
- [x] Added/updated tests
- [x] Manual QA for all new/changed functionality
  • Loading branch information
getvictor authored Nov 6, 2023
1 parent 5391b68 commit f38524a
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 9 deletions.
3 changes: 3 additions & 0 deletions changes/11446-queries-run-when-forbidden
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fixes to /fleet/queries/run endpoint:
- now returns 403 for an unauthorized user
- now returns 400 when query_ids or host_ids are not specified
7 changes: 6 additions & 1 deletion server/authz/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,13 @@ func (e *Forbidden) Internal() string {

// LogFields allows this error to be logged with subject, object, and action.
func (e *Forbidden) LogFields() []interface{} {
// Only logging User's email, and not other details such as Password and Salt.
email := "nil"
if e.subject != nil {
email = e.subject.Email
}
return []interface{}{
"subject", e.subject,
"subject", email,
"object", e.object,
"action", e.action,
}
Expand Down
2 changes: 1 addition & 1 deletion server/fleet/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ type Service interface {

GetCampaignReader(ctx context.Context, campaign *DistributedQueryCampaign) (<-chan interface{}, context.CancelFunc, error)
CompleteCampaign(ctx context.Context, campaign *DistributedQueryCampaign) error
RunLiveQueryDeadline(ctx context.Context, queryIDs []uint, hostIDs []uint, deadline time.Duration) ([]QueryCampaignResult, int)
RunLiveQueryDeadline(ctx context.Context, queryIDs []uint, hostIDs []uint, deadline time.Duration) ([]QueryCampaignResult, int, error)

// /////////////////////////////////////////////////////////////////////////////
// AgentOptionsService
Expand Down
6 changes: 5 additions & 1 deletion server/service/http_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,11 @@ func setupAuthTest(t *testing.T) (fleet.Datastore, map[string]fleet.User, *httpt
}

func getTestAdminToken(t *testing.T, server *httptest.Server) string {
testUser := testUsers["admin1"]
return getTestUserToken(t, server, "admin1")
}

func getTestUserToken(t *testing.T, server *httptest.Server, testUserId string) string {
testUser := testUsers[testUserId]

params := loginRequest{
Email: testUser.Email,
Expand Down
5 changes: 2 additions & 3 deletions server/service/integration_enterprise_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3186,12 +3186,11 @@ func (s *integrationEnterpriseTestSuite) TestGitOpsUserActions() {
require.Equal(t, "https://foobar.example.com", acr.AppConfig.WebhookSettings.VulnerabilitiesWebhook.DestinationURL)

// Attempt to run live queries synchronously, should fail.
// TODO(lucas): This is a bug, the synchronous live query API should return 403 but currently returns 200.
// It doesn't run the query but incorrectly returns a 200.
s.DoJSON("GET", "/api/latest/fleet/queries/run", runLiveQueryRequest{
HostIDs: []uint{h1.ID},
QueryIDs: []uint{q1.ID},
}, http.StatusOK, &runLiveQueryResponse{})
}, http.StatusForbidden, &runLiveQueryResponse{},
)

// Attempt to run live queries asynchronously (new unsaved query), should fail.
s.DoJSON("POST", "/api/latest/fleet/queries/run", createDistributedQueryCampaignRequest{
Expand Down
175 changes: 175 additions & 0 deletions server/service/integration_live_queries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"testing"
"time"

"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/live_query/live_query_mock"
"github.com/fleetdm/fleet/v4/server/ptr"
Expand Down Expand Up @@ -346,6 +347,180 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestMultipleHostMultipleQuery() {
}
}

// TestLiveQueriesSomeFailToAuthorize when a user requests to run a mix of authorized and unauthorized queries
func (s *liveQueriesTestSuite) TestLiveQueriesSomeFailToAuthorize() {
t := s.T()

host := s.hosts[0]

// Unauthorized query
q1, err := s.ds.NewQuery(
context.Background(), &fleet.Query{
Query: "select 1 from osquery;",
Description: "desc1",
Name: t.Name() + "query1",
Logging: fleet.LoggingSnapshot,
},
)
require.NoError(t, err)

// Authorized query
q2, err := s.ds.NewQuery(
context.Background(), &fleet.Query{
Query: "select 2 from osquery;",
Description: "desc2",
Name: t.Name() + "query2",
Logging: fleet.LoggingSnapshot,
ObserverCanRun: true,
},
)
require.NoError(t, err)

s.lq.On("QueriesForHost", uint(1)).Return(map[string]string{fmt.Sprint(q1.ID): "select 2 from osquery;"}, nil)
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
s.lq.On("RunQuery", mock.Anything, "select 2 from osquery;", []uint{host.ID}).Return(nil)
s.lq.On("StopQuery", mock.Anything).Return(nil)

// Switch to observer user.
originalToken := s.token
s.token = getTestUserToken(t, s.server, "user2")
defer func() {
s.token = originalToken
}()

liveQueryRequest := runLiveQueryRequest{
QueryIDs: []uint{q1.ID, q2.ID},
HostIDs: []uint{host.ID},
}
liveQueryResp := runLiveQueryResponse{}

wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
}()

// Give the above call a couple of seconds to create the campaign
time.Sleep(2 * time.Second)

cid2 := getCIDForQ(s, q2)

distributedReq := SubmitDistributedQueryResultsRequest{
NodeKey: *host.NodeKey,
Results: map[string][]map[string]string{
hostDistributedQueryPrefix + cid2: {{"col3": "c", "col4": "d"}, {"col3": "e", "col4": "f"}},
},
Statuses: map[string]fleet.OsqueryStatus{
hostDistributedQueryPrefix + cid2: 0,
},
Messages: map[string]string{
hostDistributedQueryPrefix + cid2: "some other msg",
},
}
distributedResp := submitDistributedQueryResultsResponse{}
s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)

wg.Wait()

require.Len(t, liveQueryResp.Results, 2)
assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount)

sort.Slice(
liveQueryResp.Results, func(i, j int) bool {
return liveQueryResp.Results[i].QueryID < liveQueryResp.Results[j].QueryID
},
)

require.True(t, q1.ID < q2.ID)

assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
assert.Nil(t, liveQueryResp.Results[0].Results)
assert.Equal(t, authz.ForbiddenErrorMessage, *liveQueryResp.Results[0].Error)

assert.Equal(t, q2.ID, liveQueryResp.Results[1].QueryID)
require.Len(t, liveQueryResp.Results[1].Results, 1)
q2Results := liveQueryResp.Results[1].Results[0]
require.Len(t, q2Results.Rows, 2)
assert.Equal(t, "c", q2Results.Rows[0]["col3"])
assert.Equal(t, "d", q2Results.Rows[0]["col4"])
assert.Equal(t, "e", q2Results.Rows[1]["col3"])
assert.Equal(t, "f", q2Results.Rows[1]["col4"])
}

// TestLiveQueriesInvalidInput without query/host IDs
func (s *liveQueriesTestSuite) TestLiveQueriesInvalidInputs() {
t := s.T()

host := s.hosts[0]

q1, err := s.ds.NewQuery(
context.Background(), &fleet.Query{
Query: "select 1 from osquery;",
Description: "desc1",
Name: t.Name() + "query1",
Logging: fleet.LoggingSnapshot,
},
)
require.NoError(t, err)

liveQueryRequest := runLiveQueryRequest{
QueryIDs: []uint{},
HostIDs: []uint{host.ID},
}
liveQueryResp := runLiveQueryResponse{}
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)

liveQueryRequest = runLiveQueryRequest{
QueryIDs: []uint{q1.ID},
HostIDs: []uint{},
}
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)

liveQueryRequest = runLiveQueryRequest{
QueryIDs: nil,
HostIDs: []uint{host.ID},
}
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)

liveQueryRequest = runLiveQueryRequest{
QueryIDs: []uint{q1.ID},
HostIDs: nil,
}
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)
}

// TestLiveQueriesFailsToAuthorize when an observer tries to run a live query
func (s *liveQueriesTestSuite) TestLiveQueriesFailsToAuthorize() {
t := s.T()

host := s.hosts[0]

q1, err := s.ds.NewQuery(
context.Background(), &fleet.Query{
Query: "select 1 from osquery;",
Description: "desc1",
Name: t.Name() + "query1",
Logging: fleet.LoggingSnapshot,
},
)
require.NoError(t, err)

liveQueryRequest := runLiveQueryRequest{
QueryIDs: []uint{q1.ID},
HostIDs: []uint{host.ID},
}
liveQueryResp := runLiveQueryResponse{}

// Switch to observer user.
originalToken := s.token
s.token = getTestUserToken(t, s.server, "user2")
defer func() {
s.token = originalToken
}()
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusForbidden, &liveQueryResp)
}

func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsToCreateCampaign() {
t := s.T()

Expand Down
32 changes: 29 additions & 3 deletions server/service/live_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"sync"
"time"

"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/contexts/logging"
"github.com/fleetdm/fleet/v4/server/fleet"
Expand Down Expand Up @@ -48,21 +50,45 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se
logging.WithExtras(ctx, "live_query_rest_period_err", err)
}

// Only allow a host to be specified once in HostIDs
req.HostIDs = server.RemoveDuplicatesFromSlice(req.HostIDs)
res := runLiveQueryResponse{
Summary: summaryPayload{
TargetedHostCount: len(req.HostIDs),
RespondedHostCount: 0,
},
}

queryResults, respondedHostCount := svc.RunLiveQueryDeadline(ctx, req.QueryIDs, req.HostIDs, duration)
queryResults, respondedHostCount, err := svc.RunLiveQueryDeadline(ctx, req.QueryIDs, req.HostIDs, duration)
if err != nil {
return nil, err
}
// Check if all query results were forbidden due to lack of authorization.
allResultsForbidden := len(queryResults) > 0 && respondedHostCount == 0
if allResultsForbidden {
for _, r := range queryResults {
if r.Error == nil || *r.Error != authz.ForbiddenErrorMessage {
allResultsForbidden = false
break
}
}
}
if allResultsForbidden {
return nil, authz.ForbiddenWithInternal("All Live Query results were forbidden.", authz.UserFromContext(ctx), nil, nil)
}
res.Results = queryResults
res.Summary.RespondedHostCount = respondedHostCount

return res, nil
}

func (svc *Service) RunLiveQueryDeadline(ctx context.Context, queryIDs []uint, hostIDs []uint, deadline time.Duration) ([]fleet.QueryCampaignResult, int) {
func (svc *Service) RunLiveQueryDeadline(
ctx context.Context, queryIDs []uint, hostIDs []uint, deadline time.Duration,
) ([]fleet.QueryCampaignResult, int, error) {
if len(queryIDs) == 0 || len(hostIDs) == 0 {
svc.authz.SkipAuthorization(ctx)
return nil, 0, ctxerr.Wrap(ctx, badRequest("query_ids and host_ids are required"))
}
wg := sync.WaitGroup{}

resultsCh := make(chan fleet.QueryCampaignResult)
Expand Down Expand Up @@ -132,7 +158,7 @@ func (svc *Service) RunLiveQueryDeadline(ctx context.Context, queryIDs []uint, h
results = append(results, result)
}

return results, len(respondedHostIDs)
return results, len(respondedHostIDs), nil
}

func (svc *Service) GetCampaignReader(ctx context.Context, campaign *fleet.DistributedQueryCampaign) (<-chan interface{}, context.CancelFunc, error) {
Expand Down
15 changes: 15 additions & 0 deletions server/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,18 @@ func Base64DecodePaddingAgnostic(s string) ([]byte, error) {
us := strings.TrimRight(s, string(base64.StdPadding))
return base64.RawStdEncoding.DecodeString(us)
}

// RemoveDuplicatesFromSlice returns a slice with all the duplicates removed from the input slice.
func RemoveDuplicatesFromSlice[T comparable](slice []T) []T {
// We are using the allKeys map as a set here
allKeys := make(map[T]struct{})
var list []T

for _, i := range slice {
if _, exists := allKeys[i]; !exists {
allKeys[i] = struct{}{}
list = append(list, i)
}
}
return list
}
27 changes: 27 additions & 0 deletions server/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,30 @@ func TestBase64DecodePaddingAgnostic(t *testing.T) {
require.Equal(t, got, c.want)
}
}

func TestRemoveDuplicatesFromSlice(t *testing.T) {
tests := map[string]struct {
input []interface{}
output []interface{}
}{
"no duplicates": {
input: []interface{}{34, 56, 1},
output: []interface{}{34, 56, 1},
},
"1 duplicate": {
input: []interface{}{"a", "d", "a"},
output: []interface{}{"a", "d"},
},
"all duplicates": {
input: []interface{}{true, true, true},
output: []interface{}{true},
},
}
for name, test := range tests {
t.Run(
name, func(t *testing.T) {
require.Equal(t, test.output, RemoveDuplicatesFromSlice(test.input))
},
)
}
}

0 comments on commit f38524a

Please sign in to comment.