Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to /fleet/queries/run endpoint #14909

Merged
merged 4 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changes/11446-queries-run-when-forbidden
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fixes to /fleet/queries/run endpoint:
- now returns 403 for an unauthorized user
- now returns 400 when query_ids or host_ids are not specified
7 changes: 6 additions & 1 deletion server/authz/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,13 @@ func (e *Forbidden) Internal() string {

// LogFields allows this error to be logged with subject, object, and action.
func (e *Forbidden) LogFields() []interface{} {
// Only logging User's email, and not other details such as Password and Salt.
email := "nil"
if e.subject != nil {
email = e.subject.Email
}
return []interface{}{
"subject", e.subject,
"subject", email,
lucasmrod marked this conversation as resolved.
Show resolved Hide resolved
"object", e.object,
"action", e.action,
}
Expand Down
2 changes: 1 addition & 1 deletion server/fleet/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ type Service interface {

GetCampaignReader(ctx context.Context, campaign *DistributedQueryCampaign) (<-chan interface{}, context.CancelFunc, error)
CompleteCampaign(ctx context.Context, campaign *DistributedQueryCampaign) error
RunLiveQueryDeadline(ctx context.Context, queryIDs []uint, hostIDs []uint, deadline time.Duration) ([]QueryCampaignResult, int)
RunLiveQueryDeadline(ctx context.Context, queryIDs []uint, hostIDs []uint, deadline time.Duration) ([]QueryCampaignResult, int, error)

// /////////////////////////////////////////////////////////////////////////////
// AgentOptionsService
Expand Down
6 changes: 5 additions & 1 deletion server/service/http_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,11 @@ func setupAuthTest(t *testing.T) (fleet.Datastore, map[string]fleet.User, *httpt
}

func getTestAdminToken(t *testing.T, server *httptest.Server) string {
testUser := testUsers["admin1"]
return getTestUserToken(t, server, "admin1")
}

func getTestUserToken(t *testing.T, server *httptest.Server, testUserId string) string {
testUser := testUsers[testUserId]

params := loginRequest{
Email: testUser.Email,
Expand Down
5 changes: 2 additions & 3 deletions server/service/integration_enterprise_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3186,12 +3186,11 @@ func (s *integrationEnterpriseTestSuite) TestGitOpsUserActions() {
require.Equal(t, "https://foobar.example.com", acr.AppConfig.WebhookSettings.VulnerabilitiesWebhook.DestinationURL)

// Attempt to run live queries synchronously, should fail.
// TODO(lucas): This is a bug, the synchronous live query API should return 403 but currently returns 200.
// It doesn't run the query but incorrectly returns a 200.
s.DoJSON("GET", "/api/latest/fleet/queries/run", runLiveQueryRequest{
HostIDs: []uint{h1.ID},
QueryIDs: []uint{q1.ID},
}, http.StatusOK, &runLiveQueryResponse{})
}, http.StatusForbidden, &runLiveQueryResponse{},
)

// Attempt to run live queries asynchronously (new unsaved query), should fail.
s.DoJSON("POST", "/api/latest/fleet/queries/run", createDistributedQueryCampaignRequest{
Expand Down
175 changes: 175 additions & 0 deletions server/service/integration_live_queries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"testing"
"time"

"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/live_query/live_query_mock"
"github.com/fleetdm/fleet/v4/server/ptr"
Expand Down Expand Up @@ -346,6 +347,180 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestMultipleHostMultipleQuery() {
}
}

// TestLiveQueriesSomeFailToAuthorize when a user requests to run a mix of authorized and unauthorized queries
func (s *liveQueriesTestSuite) TestLiveQueriesSomeFailToAuthorize() {
t := s.T()

host := s.hosts[0]

// Unauthorized query
q1, err := s.ds.NewQuery(
context.Background(), &fleet.Query{
Query: "select 1 from osquery;",
Description: "desc1",
Name: t.Name() + "query1",
Logging: fleet.LoggingSnapshot,
},
)
require.NoError(t, err)

// Authorized query
q2, err := s.ds.NewQuery(
context.Background(), &fleet.Query{
Query: "select 2 from osquery;",
Description: "desc2",
Name: t.Name() + "query2",
Logging: fleet.LoggingSnapshot,
ObserverCanRun: true,
},
)
require.NoError(t, err)

s.lq.On("QueriesForHost", uint(1)).Return(map[string]string{fmt.Sprint(q1.ID): "select 2 from osquery;"}, nil)
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
s.lq.On("RunQuery", mock.Anything, "select 2 from osquery;", []uint{host.ID}).Return(nil)
s.lq.On("StopQuery", mock.Anything).Return(nil)

// Switch to observer user.
originalToken := s.token
s.token = getTestUserToken(t, s.server, "user2")
defer func() {
s.token = originalToken
}()

liveQueryRequest := runLiveQueryRequest{
QueryIDs: []uint{q1.ID, q2.ID},
HostIDs: []uint{host.ID},
}
liveQueryResp := runLiveQueryResponse{}

wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
}()

// Give the above call a couple of seconds to create the campaign
time.Sleep(2 * time.Second)

cid2 := getCIDForQ(s, q2)

distributedReq := SubmitDistributedQueryResultsRequest{
NodeKey: *host.NodeKey,
Results: map[string][]map[string]string{
hostDistributedQueryPrefix + cid2: {{"col3": "c", "col4": "d"}, {"col3": "e", "col4": "f"}},
},
Statuses: map[string]fleet.OsqueryStatus{
hostDistributedQueryPrefix + cid2: 0,
},
Messages: map[string]string{
hostDistributedQueryPrefix + cid2: "some other msg",
},
}
distributedResp := submitDistributedQueryResultsResponse{}
s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)

wg.Wait()

require.Len(t, liveQueryResp.Results, 2)
assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount)

sort.Slice(
liveQueryResp.Results, func(i, j int) bool {
return liveQueryResp.Results[i].QueryID < liveQueryResp.Results[j].QueryID
},
)

require.True(t, q1.ID < q2.ID)

assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
assert.Nil(t, liveQueryResp.Results[0].Results)
assert.Equal(t, authz.ForbiddenErrorMessage, *liveQueryResp.Results[0].Error)

assert.Equal(t, q2.ID, liveQueryResp.Results[1].QueryID)
require.Len(t, liveQueryResp.Results[1].Results, 1)
q2Results := liveQueryResp.Results[1].Results[0]
require.Len(t, q2Results.Rows, 2)
assert.Equal(t, "c", q2Results.Rows[0]["col3"])
assert.Equal(t, "d", q2Results.Rows[0]["col4"])
assert.Equal(t, "e", q2Results.Rows[1]["col3"])
assert.Equal(t, "f", q2Results.Rows[1]["col4"])
}

// TestLiveQueriesInvalidInput without query/host IDs
func (s *liveQueriesTestSuite) TestLiveQueriesInvalidInputs() {
t := s.T()

host := s.hosts[0]

q1, err := s.ds.NewQuery(
context.Background(), &fleet.Query{
Query: "select 1 from osquery;",
Description: "desc1",
Name: t.Name() + "query1",
Logging: fleet.LoggingSnapshot,
},
)
require.NoError(t, err)

liveQueryRequest := runLiveQueryRequest{
QueryIDs: []uint{},
HostIDs: []uint{host.ID},
}
liveQueryResp := runLiveQueryResponse{}
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)

liveQueryRequest = runLiveQueryRequest{
QueryIDs: []uint{q1.ID},
HostIDs: []uint{},
}
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)

liveQueryRequest = runLiveQueryRequest{
QueryIDs: nil,
HostIDs: []uint{host.ID},
}
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)

liveQueryRequest = runLiveQueryRequest{
QueryIDs: []uint{q1.ID},
HostIDs: nil,
}
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)
}

// TestLiveQueriesFailsToAuthorize when an observer tries to run a live query
func (s *liveQueriesTestSuite) TestLiveQueriesFailsToAuthorize() {
t := s.T()

host := s.hosts[0]

q1, err := s.ds.NewQuery(
context.Background(), &fleet.Query{
Query: "select 1 from osquery;",
Description: "desc1",
Name: t.Name() + "query1",
Logging: fleet.LoggingSnapshot,
},
)
require.NoError(t, err)

liveQueryRequest := runLiveQueryRequest{
QueryIDs: []uint{q1.ID},
HostIDs: []uint{host.ID},
}
liveQueryResp := runLiveQueryResponse{}

// Switch to observer user.
originalToken := s.token
s.token = getTestUserToken(t, s.server, "user2")
defer func() {
s.token = originalToken
}()
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusForbidden, &liveQueryResp)
}

func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsToCreateCampaign() {
t := s.T()

Expand Down
32 changes: 29 additions & 3 deletions server/service/live_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"sync"
"time"

"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/contexts/logging"
"github.com/fleetdm/fleet/v4/server/fleet"
Expand Down Expand Up @@ -48,21 +50,45 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se
logging.WithExtras(ctx, "live_query_rest_period_err", err)
}

// Only allow a host to be specified once in HostIDs
req.HostIDs = server.RemoveDuplicatesFromSlice(req.HostIDs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity: Is there a bug or just a sanity check to not cause unnecessary load?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small bug. User can specify the same host twice, but service will only return 1 result. So, TargetedHostCount(2) will never match RespondedHostCount(1).

res := runLiveQueryResponse{
Summary: summaryPayload{
TargetedHostCount: len(req.HostIDs),
RespondedHostCount: 0,
},
}

queryResults, respondedHostCount := svc.RunLiveQueryDeadline(ctx, req.QueryIDs, req.HostIDs, duration)
queryResults, respondedHostCount, err := svc.RunLiveQueryDeadline(ctx, req.QueryIDs, req.HostIDs, duration)
if err != nil {
return nil, err
}
// Check if all query results were forbidden due to lack of authorization.
allResultsForbidden := len(queryResults) > 0 && respondedHostCount == 0
if allResultsForbidden {
for _, r := range queryResults {
if r.Error == nil || *r.Error != authz.ForbiddenErrorMessage {
allResultsForbidden = false
break
}
}
}
if allResultsForbidden {
getvictor marked this conversation as resolved.
Show resolved Hide resolved
return nil, authz.ForbiddenWithInternal("All Live Query results were forbidden.", authz.UserFromContext(ctx), nil, nil)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's check with the product team what the expected behavior is when sending multiple query_ids and the user is not authorized to run some of them:

  • Should the request fail and no queries be executed?
  • Should the request not fail and only run the queries that the user is authorized to run?

API: https://fleetdm.com/docs/rest-api/rest-api#parameters97

@noahtalerman @rachaelshaw

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marko-lisica did we define a similar behavior for running MDM commands that we can borrow here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current behavior for mix of authorized/unauthorized live queries is that user will get back an array of results. Good results will be valid, and unauthorized results will have "error":"forbidden"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Could we take the chance to document this behavior in the rest-api.md?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Can be done later on another PR.)

Copy link
Member

@rachaelshaw rachaelshaw Nov 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After chatting with Victor about it earlier, this behavior of mixed results makes sense to me, but I definitely agree we should document the behavior. @getvictor if you don't mind adding that to this PR, that'd be awesome. Or I'd be happy to take a stab at it after this is merged, just let me know

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@noahtalerman In the CLI we have error message for this use case - figma link. Regarding API, seems there's 403: forbidden error, but not sure when do we return this one.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rachaelshaw I added PR #14956 for rest-api.md updates.

res.Results = queryResults
res.Summary.RespondedHostCount = respondedHostCount

return res, nil
}

func (svc *Service) RunLiveQueryDeadline(ctx context.Context, queryIDs []uint, hostIDs []uint, deadline time.Duration) ([]fleet.QueryCampaignResult, int) {
func (svc *Service) RunLiveQueryDeadline(
ctx context.Context, queryIDs []uint, hostIDs []uint, deadline time.Duration,
) ([]fleet.QueryCampaignResult, int, error) {
if len(queryIDs) == 0 || len(hostIDs) == 0 {
svc.authz.SkipAuthorization(ctx)
return nil, 0, ctxerr.Wrap(ctx, badRequest("query_ids and host_ids are required"))
}
wg := sync.WaitGroup{}

resultsCh := make(chan fleet.QueryCampaignResult)
Expand Down Expand Up @@ -132,7 +158,7 @@ func (svc *Service) RunLiveQueryDeadline(ctx context.Context, queryIDs []uint, h
results = append(results, result)
}

return results, len(respondedHostIDs)
return results, len(respondedHostIDs), nil
}

func (svc *Service) GetCampaignReader(ctx context.Context, campaign *fleet.DistributedQueryCampaign) (<-chan interface{}, context.CancelFunc, error) {
Expand Down
15 changes: 15 additions & 0 deletions server/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,18 @@ func Base64DecodePaddingAgnostic(s string) ([]byte, error) {
us := strings.TrimRight(s, string(base64.StdPadding))
return base64.RawStdEncoding.DecodeString(us)
}

// RemoveDuplicatesFromSlice returns a slice with all the duplicates removed from the input slice.
func RemoveDuplicatesFromSlice[T comparable](slice []T) []T {
// We are using the allKeys map as a set here
allKeys := make(map[T]struct{})
var list []T

for _, i := range slice {
if _, exists := allKeys[i]; !exists {
allKeys[i] = struct{}{}
list = append(list, i)
}
}
return list
}
27 changes: 27 additions & 0 deletions server/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,30 @@ func TestBase64DecodePaddingAgnostic(t *testing.T) {
require.Equal(t, got, c.want)
}
}

func TestRemoveDuplicatesFromSlice(t *testing.T) {
tests := map[string]struct {
input []interface{}
output []interface{}
}{
"no duplicates": {
input: []interface{}{34, 56, 1},
output: []interface{}{34, 56, 1},
},
"1 duplicate": {
input: []interface{}{"a", "d", "a"},
output: []interface{}{"a", "d"},
},
"all duplicates": {
input: []interface{}{true, true, true},
output: []interface{}{true},
},
}
for name, test := range tests {
t.Run(
name, func(t *testing.T) {
require.Equal(t, test.output, RemoveDuplicatesFromSlice(test.input))
},
)
}
}