Skip to content

Commit

Permalink
Merge pull request #94 from fqutishat/update
Browse files Browse the repository at this point in the history
chore: add api token
  • Loading branch information
fqutishat authored Apr 5, 2022
2 parents 81fdd1a + 3fefbee commit f578b6b
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 12 deletions.
70 changes: 69 additions & 1 deletion cmd/vct/startcmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package startcmd

import (
"context"
"crypto/subtle"
"crypto/tls"
"encoding/json"
"errors"
Expand Down Expand Up @@ -186,8 +187,18 @@ const (

trillianDBConnFlagName = "trillian-db-conn"
trillianDBConnFlagUsage = "Trillian db conn" +
" Alternatively, this can be set with the following environment variable: " + contextProviderEnvKey
" Alternatively, this can be set with the following environment variable: " + trillianDBConnEnvKey
trillianDBConnEnvKey = envPrefix + "TRILLIAN_DB_CONN"

readTokenFlagName = "api-read-token"
readTokenFlagUsage = "Check for bearer token in the authorization header (optional). " +
" Alternatively, this can be set with the following environment variable: " + readTokenEnvKey
readTokenEnvKey = envPrefix + "API_READ_TOKEN"

writeTokenFlagName = "api-write-token"
writeTokenFlagUsage = "Check for bearer token in the authorization header (optional). " +
" Alternatively, this can be set with the following environment variable: " + writeTokenEnvKey
writeTokenEnvKey = envPrefix + "API_WRITE_TOKEN"
)

const (
Expand All @@ -204,6 +215,8 @@ const (
embeddedLogSignerHost = "0.0.0.0:8099"
defaultTimeout = "0"
defaultSyncTimeout = "3"
healthCheckEndpoint = "/healthcheck"
addVCEndpoint = "/add-vc"
)

type (
Expand Down Expand Up @@ -289,6 +302,8 @@ type agentParameters struct {
server server
devMode bool
kmsParams *kmsParameters
readToken string
writeToken string
}

type tlsParameters struct {
Expand Down Expand Up @@ -391,6 +406,9 @@ func createStartCMD(server server) *cobra.Command { //nolint: funlen,gocognit,go
return err
}

readToken := cmdutils.GetUserSetOptionalVarFromString(cmd, readTokenFlagName, readTokenEnvKey)
writeToken := cmdutils.GetUserSetOptionalVarFromString(cmd, writeTokenFlagName, writeTokenEnvKey)

if datasourceName == "" {
datasourceName = "mem://test"
}
Expand Down Expand Up @@ -515,6 +533,8 @@ func createStartCMD(server server) *cobra.Command { //nolint: funlen,gocognit,go
devMode: devMode,
contextProviderURLs: contextProviderURLs,
kmsParams: kmsParams,
readToken: readToken,
writeToken: writeToken,
}

return startAgent(parameters)
Expand Down Expand Up @@ -804,6 +824,10 @@ func startAgent(parameters *agentParameters) error { //nolint:funlen,gocyclo,cyc
}
}

if parameters.readToken != "" || parameters.writeToken != "" {
router.Use(authorizationMiddleware(parameters.readToken, parameters.writeToken))
}

go startMetrics(parameters, metricsRouter)

logger.Infof("Starting vct on host [%s]", parameters.host)
Expand Down Expand Up @@ -943,6 +967,8 @@ func createFlags(startCmd *cobra.Command) {
startCmd.Flags().String(kmsTypeFlagName, "", kmsTypeFlagUsage)
startCmd.Flags().String(kmsEndpointFlagName, "", kmsEndpointFlagUsage)
startCmd.Flags().String(logSignActiveKeyIDFlagName, "", logSignActiveKeyIDFlagUsage)
startCmd.Flags().String(readTokenFlagName, "", readTokenFlagUsage)
startCmd.Flags().String(writeTokenFlagName, "", writeTokenFlagUsage)
}

func getTLS(cmd *cobra.Command) (*tlsParameters, error) {
Expand Down Expand Up @@ -1089,6 +1115,48 @@ func createJSONLDDocumentLoader(ldStore *ldStoreProvider, httpClient *http.Clien
return loader, nil
}

func validateAuthorizationBearerToken(w http.ResponseWriter, r *http.Request, readToken, writeToken string) bool {
if r.RequestURI == healthCheckEndpoint {
return true
}

token := readToken

if strings.Contains(r.RequestURI, addVCEndpoint) {
if writeToken == "" {
return true
}

token = writeToken
}

if token != "" {
actHdr := r.Header.Get("Authorization")
expHdr := "Bearer " + token

if subtle.ConstantTimeCompare([]byte(actHdr), []byte(expHdr)) != 1 {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("Unauthorised.\n")) // nolint:gosec,errcheck

return false
}
}

return true
}

func authorizationMiddleware(readToken, writeToken string) mux.MiddlewareFunc {
middleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if validateAuthorizationBearerToken(w, r, readToken, writeToken) {
next.ServeHTTP(w, r)
}
})
}

return middleware
}

// AWSMetricsProvider aws metrics provider.
type AWSMetricsProvider struct {
signCount monitoring.Counter
Expand Down
20 changes: 20 additions & 0 deletions cmd/vct/startcmd/start_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ const (
tlsCACertsFlagName = "tls-cacerts"
timeoutFlagName = "timeout"
syncTimeoutFlagName = "sync-timeout"
readTokenFlagName = "api-read-token"
)

type mockServer struct{}
Expand Down Expand Up @@ -89,6 +90,7 @@ func TestCmd(t *testing.T) {
"--" + agentHostFlagName, "",
"--" + logsFlagName, "maple2021:rw@localhost:50051",
"--" + kmsTypeFlagName, "local",
"--" + readTokenFlagName, "tk1",
}
startCmd.SetArgs(args)
require.NoError(t, startCmd.Execute())
Expand Down Expand Up @@ -333,6 +335,24 @@ func TestCmd(t *testing.T) {
require.Contains(t, err.Error(), "unsupported kms type")
})

t.Run("kms type empty", func(t *testing.T) {
startCmd, err := startcmd.Cmd(&mockServer{})
require.NoError(t, err)

args := []string{
"--" + agentHostFlagName, ":98989",
"--" + logsFlagName, "11111:rw@https://vct.example.com",
"--" + datasourceNameFlagName, "mem://test",
"--" + tlsCACertsFlagName, "invalid",
}
startCmd.SetArgs(args)

err = startCmd.Execute()
require.Error(t, err)
require.Contains(t, err.Error(),
"Neither kms-type (command line flag) nor VCT_KMS_TYPE (environment variable) have been set.")
})

t.Run("failed to get region", func(t *testing.T) {
startCmd, err := startcmd.Cmd(&mockServer{})
require.NoError(t, err)
Expand Down
54 changes: 45 additions & 9 deletions pkg/client/vct/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ import (
)

type clientOptions struct {
http HTTPClient
http HTTPClient
authReadToken string
authWriteToken string
}

// ClientOpt represents client option func.
Expand All @@ -43,15 +45,31 @@ func WithHTTPClient(client HTTPClient) ClientOpt {
}
}

// WithAuthReadToken add auth token.
func WithAuthReadToken(authToken string) ClientOpt {
return func(o *clientOptions) {
o.authReadToken = authToken
}
}

// WithAuthWriteToken add auth token.
func WithAuthWriteToken(authToken string) ClientOpt {
return func(o *clientOptions) {
o.authWriteToken = authToken
}
}

// HTTPClient represents HTTP client.
type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}

// Client represents VCT REST client.
type Client struct {
endpoint string
http HTTPClient
endpoint string
http HTTPClient
authReadToken string
authWriteToken string
}

// New returns VCT REST client.
Expand All @@ -65,15 +83,18 @@ func New(endpoint string, opts ...ClientOpt) *Client {
}

return &Client{
endpoint: endpoint,
http: op.http,
endpoint: endpoint,
http: op.http,
authReadToken: op.authReadToken,
authWriteToken: op.authWriteToken,
}
}

// AddVC adds verifiable credential to log.
func (c *Client) AddVC(ctx context.Context, credential []byte) (*command.AddVCResponse, error) {
var result *command.AddVCResponse
if err := c.do(ctx, rest.AddVCPath, &result, withMethod(http.MethodPost), withBody(credential)); err != nil {
if err := c.do(ctx, rest.AddVCPath, &result, withMethod(http.MethodPost), withBody(credential),
withToken(c.authWriteToken)); err != nil {
return nil, fmt.Errorf("add VC: %w", err)
}

Expand All @@ -83,7 +104,7 @@ func (c *Client) AddVC(ctx context.Context, credential []byte) (*command.AddVCRe
// Webfinger returns discovery info.
func (c *Client) Webfinger(ctx context.Context) (*command.WebFingerResponse, error) {
var result *command.WebFingerResponse
if err := c.do(ctx, rest.WebfingerPath, &result); err != nil {
if err := c.do(ctx, rest.WebfingerPath, &result, withToken(c.authReadToken)); err != nil {
return nil, fmt.Errorf("webfinger: %w", err)
}

Expand All @@ -93,7 +114,7 @@ func (c *Client) Webfinger(ctx context.Context) (*command.WebFingerResponse, err
// GetIssuers returns issuers.
func (c *Client) GetIssuers(ctx context.Context) ([]string, error) {
var result []string
if err := c.do(ctx, rest.GetIssuersPath, &result); err != nil {
if err := c.do(ctx, rest.GetIssuersPath, &result, withToken(c.authReadToken)); err != nil {
return nil, fmt.Errorf("get issuers: %w", err)
}

Expand All @@ -103,7 +124,7 @@ func (c *Client) GetIssuers(ctx context.Context) ([]string, error) {
// GetSTH retrieves latest signed tree head.
func (c *Client) GetSTH(ctx context.Context) (*command.GetSTHResponse, error) {
var result *command.GetSTHResponse
if err := c.do(ctx, rest.GetSTHPath, &result); err != nil {
if err := c.do(ctx, rest.GetSTHPath, &result, withToken(c.authReadToken)); err != nil {
return nil, fmt.Errorf("get STH: %w", err)
}

Expand All @@ -120,6 +141,7 @@ func (c *Client) GetSTHConsistency(ctx context.Context, first, second uint64) (*
opts := []opt{
withValueAdd(firstParamName, strconv.FormatUint(first, 10)),
withValueAdd(secondParamName, strconv.FormatUint(second, 10)),
withToken(c.authReadToken),
}

var result *command.GetSTHConsistencyResponse
Expand All @@ -140,6 +162,7 @@ func (c *Client) GetProofByHash(ctx context.Context, hash string, treeSize uint6
opts := []opt{
withValueAdd(hashParamName, hash),
withValueAdd(treeSizeParamName, strconv.FormatUint(treeSize, 10)),
withToken(c.authReadToken),
}

var result *command.GetProofByHashResponse
Expand All @@ -160,6 +183,7 @@ func (c *Client) GetEntries(ctx context.Context, start, end uint64) (*command.Ge
opts := []opt{
withValueAdd(startParamName, strconv.FormatUint(start, 10)),
withValueAdd(endParamName, strconv.FormatUint(end, 10)),
withToken(c.authReadToken),
}

var result *command.GetEntriesResponse
Expand All @@ -180,6 +204,7 @@ func (c *Client) GetEntryAndProof(ctx context.Context, leafIndex, treeSize uint6
opts := []opt{
withValueAdd(leafIndexParamName, strconv.FormatUint(leafIndex, 10)),
withValueAdd(treeSizeParamName, strconv.FormatUint(treeSize, 10)),
withToken(c.authReadToken),
}

var result *command.GetEntryAndProofResponse
Expand Down Expand Up @@ -235,6 +260,7 @@ type options struct {
method string
body io.Reader
values url.Values
token string
}

type opt func(*options)
Expand All @@ -257,6 +283,12 @@ func withMethod(val string) opt {
}
}

func withToken(val string) opt {
return func(o *options) {
o.token = val
}
}

func (c *Client) do(ctx context.Context, path string, v interface{}, opts ...opt) error {
op := &options{method: http.MethodGet, values: url.Values{}}
for _, fn := range opts {
Expand All @@ -270,6 +302,10 @@ func (c *Client) do(ctx context.Context, path string, v interface{}, opts ...opt
return fmt.Errorf("new request with context: %w", err)
}

if op.token != "" {
req.Header.Add("Authorization", "Bearer "+op.token)
}

resp, err := c.http.Do(req)
if err != nil {
return fmt.Errorf("http do: %w", err)
Expand Down
3 changes: 2 additions & 1 deletion pkg/client/vct/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ func TestClient_AddVC(t *testing.T) {
StatusCode: http.StatusOK,
}, nil)

client := vct.New(endpoint, vct.WithHTTPClient(httpClient))
client := vct.New(endpoint, vct.WithHTTPClient(httpClient), vct.WithAuthReadToken("tk1"),
vct.WithAuthWriteToken("tk2"))
resp, err := client.AddVC(context.Background(), expectedCredential)
require.NoError(t, err)

Expand Down
4 changes: 4 additions & 0 deletions test/bdd/fixtures/vct/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ services:
- AWS_ACCESS_KEY_ID=mock
- AWS_SECRET_ACCESS_KEY=mock
- VCT_TIMEOUT=60
- VCT_API_READ_TOKEN=tk1
- VCT_API_WRITE_TOKEN=tk2
- VCT_DSN=postgres://postgres:password@vct.postgres:5432
- VCT_DATABASE_PREFIX=vctdb_
- VCT_ISSUERS=maple2021@did:key:zUC724vuGvHpnCGFG1qqpXb81SiBLu3KLSqVzenwEZNPoY35i2Bscb8DLaVwHvRFs6F2NkNNXRcPWvqnPDUd9ukdjLkjZd3u9zzL4wDZDUpkPAatLDGLEYVo8kkAzuAKJQMr7N2,maple2020@did:key:zUC724vuGvHpnCGFG1qqpXb81SiBLu3KLSqVzenwEZNPoY35i2Bscb8DLaVwHvRFs6F2NkNNXRcPWvqnPDUd9ukdjLkjZd3u9zzL4wDZDUpkPAatLDGLEYVo8kkAzuAKJQMr7N7
Expand Down Expand Up @@ -137,6 +139,8 @@ services:
- AWS_ACCESS_KEY_ID=mock
- AWS_SECRET_ACCESS_KEY=mock
- VCT_TIMEOUT=60
- VCT_API_READ_TOKEN=tk1
- VCT_API_WRITE_TOKEN=tk2
- VCT_DSN=postgres://postgres:password@vct.postgres:5432
- VCT_DATABASE_PREFIX=vctdb_
- VCT_ISSUERS=maple2021@did:key:zUC724vuGvHpnCGFG1qqpXb81SiBLu3KLSqVzenwEZNPoY35i2Bscb8DLaVwHvRFs6F2NkNNXRcPWvqnPDUd9ukdjLkjZd3u9zzL4wDZDUpkPAatLDGLEYVo8kkAzuAKJQMr7N2,maple2020@did:key:zUC724vuGvHpnCGFG1qqpXb81SiBLu3KLSqVzenwEZNPoY35i2Bscb8DLaVwHvRFs6F2NkNNXRcPWvqnPDUd9ukdjLkjZd3u9zzL4wDZDUpkPAatLDGLEYVo8kkAzuAKJQMr7N7
Expand Down
3 changes: 2 additions & 1 deletion test/bdd/pkg/controller/rest/rest_controller_steps.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ func (s *Steps) issuerIsNotSupported(issuer string) error {
}

func (s *Steps) setVCTClient(endpoint string) error {
s.vct = vct.New(endpoint, vct.WithHTTPClient(s.client))
s.vct = vct.New(endpoint, vct.WithHTTPClient(s.client), vct.WithAuthReadToken("tk1"),
vct.WithAuthWriteToken("tk2"))

return backoff.Retry(func() error { // nolint: wrapcheck
resp, err := s.vct.GetSTH(context.Background())
Expand Down

0 comments on commit f578b6b

Please sign in to comment.