Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix readyz endpoint not returning correct status when using database dynamic registration #11160

Merged
merged 5 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions integration/db_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ package integration

import (
"context"
"fmt"
"net"
"net/http"
"testing"
"time"

Expand All @@ -31,6 +33,7 @@ import (
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/service"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/db"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/mongodb"
Expand Down Expand Up @@ -588,6 +591,49 @@ func TestDatabaseAccessMongoSeparateListener(t *testing.T) {
require.NoError(t, err)
}

func TestDatabaseAgentState(t *testing.T) {
tests := map[string]struct {
agentParams databaseAgentStartParams
}{
"WithStaticDatabases": {
agentParams: databaseAgentStartParams{
databases: []service.Database{
{Name: "mysql", Protocol: defaults.ProtocolMySQL, URI: "localhost:3306"},
{Name: "pg", Protocol: defaults.ProtocolPostgres, URI: "localhost:5432"},
},
},
},
"WithResourceMatchers": {
agentParams: databaseAgentStartParams{
resourceMatchers: []services.ResourceMatcher{
{Labels: types.Labels{"*": []string{"*"}}},
},
},
},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
pack := setupDatabaseTest(t)

// Start also ensures that the database agent has the “ready” state.
// If the agent can’t make it, this function will fail the test.
agent, _ := pack.startRootDatabaseAgent(t, test.agentParams)

// In addition to the checks performed during the agent start,
// we’ll request the diagnostic server to ensure the readyz route
// is returning to the proper state.
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%v/readyz", agent.Config.DiagnosticAddr.Addr), nil)
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

require.Equal(t, http.StatusOK, resp.StatusCode)
})
}
}

func waitForAuditEventTypeWithBackoff(t *testing.T, cli *auth.Server, startTime time.Time, eventType string) []apievents.AuditEvent {
max := time.Second
timeout := time.After(max)
Expand Down Expand Up @@ -1015,6 +1061,39 @@ func (p *databasePack) waitForLeaf(t *testing.T) {
}
}

// databaseAgentStartParams parameters used to configure a database agent.
type databaseAgentStartParams struct {
databases []service.Database
resourceMatchers []services.ResourceMatcher
}

// startRootDatabaseAgent starts a database agent with the provided
// configuration on the root cluster.
func (p *databasePack) startRootDatabaseAgent(t *testing.T, params databaseAgentStartParams) (*service.TeleportProcess, *auth.Client) {
conf := service.MakeDefaultConfig()
conf.DataDir = t.TempDir()
conf.Token = "static-token-value"
conf.DiagnosticAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("localhost", ports.Pop())}
conf.AuthServers = []utils.NetAddr{
{
AddrNetwork: "tcp",
Addr: net.JoinHostPort(Loopback, p.root.cluster.GetPortWeb()),
},
}
conf.Clock = p.clock
conf.Databases.Enabled = true
conf.Databases.Databases = params.databases
conf.Databases.ResourceMatchers = params.resourceMatchers

server, authClient, err := p.root.cluster.StartDatabase(conf)
require.NoError(t, err)
t.Cleanup(func() {
server.Close()
})

return server, authClient
}

func containsDB(servers []types.DatabaseServer, name string) bool {
for _, server := range servers {
if server.GetDatabase().GetName() == name {
Expand Down
3 changes: 3 additions & 0 deletions lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1700,6 +1700,8 @@ type agentParams struct {
NoStart bool
// GCPSQL defines the GCP Cloud SQL mock to use for GCP API calls.
GCPSQL *cloud.GCPSQLAdminClientMock
// OnHeartbeat defines a heartbeat function that generates heartbeat events.
OnHeartbeat func(error)
}

func (p *agentParams) setDefaults(c *testContext) {
Expand Down Expand Up @@ -1765,6 +1767,7 @@ func (c *testContext) setupDatabaseServer(ctx context.Context, t *testing.T, p a
Limiter: connLimiter,
Auth: testAuth,
Databases: p.Databases,
OnHeartbeat: p.OnHeartbeat,
ResourceMatchers: p.ResourceMatchers,
GetServerInfoFn: p.GetServerInfoFn,
GetRotation: func(types.SystemRole) (*types.Rotation, error) {
Expand Down
6 changes: 6 additions & 0 deletions lib/srv/db/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,12 @@ func (s *Server) Start(ctx context.Context) (err error) {
return trace.Wrap(err)
}

// If the agent doesn’t have any static databases configured, send a
// heartbeat without error to make the component “ready”.
if len(s.cfg.Databases) == 0 && s.cfg.OnHeartbeat != nil {
s.cfg.OnHeartbeat(nil)
}

return nil
}

Expand Down
67 changes: 67 additions & 0 deletions lib/srv/db/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ package db

import (
"context"
"sync/atomic"
"testing"
"time"

apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/limiter"

"github.com/jackc/pgconn"
Expand Down Expand Up @@ -177,3 +180,67 @@ func TestDatabaseServerLimiting(t *testing.T) {
require.FailNow(t, "we should exceed the connection limit by now")
})
}

func TestHeartbeatEvents(t *testing.T) {
ctx := context.Background()

dbOne, err := types.NewDatabaseV3(types.Metadata{
Name: "dbOne",
}, types.DatabaseSpecV3{
Protocol: defaults.ProtocolPostgres,
URI: "localhost:5432",
})
require.NoError(t, err)

dbTwo, err := types.NewDatabaseV3(types.Metadata{
Name: "dbOne",
}, types.DatabaseSpecV3{
Protocol: defaults.ProtocolMySQL,
URI: "localhost:3306",
})
require.NoError(t, err)

tests := map[string]struct {
staticDatabases types.Databases
heartbeatCount int64
}{
"SingleStaticDatabase": {
staticDatabases: types.Databases{dbOne},
heartbeatCount: 1,
},
"MultipleStaticDatabases": {
staticDatabases: types.Databases{dbOne, dbTwo},
heartbeatCount: 2,
},
"EmptyStaticDatabases": {
staticDatabases: types.Databases{},
heartbeatCount: 1,
},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
var heartbeatEvents int64
heartbeatRecorder := func(err error) {
require.NoError(t, err)
atomic.AddInt64(&heartbeatEvents, 1)
}

testCtx := setupTestContext(ctx, t)
server := testCtx.setupDatabaseServer(ctx, t, agentParams{
NoStart: true,
OnHeartbeat: heartbeatRecorder,
Databases: test.staticDatabases,
})
require.NoError(t, server.Start(ctx))
t.Cleanup(func() {
server.Close()
})

require.NotNil(t, server)
require.Eventually(t, func() bool {
return atomic.LoadInt64(&heartbeatEvents) == test.heartbeatCount
}, 2*time.Second, 500*time.Millisecond)
})
}
}