diff --git a/Makefile b/Makefile index 4b25303..2959a39 100644 --- a/Makefile +++ b/Makefile @@ -41,3 +41,8 @@ docker: ## runs docker build migrate: ## runs database migrations go install -tags 'postgres' github.com/golang-migrate/migrate/v4/cmd/migrate@latest migrate -path internal/database/migrations/ -database "postgres://user:password@localhost:5432/cerberus?sslmode=disable" --verbose up + +.PHONY: create-migration +create-migration: ## creates a new database migration + go install -tags 'postgres' github.com/golang-migrate/migrate/v4/cmd/migrate@latest + migrate create -dir internal/database/migrations/ -ext sql $(name) diff --git a/go.mod b/go.mod index 87527af..2157461 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,8 @@ go 1.22.0 toolchain go1.22.3 +replace github.com/Layr-Labs/cerberus-api => ../cerberus-api + require ( cloud.google.com/go/secretmanager v1.14.2 github.com/Layr-Labs/bn254-keystore-go v0.0.0-20250107020618-26bd412fae87 @@ -14,6 +16,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.6 github.com/consensys/gnark-crypto v0.12.1 github.com/golang-migrate/migrate/v4 v4.18.1 + github.com/google/uuid v1.6.0 github.com/prometheus/client_golang v1.20.3 github.com/stretchr/testify v1.10.0 github.com/testcontainers/testcontainers-go v0.34.0 @@ -37,7 +40,6 @@ require ( github.com/docker/go-units v0.5.0 // indirect github.com/go-ole/go-ole v1.2.6 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect diff --git a/go.sum b/go.sum index 33356a2..d7e55bc 100644 --- a/go.sum +++ b/go.sum @@ -20,8 +20,6 @@ github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg6 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Layr-Labs/bn254-keystore-go v0.0.0-20250107020618-26bd412fae87 h1:EkaBNT0o8RTgtFeYSKaoNHNbnCVxrcsAyRpUeN29hiQ= github.com/Layr-Labs/bn254-keystore-go v0.0.0-20250107020618-26bd412fae87/go.mod h1:7J8hptSX8cFq7KmVb+rEO5aEifj7E44c3i0afIyr4WA= -github.com/Layr-Labs/cerberus-api v0.0.2-0.20250108174619-d5e1eb03fbd5 h1:s24M6HYObEuV9OSY36jUM09kp5fOhuz/g1ev2qWDPzU= -github.com/Layr-Labs/cerberus-api v0.0.2-0.20250108174619-d5e1eb03fbd5/go.mod h1:Lm4fhzy0S3P7GjerzuseGaBFVczsIKmEhIjcT52Hluo= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/aws/aws-sdk-go-v2 v1.32.5 h1:U8vdWJuY7ruAkzaOdD7guwJjD06YSKmnKCJs7s3IkIo= diff --git a/internal/common/common.go b/internal/common/common.go index 57e3957..a692d64 100644 --- a/internal/common/common.go +++ b/internal/common/common.go @@ -1,5 +1,10 @@ package common +import ( + "crypto/sha256" + "encoding/hex" +) + func Trim0x(s string) string { if len(s) >= 2 && s[0:2] == "0x" { return s[2:] @@ -13,3 +18,9 @@ func RemovePrefix(s string, prefix string) string { } return s } + +func CreateSHA256Hash(s string) string { + hash := sha256.New() + hash.Write([]byte(s)) + return hex.EncodeToString(hash.Sum(nil)) +} diff --git a/internal/database/migrations/20250115045232_add_api_key_details.up.sql b/internal/database/migrations/20250115045232_add_api_key_details.up.sql new file mode 100644 index 0000000..325e29e --- /dev/null +++ b/internal/database/migrations/20250115045232_add_api_key_details.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE public.keys_metadata ADD COLUMN api_key_hash text; +ALTER TABLE public.keys_metadata ADD COLUMN locked boolean DEFAULT false; diff --git a/internal/database/model/key_metadata.go b/internal/database/model/key_metadata.go index e5ab661..76984a6 100644 --- a/internal/database/model/key_metadata.go +++ b/internal/database/model/key_metadata.go @@ -7,4 +7,6 @@ type KeyMetadata struct { PublicKeyG2 string `db:"public_key_g2"` CreatedAt time.Time `db:"created_at"` UpdatedAt time.Time `db:"updated_at"` + ApiKeyHash string `db:"api_key_hash"` + Locked bool `db:"locked"` } diff --git a/internal/database/repository/key_metadata.go b/internal/database/repository/key_metadata.go index 6c8626e..31043c7 100644 --- a/internal/database/repository/key_metadata.go +++ b/internal/database/repository/key_metadata.go @@ -10,6 +10,7 @@ type KeyMetadataRepository interface { Create(ctx context.Context, metadata *model.KeyMetadata) error Get(ctx context.Context, publicKeyG1 string) (*model.KeyMetadata, error) Update(ctx context.Context, metadata *model.KeyMetadata) error + UpdateAPIKeyHash(ctx context.Context, metadata *model.KeyMetadata) error Delete(ctx context.Context, publicKeyG1 string) error List(ctx context.Context) ([]*model.KeyMetadata, error) } diff --git a/internal/database/repository/postgres/key_metadata.go b/internal/database/repository/postgres/key_metadata.go index 0edd0e4..3eea062 100644 --- a/internal/database/repository/postgres/key_metadata.go +++ b/internal/database/repository/postgres/key_metadata.go @@ -23,12 +23,12 @@ func NewKeyMetadataRepository(db *sql.DB) repository.KeyMetadataRepository { const ( createKeyMetadataQuery = ` INSERT INTO public.keys_metadata ( - public_key_g1, public_key_g2, created_at, updated_at - ) VALUES ($1, $2, $3, $4) + public_key_g1, public_key_g2, created_at, updated_at, api_key_hash + ) VALUES ($1, $2, $3, $4, $5) ` getKeyMetadataQuery = ` - SELECT public_key_g1, public_key_g2, created_at, updated_at + SELECT public_key_g1, public_key_g2, created_at, updated_at, api_key_hash, locked FROM public.keys_metadata WHERE public_key_g1 = $1 ` @@ -39,6 +39,12 @@ const ( WHERE public_key_g1 = $2 ` + updateAPIKeyHashQuery = ` + UPDATE public.keys_metadata + SET api_key_hash = $1, updated_at = $2 + WHERE public_key_g1 = $3 + ` + deleteKeyMetadataQuery = ` DELETE FROM public.keys_metadata WHERE public_key_g1 = $1 @@ -68,6 +74,7 @@ func (r *keyMetadataRepo) Create(ctx context.Context, metadata *model.KeyMetadat metadata.PublicKeyG2, metadata.CreatedAt, metadata.UpdatedAt, + metadata.ApiKeyHash, ) return err } @@ -79,6 +86,8 @@ func (r *keyMetadataRepo) Get(ctx context.Context, publicKeyG1 string) (*model.K &metadata.PublicKeyG2, &metadata.CreatedAt, &metadata.UpdatedAt, + &metadata.ApiKeyHash, + &metadata.Locked, ) if err == sql.ErrNoRows { return nil, errors.New("key metadata not found") @@ -113,6 +122,24 @@ func (r *keyMetadataRepo) Update(ctx context.Context, metadata *model.KeyMetadat return nil } +func (r *keyMetadataRepo) UpdateAPIKeyHash(ctx context.Context, metadata *model.KeyMetadata) error { + if metadata.PublicKeyG1 == "" { + return errors.New("public key g1 is required") + } + if metadata.ApiKeyHash == "" { + return errors.New("api key hash is required") + } + + metadata.UpdatedAt = time.Now().UTC() + + _, err := r.db.ExecContext(ctx, updateAPIKeyHashQuery, + metadata.ApiKeyHash, + metadata.UpdatedAt, + metadata.PublicKeyG1, + ) + return err +} + func (r *keyMetadataRepo) Delete(ctx context.Context, publicKeyG1 string) error { if publicKeyG1 == "" { return errors.New("public key g1 is required") diff --git a/internal/database/repository/postgres/key_metadata_test.go b/internal/database/repository/postgres/key_metadata_test.go index 46e9790..c4eca2c 100644 --- a/internal/database/repository/postgres/key_metadata_test.go +++ b/internal/database/repository/postgres/key_metadata_test.go @@ -265,3 +265,31 @@ func TestKeyMetadataRepository_List(t *testing.T) { assert.Equal(t, "test_key_1", results[1].PublicKeyG1) }) } + +func TestKeyMetadataRepository_UpdateAPIKeyHash(t *testing.T) { + testDB := SetupTestDB(t) + // No need to defer db.Close() as it's handled by t.Cleanup + + // Create initial test data + initialKey := &model.KeyMetadata{ + PublicKeyG1: "test_key_1", + PublicKeyG2: "test_key_2", + ApiKeyHash: "test_api_key_hash", + } + err := testDB.Repo.Create(context.Background(), initialKey) + require.NoError(t, err) + + metadata := &model.KeyMetadata{ + PublicKeyG1: "test_key_1", + ApiKeyHash: "test_api_key_hash_2", + } + + err = testDB.Repo.UpdateAPIKeyHash(context.Background(), metadata) + require.NoError(t, err) + + // Verify the update + result, err := testDB.Repo.Get(context.Background(), metadata.PublicKeyG1) + assert.NoError(t, err) + assert.Equal(t, metadata.ApiKeyHash, result.ApiKeyHash) + assert.WithinDuration(t, time.Now(), result.UpdatedAt, 2*time.Second) +} diff --git a/internal/database/repository/postgres/test_helper.go b/internal/database/repository/postgres/test_helper.go index 64653d2..d8d65de 100644 --- a/internal/database/repository/postgres/test_helper.go +++ b/internal/database/repository/postgres/test_helper.go @@ -82,7 +82,9 @@ func CreateTestContainer(t *testing.T) (*TestContainer, error) { public_key_g1 VARCHAR(255) PRIMARY KEY, public_key_g2 VARCHAR(255) NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT NOW(), - updated_at TIMESTAMP NOT NULL DEFAULT NOW() + updated_at TIMESTAMP NOT NULL DEFAULT NOW(), + api_key_hash text, + locked boolean DEFAULT false ); `) if err != nil { diff --git a/internal/database/sql/cerberus.sql b/internal/database/sql/cerberus.sql index 3add411..fd36981 100644 --- a/internal/database/sql/cerberus.sql +++ b/internal/database/sql/cerberus.sql @@ -4,5 +4,7 @@ CREATE TABLE IF NOT EXISTS public.keys_metadata ( public_key_g1 VARCHAR(255) PRIMARY KEY, public_key_g2 VARCHAR(255) NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT NOW(), - updated_at TIMESTAMP NOT NULL DEFAULT NOW() -); \ No newline at end of file + updated_at TIMESTAMP NOT NULL DEFAULT NOW(), + api_key_hash text, + locked boolean DEFAULT false +); diff --git a/internal/middleware/auth_interceptor.go b/internal/middleware/auth_interceptor.go new file mode 100644 index 0000000..f91d586 --- /dev/null +++ b/internal/middleware/auth_interceptor.go @@ -0,0 +1,86 @@ +package middleware + +import ( + "context" + "errors" + "strings" + + v1 "github.com/Layr-Labs/cerberus-api/pkg/api/v1" + "github.com/Layr-Labs/cerberus/internal/common" + "github.com/Layr-Labs/cerberus/internal/database/repository" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +// AuthInterceptor creates a selective authentication interceptor +func AuthInterceptor( + protectedServiceName string, + keyMetadataRepo repository.KeyMetadataRepository, +) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + // Check if the current service should be protected + if !strings.HasPrefix(info.FullMethod, "/"+protectedServiceName) { + // Skip auth for non-protected services + return handler(ctx, req) + } + + // Get metadata from context + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Error(codes.Unauthenticated, "missing metadata") + } + + // Get authorization token + authHeader := md.Get("authorization") + if len(authHeader) == 0 { + return nil, status.Error(codes.Unauthenticated, "missing authorization header") + } + + // Validate the token (implement your own validation logic) + valid, err := validateToken(ctx, authHeader[0], req, keyMetadataRepo) + if err != nil { + return nil, status.Error(codes.Unauthenticated, err.Error()) + } + + if !valid { + return nil, status.Error(codes.Unauthenticated, "invalid token") + } + + // If authentication successful, proceed with the handler + return handler(ctx, req) + } +} + +// Example token validation function - replace with your own implementation +func validateToken( + ctx context.Context, + token string, + req interface{}, + keyMetadataRepo repository.KeyMetadataRepository, +) (bool, error) { + var pubKeyG1 string + switch r := req.(type) { + case *v1.SignGenericRequest: + pubKeyG1 = r.GetPublicKeyG1() + case *v1.SignG1Request: + pubKeyG1 = r.GetPublicKeyG1() + default: + return false, errors.New("invalid request type") + } + + keyMetadata, err := keyMetadataRepo.Get(ctx, pubKeyG1) + if err != nil { + return false, err + } + + requestAPIKeyHash := common.CreateSHA256Hash(token) + + if keyMetadata.ApiKeyHash != requestAPIKeyHash { + return false, errors.New("invalid token") + } + + return true, nil +} diff --git a/internal/server/server.go b/internal/server/server.go index 3e48015..65d7c6f 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -117,7 +117,11 @@ func Start(config *configuration.Configuration, logger *slog.Logger) { // Register metrics middleware metricsMiddleware := middleware.NewMetricsMiddleware(registry, rpcMetrics) - opts = append(opts, grpc.UnaryInterceptor(metricsMiddleware.UnaryServerInterceptor())) + authInterceptor := middleware.AuthInterceptor("signer.v1.Signer", keyMetadataRepo) + opts = append( + opts, + grpc.ChainUnaryInterceptor(metricsMiddleware.UnaryServerInterceptor(), authInterceptor), + ) s := grpc.NewServer(opts...) kmsService := kms.NewService(config, keystore, keyMetadataRepo, logger, rpcMetrics) diff --git a/internal/services/kms/kms.go b/internal/services/kms/kms.go index fab59dc..9eaf1b5 100644 --- a/internal/services/kms/kms.go +++ b/internal/services/kms/kms.go @@ -8,6 +8,7 @@ import ( "math/big" v1 "github.com/Layr-Labs/cerberus-api/pkg/api/v1" + "github.com/google/uuid" "github.com/Layr-Labs/cerberus/internal/common" "github.com/Layr-Labs/cerberus/internal/configuration" @@ -75,9 +76,17 @@ func (k *Service) GenerateKeyPair( return nil, status.Error(codes.Internal, err.Error()) } + // Generate a new API key and hash + apiKey, apiKeyHash, err := generateNewAPIKeyAndHash() + if err != nil { + k.logger.Error(fmt.Sprintf("Failed to generate API key: %v", err)) + return nil, status.Error(codes.Internal, err.Error()) + } + err = k.keyMetadataRepo.Create(ctx, &model.KeyMetadata{ PublicKeyG1: pubKeyHex, PublicKeyG2: g2PubKey, + ApiKeyHash: apiKeyHash, }) if err != nil { k.logger.Error(fmt.Sprintf("Failed to save key metadata: %v", err)) @@ -94,6 +103,7 @@ func (k *Service) GenerateKeyPair( PublicKeyG2: g2PubKey, PrivateKey: privKeyHex, Mnemonic: keyPair.Mnemonic, + ApiKey: apiKey, }, nil } @@ -149,16 +159,28 @@ func (k *Service) ImportKey( return nil, status.Error(codes.Internal, err.Error()) } + // Generate a new API key and hash + apiKey, apiKeyHash, err := generateNewAPIKeyAndHash() + if err != nil { + k.logger.Error(fmt.Sprintf("Failed to generate API key: %v", err)) + return nil, status.Error(codes.Internal, err.Error()) + } + err = k.keyMetadataRepo.Create(ctx, &model.KeyMetadata{ PublicKeyG1: pubKeyHex, PublicKeyG2: g2PubKey, + ApiKeyHash: apiKeyHash, }) if err != nil { k.logger.Error(fmt.Sprintf("Failed to save key metadata: %v", err)) return nil, status.Error(codes.Internal, err.Error()) } - return &v1.ImportKeyResponse{PublicKeyG1: pubKeyHex, PublicKeyG2: g2PubKey}, nil + return &v1.ImportKeyResponse{ + PublicKeyG1: pubKeyHex, + PublicKeyG2: g2PubKey, + ApiKey: apiKey, + }, nil } func (k *Service) ListKeys( @@ -197,3 +219,13 @@ func (k *Service) GetKeyMetadata( UpdatedAt: metadata.UpdatedAt.Unix(), }, nil } + +func generateNewAPIKeyAndHash() (string, string, error) { + newUUID, err := uuid.NewV7() + if err != nil { + return "", "", err + } + apiKey := newUUID.String() + apiKeyHash := common.CreateSHA256Hash(apiKey) + return apiKey, apiKeyHash, nil +}