Skip to content

Commit

Permalink
chore: refactor performRefresh to use a cfg struct (#23)
Browse files Browse the repository at this point in the history
* chore: refactor performRefresh to use a cfg struct
* chore: refactor performRefresh into refresher object
  • Loading branch information
kurtisvg authored May 14, 2021
1 parent 65073d0 commit 18c58a1
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 77 deletions.
89 changes: 12 additions & 77 deletions internal/cloudsql/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,8 @@ func (r *refreshResult) IsValid() bool {
// before the previous certificate expires (every 55 minutes).
type Instance struct {
connName

clientLimiter *rate.Limiter
client *sqladmin.Service
key *rsa.PrivateKey
refreshTimeout time.Duration
key *rsa.PrivateKey
r refresher

resultGuard sync.RWMutex
// cur represents the current refreshResult that will be used to create connections. If a valid complete
Expand All @@ -140,11 +137,13 @@ func NewInstance(instance string, client *sqladmin.Service, key *rsa.PrivateKey,
return nil, err
}
i := &Instance{
connName: cn,
clientLimiter: rate.NewLimiter(rate.Every(30*time.Second), 2),
client: client,
key: key,
refreshTimeout: refreshTimeout,
connName: cn,
key: key,
r: refresher{
timeout: refreshTimeout,
clientLimiter: rate.NewLimiter(rate.Every(30*time.Second), 2),
client: client,
},
}
// For the initial refresh operation, set cur = next so that connection requests block
// until the first refresh is complete.
Expand Down Expand Up @@ -185,11 +184,10 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshResult {
res := &refreshResult{}
res.ready = make(chan struct{})
res.timer = time.AfterFunc(d, func() {
ctx, cancel := context.WithTimeout(context.Background(), i.refreshTimeout)
res.md, res.tlsCfg, res.expiry, res.err = performRefresh(ctx, i.client, i.clientLimiter, i.connName, i.key)
cancel()

ctx := context.Background() // TODO: store this in Instance
res.md, res.tlsCfg, res.expiry, res.err = i.r.performRefresh(ctx, i.connName, i.key)
close(res.ready)

// Once the refresh is complete, update "current" with working result and schedule a new refresh
i.resultGuard.Lock()
defer i.resultGuard.Unlock()
Expand All @@ -212,66 +210,3 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshResult {
})
return res
}

// performRefresh immediately performs a full refresh operation using the Cloud SQL Admin API.
func performRefresh(ctx context.Context, client *sqladmin.Service, l *rate.Limiter, cn connName, k *rsa.PrivateKey) (metadata, *tls.Config, time.Time, error) {
// avoid refreshing too often to try not to tax the SQL Admin API quotas
err := l.Wait(ctx)
if err != nil {
return metadata{}, nil, time.Time{}, fmt.Errorf("refresh was throttled until context expired: %w", err)
}

// start async fetching the instance's metadata
type mdRes struct {
md metadata
err error
}
mdC := make(chan mdRes, 1)
go func() {
defer close(mdC)
md, err := fetchMetadata(ctx, client, cn)
mdC <- mdRes{md, err}
}()

// start async fetching the certs
type ecRes struct {
ec tls.Certificate
err error
}
ecC := make(chan ecRes, 1)
go func() {
defer close(ecC)
ec, err := fetchEphemeralCert(ctx, client, cn, k)
ecC <- ecRes{ec, err}
}()

// wait for the results of each operations
var md metadata
select {
case r := <-mdC:
if r.err != nil {
return md, nil, time.Time{}, fmt.Errorf("fetch metadata failed: %w", r.err)
}
md = r.md
case <-ctx.Done():
return md, nil, time.Time{}, fmt.Errorf("refresh failed: %w", ctx.Err())
}
var ec tls.Certificate
select {
case r := <-ecC:
if r.err != nil {
return md, nil, time.Time{}, fmt.Errorf("fetch ephemeral cert failed: %w", r.err)
}
ec = r.ec
case <-ctx.Done():
return md, nil, time.Time{}, fmt.Errorf("refresh failed: %w", ctx.Err())
}

c := createTLSConfig(cn, md, ec)
// This should never not be the case, but we check to avoid a potential nil-pointer
expiry := time.Time{}
if len(c.Certificates) > 0 {
expiry = c.Certificates[0].Leaf.NotAfter
}
return md, c, expiry, nil
}
79 changes: 79 additions & 0 deletions internal/cloudsql/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import (
"encoding/pem"
"errors"
"fmt"
"time"

"golang.org/x/time/rate"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)

Expand Down Expand Up @@ -170,3 +172,80 @@ func genVerifyPeerCertificateFunc(cn connName, pool *x509.CertPool) func(rawCert
return nil
}
}

type refresher struct {
// timeout is the maximum amount of time a refresh operation should be allowed to take.
timeout time.Duration

clientLimiter *rate.Limiter
client *sqladmin.Service
}

// performRefresh immediately performs a full refresh operation using the Cloud SQL Admin API.
func (r refresher) performRefresh(ctx context.Context, cn connName, k *rsa.PrivateKey) (metadata, *tls.Config, time.Time, error) {
ctx, cancel := context.WithTimeout(ctx, r.timeout)
defer cancel()
if ctx.Err() == context.Canceled {
return metadata{}, nil, time.Time{}, ctx.Err()
}

// avoid refreshing too often to try not to tax the SQL Admin API quotas
err := r.clientLimiter.Wait(ctx)
if err != nil {
return metadata{}, nil, time.Time{}, fmt.Errorf("refresh was throttled until context expired: %w", err)
}

// start async fetching the instance's metadata
type mdRes struct {
md metadata
err error
}
mdC := make(chan mdRes, 1)
go func() {
defer close(mdC)
md, err := fetchMetadata(ctx, r.client, cn)
mdC <- mdRes{md, err}
}()

// start async fetching the certs
type ecRes struct {
ec tls.Certificate
err error
}
ecC := make(chan ecRes, 1)
go func() {
defer close(ecC)
ec, err := fetchEphemeralCert(ctx, r.client, cn, k)
ecC <- ecRes{ec, err}
}()

// wait for the results of each operations
var md metadata
select {
case r := <-mdC:
if r.err != nil {
return md, nil, time.Time{}, fmt.Errorf("fetch metadata failed: %w", r.err)
}
md = r.md
case <-ctx.Done():
return md, nil, time.Time{}, fmt.Errorf("refresh failed: %w", ctx.Err())
}
var ec tls.Certificate
select {
case r := <-ecC:
if r.err != nil {
return md, nil, time.Time{}, fmt.Errorf("fetch ephemeral cert failed: %w", r.err)
}
ec = r.ec
case <-ctx.Done():
return md, nil, time.Time{}, fmt.Errorf("refresh failed: %w", ctx.Err())
}

c := createTLSConfig(cn, md, ec)
// This should never not be the case, but we check to avoid a potential nil-pointer
expiry := time.Time{}
if len(c.Certificates) > 0 {
expiry = c.Certificates[0].Leaf.NotAfter
}
return md, c, expiry, nil
}

0 comments on commit 18c58a1

Please sign in to comment.