Skip to content

Commit

Permalink
feat: add api key for authz
Browse files Browse the repository at this point in the history
  • Loading branch information
shrimalmadhur committed Jan 16, 2025
1 parent 6d0c13b commit 6dbd9ba
Show file tree
Hide file tree
Showing 14 changed files with 213 additions and 11 deletions.
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,8 @@ docker: ## runs docker build
migrate: ## runs database migrations
go install -tags 'postgres' github.com/golang-migrate/migrate/v4/cmd/migrate@latest
migrate -path internal/database/migrations/ -database "postgres://user:password@localhost:5432/cerberus?sslmode=disable" --verbose up

.PHONY: create-migration
create-migration: ## creates a new database migration
go install -tags 'postgres' github.com/golang-migrate/migrate/v4/cmd/migrate@latest
migrate create -dir internal/database/migrations/ -ext sql $(name)
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ go 1.22.0

toolchain go1.22.3

replace github.com/Layr-Labs/cerberus-api => ../cerberus-api

require (
cloud.google.com/go/secretmanager v1.14.2
github.com/Layr-Labs/bn254-keystore-go v0.0.0-20250107020618-26bd412fae87
Expand All @@ -14,6 +16,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.6
github.com/consensys/gnark-crypto v0.12.1
github.com/golang-migrate/migrate/v4 v4.18.1
github.com/google/uuid v1.6.0
github.com/prometheus/client_golang v1.20.3
github.com/stretchr/testify v1.10.0
github.com/testcontainers/testcontainers-go v0.34.0
Expand All @@ -37,7 +40,6 @@ require (
github.com/docker/go-units v0.5.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg6
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/Layr-Labs/bn254-keystore-go v0.0.0-20250107020618-26bd412fae87 h1:EkaBNT0o8RTgtFeYSKaoNHNbnCVxrcsAyRpUeN29hiQ=
github.com/Layr-Labs/bn254-keystore-go v0.0.0-20250107020618-26bd412fae87/go.mod h1:7J8hptSX8cFq7KmVb+rEO5aEifj7E44c3i0afIyr4WA=
github.com/Layr-Labs/cerberus-api v0.0.2-0.20250108174619-d5e1eb03fbd5 h1:s24M6HYObEuV9OSY36jUM09kp5fOhuz/g1ev2qWDPzU=
github.com/Layr-Labs/cerberus-api v0.0.2-0.20250108174619-d5e1eb03fbd5/go.mod h1:Lm4fhzy0S3P7GjerzuseGaBFVczsIKmEhIjcT52Hluo=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/aws/aws-sdk-go-v2 v1.32.5 h1:U8vdWJuY7ruAkzaOdD7guwJjD06YSKmnKCJs7s3IkIo=
Expand Down
11 changes: 11 additions & 0 deletions internal/common/common.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package common

import (
"crypto/sha256"
"encoding/hex"
)

func Trim0x(s string) string {
if len(s) >= 2 && s[0:2] == "0x" {
return s[2:]
Expand All @@ -13,3 +18,9 @@ func RemovePrefix(s string, prefix string) string {
}
return s
}

func CreateSHA256Hash(s string) string {
hash := sha256.New()
hash.Write([]byte(s))
return hex.EncodeToString(hash.Sum(nil))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE public.keys_metadata ADD COLUMN api_key_hash text;
ALTER TABLE public.keys_metadata ADD COLUMN locked boolean DEFAULT false;
2 changes: 2 additions & 0 deletions internal/database/model/key_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ type KeyMetadata struct {
PublicKeyG2 string `db:"public_key_g2"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
ApiKeyHash string `db:"api_key_hash"`
Locked bool `db:"locked"`
}
1 change: 1 addition & 0 deletions internal/database/repository/key_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type KeyMetadataRepository interface {
Create(ctx context.Context, metadata *model.KeyMetadata) error
Get(ctx context.Context, publicKeyG1 string) (*model.KeyMetadata, error)
Update(ctx context.Context, metadata *model.KeyMetadata) error
UpdateAPIKeyHash(ctx context.Context, metadata *model.KeyMetadata) error
Delete(ctx context.Context, publicKeyG1 string) error
List(ctx context.Context) ([]*model.KeyMetadata, error)
}
33 changes: 30 additions & 3 deletions internal/database/repository/postgres/key_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ func NewKeyMetadataRepository(db *sql.DB) repository.KeyMetadataRepository {
const (
createKeyMetadataQuery = `
INSERT INTO public.keys_metadata (
public_key_g1, public_key_g2, created_at, updated_at
) VALUES ($1, $2, $3, $4)
public_key_g1, public_key_g2, created_at, updated_at, api_key_hash
) VALUES ($1, $2, $3, $4, $5)
`

getKeyMetadataQuery = `
SELECT public_key_g1, public_key_g2, created_at, updated_at
SELECT public_key_g1, public_key_g2, created_at, updated_at, api_key_hash, locked
FROM public.keys_metadata
WHERE public_key_g1 = $1
`
Expand All @@ -39,6 +39,12 @@ const (
WHERE public_key_g1 = $2
`

updateAPIKeyHashQuery = `
UPDATE public.keys_metadata
SET api_key_hash = $1, updated_at = $2
WHERE public_key_g1 = $3
`

deleteKeyMetadataQuery = `
DELETE FROM public.keys_metadata
WHERE public_key_g1 = $1
Expand Down Expand Up @@ -68,6 +74,7 @@ func (r *keyMetadataRepo) Create(ctx context.Context, metadata *model.KeyMetadat
metadata.PublicKeyG2,
metadata.CreatedAt,
metadata.UpdatedAt,
metadata.ApiKeyHash,
)
return err
}
Expand All @@ -79,6 +86,8 @@ func (r *keyMetadataRepo) Get(ctx context.Context, publicKeyG1 string) (*model.K
&metadata.PublicKeyG2,
&metadata.CreatedAt,
&metadata.UpdatedAt,
&metadata.ApiKeyHash,
&metadata.Locked,
)
if err == sql.ErrNoRows {
return nil, errors.New("key metadata not found")
Expand Down Expand Up @@ -113,6 +122,24 @@ func (r *keyMetadataRepo) Update(ctx context.Context, metadata *model.KeyMetadat
return nil
}

func (r *keyMetadataRepo) UpdateAPIKeyHash(ctx context.Context, metadata *model.KeyMetadata) error {
if metadata.PublicKeyG1 == "" {
return errors.New("public key g1 is required")
}
if metadata.ApiKeyHash == "" {
return errors.New("api key hash is required")
}

metadata.UpdatedAt = time.Now().UTC()

_, err := r.db.ExecContext(ctx, updateAPIKeyHashQuery,
metadata.ApiKeyHash,
metadata.UpdatedAt,
metadata.PublicKeyG1,
)
return err
}

func (r *keyMetadataRepo) Delete(ctx context.Context, publicKeyG1 string) error {
if publicKeyG1 == "" {
return errors.New("public key g1 is required")
Expand Down
28 changes: 28 additions & 0 deletions internal/database/repository/postgres/key_metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,31 @@ func TestKeyMetadataRepository_List(t *testing.T) {
assert.Equal(t, "test_key_1", results[1].PublicKeyG1)
})
}

func TestKeyMetadataRepository_UpdateAPIKeyHash(t *testing.T) {
testDB := SetupTestDB(t)
// No need to defer db.Close() as it's handled by t.Cleanup

// Create initial test data
initialKey := &model.KeyMetadata{
PublicKeyG1: "test_key_1",
PublicKeyG2: "test_key_2",
ApiKeyHash: "test_api_key_hash",
}
err := testDB.Repo.Create(context.Background(), initialKey)
require.NoError(t, err)

metadata := &model.KeyMetadata{
PublicKeyG1: "test_key_1",
ApiKeyHash: "test_api_key_hash_2",
}

err = testDB.Repo.UpdateAPIKeyHash(context.Background(), metadata)
require.NoError(t, err)

// Verify the update
result, err := testDB.Repo.Get(context.Background(), metadata.PublicKeyG1)
assert.NoError(t, err)
assert.Equal(t, metadata.ApiKeyHash, result.ApiKeyHash)
assert.WithinDuration(t, time.Now(), result.UpdatedAt, 2*time.Second)
}
4 changes: 3 additions & 1 deletion internal/database/repository/postgres/test_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ func CreateTestContainer(t *testing.T) (*TestContainer, error) {
public_key_g1 VARCHAR(255) PRIMARY KEY,
public_key_g2 VARCHAR(255) NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP NOT NULL DEFAULT NOW()
updated_at TIMESTAMP NOT NULL DEFAULT NOW(),
api_key_hash text,
locked boolean DEFAULT false
);
`)
if err != nil {
Expand Down
6 changes: 4 additions & 2 deletions internal/database/sql/cerberus.sql
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ CREATE TABLE IF NOT EXISTS public.keys_metadata (
public_key_g1 VARCHAR(255) PRIMARY KEY,
public_key_g2 VARCHAR(255) NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP NOT NULL DEFAULT NOW()
);
updated_at TIMESTAMP NOT NULL DEFAULT NOW(),
api_key_hash text,
locked boolean DEFAULT false
);
86 changes: 86 additions & 0 deletions internal/middleware/auth_interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package middleware

import (
"context"
"errors"
"strings"

v1 "github.com/Layr-Labs/cerberus-api/pkg/api/v1"
"github.com/Layr-Labs/cerberus/internal/common"
"github.com/Layr-Labs/cerberus/internal/database/repository"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

// AuthInterceptor creates a selective authentication interceptor
func AuthInterceptor(
protectedServiceName string,
keyMetadataRepo repository.KeyMetadataRepository,
) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// Check if the current service should be protected
if !strings.HasPrefix(info.FullMethod, "/"+protectedServiceName) {
// Skip auth for non-protected services
return handler(ctx, req)
}

// Get metadata from context
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Error(codes.Unauthenticated, "missing metadata")
}

// Get authorization token
authHeader := md.Get("authorization")
if len(authHeader) == 0 {
return nil, status.Error(codes.Unauthenticated, "missing authorization header")
}

// Validate the token (implement your own validation logic)
valid, err := validateToken(ctx, authHeader[0], req, keyMetadataRepo)
if err != nil {
return nil, status.Error(codes.Unauthenticated, err.Error())
}

if !valid {
return nil, status.Error(codes.Unauthenticated, "invalid token")
}

// If authentication successful, proceed with the handler
return handler(ctx, req)
}
}

// Example token validation function - replace with your own implementation
func validateToken(
ctx context.Context,
token string,
req interface{},
keyMetadataRepo repository.KeyMetadataRepository,
) (bool, error) {
var pubKeyG1 string
switch r := req.(type) {
case *v1.SignGenericRequest:
pubKeyG1 = r.GetPublicKeyG1()
case *v1.SignG1Request:
pubKeyG1 = r.GetPublicKeyG1()
default:
return false, errors.New("invalid request type")
}

keyMetadata, err := keyMetadataRepo.Get(ctx, pubKeyG1)
if err != nil {
return false, err
}

requestAPIKeyHash := common.CreateSHA256Hash(token)

if keyMetadata.ApiKeyHash != requestAPIKeyHash {
return false, errors.New("invalid token")
}

return true, nil
}
6 changes: 5 additions & 1 deletion internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ func Start(config *configuration.Configuration, logger *slog.Logger) {

// Register metrics middleware
metricsMiddleware := middleware.NewMetricsMiddleware(registry, rpcMetrics)
opts = append(opts, grpc.UnaryInterceptor(metricsMiddleware.UnaryServerInterceptor()))
authInterceptor := middleware.AuthInterceptor("signer.v1.Signer", keyMetadataRepo)
opts = append(
opts,
grpc.ChainUnaryInterceptor(metricsMiddleware.UnaryServerInterceptor(), authInterceptor),
)

s := grpc.NewServer(opts...)
kmsService := kms.NewService(config, keystore, keyMetadataRepo, logger, rpcMetrics)
Expand Down
34 changes: 33 additions & 1 deletion internal/services/kms/kms.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"math/big"

v1 "github.com/Layr-Labs/cerberus-api/pkg/api/v1"
"github.com/google/uuid"

"github.com/Layr-Labs/cerberus/internal/common"
"github.com/Layr-Labs/cerberus/internal/configuration"
Expand Down Expand Up @@ -75,9 +76,17 @@ func (k *Service) GenerateKeyPair(
return nil, status.Error(codes.Internal, err.Error())
}

// Generate a new API key and hash
apiKey, apiKeyHash, err := generateNewAPIKeyAndHash()
if err != nil {
k.logger.Error(fmt.Sprintf("Failed to generate API key: %v", err))
return nil, status.Error(codes.Internal, err.Error())
}

err = k.keyMetadataRepo.Create(ctx, &model.KeyMetadata{
PublicKeyG1: pubKeyHex,
PublicKeyG2: g2PubKey,
ApiKeyHash: apiKeyHash,
})
if err != nil {
k.logger.Error(fmt.Sprintf("Failed to save key metadata: %v", err))
Expand All @@ -94,6 +103,7 @@ func (k *Service) GenerateKeyPair(
PublicKeyG2: g2PubKey,
PrivateKey: privKeyHex,
Mnemonic: keyPair.Mnemonic,
ApiKey: apiKey,
}, nil
}

Expand Down Expand Up @@ -149,16 +159,28 @@ func (k *Service) ImportKey(
return nil, status.Error(codes.Internal, err.Error())
}

// Generate a new API key and hash
apiKey, apiKeyHash, err := generateNewAPIKeyAndHash()
if err != nil {
k.logger.Error(fmt.Sprintf("Failed to generate API key: %v", err))
return nil, status.Error(codes.Internal, err.Error())
}

err = k.keyMetadataRepo.Create(ctx, &model.KeyMetadata{
PublicKeyG1: pubKeyHex,
PublicKeyG2: g2PubKey,
ApiKeyHash: apiKeyHash,
})
if err != nil {
k.logger.Error(fmt.Sprintf("Failed to save key metadata: %v", err))
return nil, status.Error(codes.Internal, err.Error())
}

return &v1.ImportKeyResponse{PublicKeyG1: pubKeyHex, PublicKeyG2: g2PubKey}, nil
return &v1.ImportKeyResponse{
PublicKeyG1: pubKeyHex,
PublicKeyG2: g2PubKey,
ApiKey: apiKey,
}, nil
}

func (k *Service) ListKeys(
Expand Down Expand Up @@ -197,3 +219,13 @@ func (k *Service) GetKeyMetadata(
UpdatedAt: metadata.UpdatedAt.Unix(),
}, nil
}

func generateNewAPIKeyAndHash() (string, string, error) {
newUUID, err := uuid.NewV7()
if err != nil {
return "", "", err
}
apiKey := newUUID.String()
apiKeyHash := common.CreateSHA256Hash(apiKey)
return apiKey, apiKeyHash, nil
}

0 comments on commit 6dbd9ba

Please sign in to comment.