From 6dbd9ba2231fcdd5ac0987463f41eac88de51618 Mon Sep 17 00:00:00 2001 From: Madhur Shrimal Date: Wed, 15 Jan 2025 18:10:55 -0800 Subject: [PATCH 1/4] feat: add api key for authz --- Makefile | 5 ++ go.mod | 4 +- go.sum | 2 - internal/common/common.go | 11 +++ .../20250115045232_add_api_key_details.up.sql | 2 + internal/database/model/key_metadata.go | 2 + internal/database/repository/key_metadata.go | 1 + .../repository/postgres/key_metadata.go | 33 ++++++- .../repository/postgres/key_metadata_test.go | 28 ++++++ .../repository/postgres/test_helper.go | 4 +- internal/database/sql/cerberus.sql | 6 +- internal/middleware/auth_interceptor.go | 86 +++++++++++++++++++ internal/server/server.go | 6 +- internal/services/kms/kms.go | 34 +++++++- 14 files changed, 213 insertions(+), 11 deletions(-) create mode 100644 internal/database/migrations/20250115045232_add_api_key_details.up.sql create mode 100644 internal/middleware/auth_interceptor.go diff --git a/Makefile b/Makefile index 4b25303..2959a39 100644 --- a/Makefile +++ b/Makefile @@ -41,3 +41,8 @@ docker: ## runs docker build migrate: ## runs database migrations go install -tags 'postgres' github.com/golang-migrate/migrate/v4/cmd/migrate@latest migrate -path internal/database/migrations/ -database "postgres://user:password@localhost:5432/cerberus?sslmode=disable" --verbose up + +.PHONY: create-migration +create-migration: ## creates a new database migration + go install -tags 'postgres' github.com/golang-migrate/migrate/v4/cmd/migrate@latest + migrate create -dir internal/database/migrations/ -ext sql $(name) diff --git a/go.mod b/go.mod index 87527af..2157461 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,8 @@ go 1.22.0 toolchain go1.22.3 +replace github.com/Layr-Labs/cerberus-api => ../cerberus-api + require ( cloud.google.com/go/secretmanager v1.14.2 github.com/Layr-Labs/bn254-keystore-go v0.0.0-20250107020618-26bd412fae87 @@ -14,6 +16,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.6 github.com/consensys/gnark-crypto v0.12.1 github.com/golang-migrate/migrate/v4 v4.18.1 + github.com/google/uuid v1.6.0 github.com/prometheus/client_golang v1.20.3 github.com/stretchr/testify v1.10.0 github.com/testcontainers/testcontainers-go v0.34.0 @@ -37,7 +40,6 @@ require ( github.com/docker/go-units v0.5.0 // indirect github.com/go-ole/go-ole v1.2.6 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect diff --git a/go.sum b/go.sum index 33356a2..d7e55bc 100644 --- a/go.sum +++ b/go.sum @@ -20,8 +20,6 @@ github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg6 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Layr-Labs/bn254-keystore-go v0.0.0-20250107020618-26bd412fae87 h1:EkaBNT0o8RTgtFeYSKaoNHNbnCVxrcsAyRpUeN29hiQ= github.com/Layr-Labs/bn254-keystore-go v0.0.0-20250107020618-26bd412fae87/go.mod h1:7J8hptSX8cFq7KmVb+rEO5aEifj7E44c3i0afIyr4WA= -github.com/Layr-Labs/cerberus-api v0.0.2-0.20250108174619-d5e1eb03fbd5 h1:s24M6HYObEuV9OSY36jUM09kp5fOhuz/g1ev2qWDPzU= -github.com/Layr-Labs/cerberus-api v0.0.2-0.20250108174619-d5e1eb03fbd5/go.mod h1:Lm4fhzy0S3P7GjerzuseGaBFVczsIKmEhIjcT52Hluo= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/aws/aws-sdk-go-v2 v1.32.5 h1:U8vdWJuY7ruAkzaOdD7guwJjD06YSKmnKCJs7s3IkIo= diff --git a/internal/common/common.go b/internal/common/common.go index 57e3957..a692d64 100644 --- a/internal/common/common.go +++ b/internal/common/common.go @@ -1,5 +1,10 @@ package common +import ( + "crypto/sha256" + "encoding/hex" +) + func Trim0x(s string) string { if len(s) >= 2 && s[0:2] == "0x" { return s[2:] @@ -13,3 +18,9 @@ func RemovePrefix(s string, prefix string) string { } return s } + +func CreateSHA256Hash(s string) string { + hash := sha256.New() + hash.Write([]byte(s)) + return hex.EncodeToString(hash.Sum(nil)) +} diff --git a/internal/database/migrations/20250115045232_add_api_key_details.up.sql b/internal/database/migrations/20250115045232_add_api_key_details.up.sql new file mode 100644 index 0000000..325e29e --- /dev/null +++ b/internal/database/migrations/20250115045232_add_api_key_details.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE public.keys_metadata ADD COLUMN api_key_hash text; +ALTER TABLE public.keys_metadata ADD COLUMN locked boolean DEFAULT false; diff --git a/internal/database/model/key_metadata.go b/internal/database/model/key_metadata.go index e5ab661..76984a6 100644 --- a/internal/database/model/key_metadata.go +++ b/internal/database/model/key_metadata.go @@ -7,4 +7,6 @@ type KeyMetadata struct { PublicKeyG2 string `db:"public_key_g2"` CreatedAt time.Time `db:"created_at"` UpdatedAt time.Time `db:"updated_at"` + ApiKeyHash string `db:"api_key_hash"` + Locked bool `db:"locked"` } diff --git a/internal/database/repository/key_metadata.go b/internal/database/repository/key_metadata.go index 6c8626e..31043c7 100644 --- a/internal/database/repository/key_metadata.go +++ b/internal/database/repository/key_metadata.go @@ -10,6 +10,7 @@ type KeyMetadataRepository interface { Create(ctx context.Context, metadata *model.KeyMetadata) error Get(ctx context.Context, publicKeyG1 string) (*model.KeyMetadata, error) Update(ctx context.Context, metadata *model.KeyMetadata) error + UpdateAPIKeyHash(ctx context.Context, metadata *model.KeyMetadata) error Delete(ctx context.Context, publicKeyG1 string) error List(ctx context.Context) ([]*model.KeyMetadata, error) } diff --git a/internal/database/repository/postgres/key_metadata.go b/internal/database/repository/postgres/key_metadata.go index 0edd0e4..3eea062 100644 --- a/internal/database/repository/postgres/key_metadata.go +++ b/internal/database/repository/postgres/key_metadata.go @@ -23,12 +23,12 @@ func NewKeyMetadataRepository(db *sql.DB) repository.KeyMetadataRepository { const ( createKeyMetadataQuery = ` INSERT INTO public.keys_metadata ( - public_key_g1, public_key_g2, created_at, updated_at - ) VALUES ($1, $2, $3, $4) + public_key_g1, public_key_g2, created_at, updated_at, api_key_hash + ) VALUES ($1, $2, $3, $4, $5) ` getKeyMetadataQuery = ` - SELECT public_key_g1, public_key_g2, created_at, updated_at + SELECT public_key_g1, public_key_g2, created_at, updated_at, api_key_hash, locked FROM public.keys_metadata WHERE public_key_g1 = $1 ` @@ -39,6 +39,12 @@ const ( WHERE public_key_g1 = $2 ` + updateAPIKeyHashQuery = ` + UPDATE public.keys_metadata + SET api_key_hash = $1, updated_at = $2 + WHERE public_key_g1 = $3 + ` + deleteKeyMetadataQuery = ` DELETE FROM public.keys_metadata WHERE public_key_g1 = $1 @@ -68,6 +74,7 @@ func (r *keyMetadataRepo) Create(ctx context.Context, metadata *model.KeyMetadat metadata.PublicKeyG2, metadata.CreatedAt, metadata.UpdatedAt, + metadata.ApiKeyHash, ) return err } @@ -79,6 +86,8 @@ func (r *keyMetadataRepo) Get(ctx context.Context, publicKeyG1 string) (*model.K &metadata.PublicKeyG2, &metadata.CreatedAt, &metadata.UpdatedAt, + &metadata.ApiKeyHash, + &metadata.Locked, ) if err == sql.ErrNoRows { return nil, errors.New("key metadata not found") @@ -113,6 +122,24 @@ func (r *keyMetadataRepo) Update(ctx context.Context, metadata *model.KeyMetadat return nil } +func (r *keyMetadataRepo) UpdateAPIKeyHash(ctx context.Context, metadata *model.KeyMetadata) error { + if metadata.PublicKeyG1 == "" { + return errors.New("public key g1 is required") + } + if metadata.ApiKeyHash == "" { + return errors.New("api key hash is required") + } + + metadata.UpdatedAt = time.Now().UTC() + + _, err := r.db.ExecContext(ctx, updateAPIKeyHashQuery, + metadata.ApiKeyHash, + metadata.UpdatedAt, + metadata.PublicKeyG1, + ) + return err +} + func (r *keyMetadataRepo) Delete(ctx context.Context, publicKeyG1 string) error { if publicKeyG1 == "" { return errors.New("public key g1 is required") diff --git a/internal/database/repository/postgres/key_metadata_test.go b/internal/database/repository/postgres/key_metadata_test.go index 46e9790..c4eca2c 100644 --- a/internal/database/repository/postgres/key_metadata_test.go +++ b/internal/database/repository/postgres/key_metadata_test.go @@ -265,3 +265,31 @@ func TestKeyMetadataRepository_List(t *testing.T) { assert.Equal(t, "test_key_1", results[1].PublicKeyG1) }) } + +func TestKeyMetadataRepository_UpdateAPIKeyHash(t *testing.T) { + testDB := SetupTestDB(t) + // No need to defer db.Close() as it's handled by t.Cleanup + + // Create initial test data + initialKey := &model.KeyMetadata{ + PublicKeyG1: "test_key_1", + PublicKeyG2: "test_key_2", + ApiKeyHash: "test_api_key_hash", + } + err := testDB.Repo.Create(context.Background(), initialKey) + require.NoError(t, err) + + metadata := &model.KeyMetadata{ + PublicKeyG1: "test_key_1", + ApiKeyHash: "test_api_key_hash_2", + } + + err = testDB.Repo.UpdateAPIKeyHash(context.Background(), metadata) + require.NoError(t, err) + + // Verify the update + result, err := testDB.Repo.Get(context.Background(), metadata.PublicKeyG1) + assert.NoError(t, err) + assert.Equal(t, metadata.ApiKeyHash, result.ApiKeyHash) + assert.WithinDuration(t, time.Now(), result.UpdatedAt, 2*time.Second) +} diff --git a/internal/database/repository/postgres/test_helper.go b/internal/database/repository/postgres/test_helper.go index 64653d2..d8d65de 100644 --- a/internal/database/repository/postgres/test_helper.go +++ b/internal/database/repository/postgres/test_helper.go @@ -82,7 +82,9 @@ func CreateTestContainer(t *testing.T) (*TestContainer, error) { public_key_g1 VARCHAR(255) PRIMARY KEY, public_key_g2 VARCHAR(255) NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT NOW(), - updated_at TIMESTAMP NOT NULL DEFAULT NOW() + updated_at TIMESTAMP NOT NULL DEFAULT NOW(), + api_key_hash text, + locked boolean DEFAULT false ); `) if err != nil { diff --git a/internal/database/sql/cerberus.sql b/internal/database/sql/cerberus.sql index 3add411..fd36981 100644 --- a/internal/database/sql/cerberus.sql +++ b/internal/database/sql/cerberus.sql @@ -4,5 +4,7 @@ CREATE TABLE IF NOT EXISTS public.keys_metadata ( public_key_g1 VARCHAR(255) PRIMARY KEY, public_key_g2 VARCHAR(255) NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT NOW(), - updated_at TIMESTAMP NOT NULL DEFAULT NOW() -); \ No newline at end of file + updated_at TIMESTAMP NOT NULL DEFAULT NOW(), + api_key_hash text, + locked boolean DEFAULT false +); diff --git a/internal/middleware/auth_interceptor.go b/internal/middleware/auth_interceptor.go new file mode 100644 index 0000000..f91d586 --- /dev/null +++ b/internal/middleware/auth_interceptor.go @@ -0,0 +1,86 @@ +package middleware + +import ( + "context" + "errors" + "strings" + + v1 "github.com/Layr-Labs/cerberus-api/pkg/api/v1" + "github.com/Layr-Labs/cerberus/internal/common" + "github.com/Layr-Labs/cerberus/internal/database/repository" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +// AuthInterceptor creates a selective authentication interceptor +func AuthInterceptor( + protectedServiceName string, + keyMetadataRepo repository.KeyMetadataRepository, +) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + // Check if the current service should be protected + if !strings.HasPrefix(info.FullMethod, "/"+protectedServiceName) { + // Skip auth for non-protected services + return handler(ctx, req) + } + + // Get metadata from context + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Error(codes.Unauthenticated, "missing metadata") + } + + // Get authorization token + authHeader := md.Get("authorization") + if len(authHeader) == 0 { + return nil, status.Error(codes.Unauthenticated, "missing authorization header") + } + + // Validate the token (implement your own validation logic) + valid, err := validateToken(ctx, authHeader[0], req, keyMetadataRepo) + if err != nil { + return nil, status.Error(codes.Unauthenticated, err.Error()) + } + + if !valid { + return nil, status.Error(codes.Unauthenticated, "invalid token") + } + + // If authentication successful, proceed with the handler + return handler(ctx, req) + } +} + +// Example token validation function - replace with your own implementation +func validateToken( + ctx context.Context, + token string, + req interface{}, + keyMetadataRepo repository.KeyMetadataRepository, +) (bool, error) { + var pubKeyG1 string + switch r := req.(type) { + case *v1.SignGenericRequest: + pubKeyG1 = r.GetPublicKeyG1() + case *v1.SignG1Request: + pubKeyG1 = r.GetPublicKeyG1() + default: + return false, errors.New("invalid request type") + } + + keyMetadata, err := keyMetadataRepo.Get(ctx, pubKeyG1) + if err != nil { + return false, err + } + + requestAPIKeyHash := common.CreateSHA256Hash(token) + + if keyMetadata.ApiKeyHash != requestAPIKeyHash { + return false, errors.New("invalid token") + } + + return true, nil +} diff --git a/internal/server/server.go b/internal/server/server.go index 3e48015..65d7c6f 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -117,7 +117,11 @@ func Start(config *configuration.Configuration, logger *slog.Logger) { // Register metrics middleware metricsMiddleware := middleware.NewMetricsMiddleware(registry, rpcMetrics) - opts = append(opts, grpc.UnaryInterceptor(metricsMiddleware.UnaryServerInterceptor())) + authInterceptor := middleware.AuthInterceptor("signer.v1.Signer", keyMetadataRepo) + opts = append( + opts, + grpc.ChainUnaryInterceptor(metricsMiddleware.UnaryServerInterceptor(), authInterceptor), + ) s := grpc.NewServer(opts...) kmsService := kms.NewService(config, keystore, keyMetadataRepo, logger, rpcMetrics) diff --git a/internal/services/kms/kms.go b/internal/services/kms/kms.go index fab59dc..9eaf1b5 100644 --- a/internal/services/kms/kms.go +++ b/internal/services/kms/kms.go @@ -8,6 +8,7 @@ import ( "math/big" v1 "github.com/Layr-Labs/cerberus-api/pkg/api/v1" + "github.com/google/uuid" "github.com/Layr-Labs/cerberus/internal/common" "github.com/Layr-Labs/cerberus/internal/configuration" @@ -75,9 +76,17 @@ func (k *Service) GenerateKeyPair( return nil, status.Error(codes.Internal, err.Error()) } + // Generate a new API key and hash + apiKey, apiKeyHash, err := generateNewAPIKeyAndHash() + if err != nil { + k.logger.Error(fmt.Sprintf("Failed to generate API key: %v", err)) + return nil, status.Error(codes.Internal, err.Error()) + } + err = k.keyMetadataRepo.Create(ctx, &model.KeyMetadata{ PublicKeyG1: pubKeyHex, PublicKeyG2: g2PubKey, + ApiKeyHash: apiKeyHash, }) if err != nil { k.logger.Error(fmt.Sprintf("Failed to save key metadata: %v", err)) @@ -94,6 +103,7 @@ func (k *Service) GenerateKeyPair( PublicKeyG2: g2PubKey, PrivateKey: privKeyHex, Mnemonic: keyPair.Mnemonic, + ApiKey: apiKey, }, nil } @@ -149,16 +159,28 @@ func (k *Service) ImportKey( return nil, status.Error(codes.Internal, err.Error()) } + // Generate a new API key and hash + apiKey, apiKeyHash, err := generateNewAPIKeyAndHash() + if err != nil { + k.logger.Error(fmt.Sprintf("Failed to generate API key: %v", err)) + return nil, status.Error(codes.Internal, err.Error()) + } + err = k.keyMetadataRepo.Create(ctx, &model.KeyMetadata{ PublicKeyG1: pubKeyHex, PublicKeyG2: g2PubKey, + ApiKeyHash: apiKeyHash, }) if err != nil { k.logger.Error(fmt.Sprintf("Failed to save key metadata: %v", err)) return nil, status.Error(codes.Internal, err.Error()) } - return &v1.ImportKeyResponse{PublicKeyG1: pubKeyHex, PublicKeyG2: g2PubKey}, nil + return &v1.ImportKeyResponse{ + PublicKeyG1: pubKeyHex, + PublicKeyG2: g2PubKey, + ApiKey: apiKey, + }, nil } func (k *Service) ListKeys( @@ -197,3 +219,13 @@ func (k *Service) GetKeyMetadata( UpdatedAt: metadata.UpdatedAt.Unix(), }, nil } + +func generateNewAPIKeyAndHash() (string, string, error) { + newUUID, err := uuid.NewV7() + if err != nil { + return "", "", err + } + apiKey := newUUID.String() + apiKeyHash := common.CreateSHA256Hash(apiKey) + return apiKey, apiKeyHash, nil +} From 80905ccf8dff951cf7f7c8063a86ce16fc6e86b2 Mon Sep 17 00:00:00 2001 From: Madhur Shrimal Date: Thu, 16 Jan 2025 18:00:02 -0800 Subject: [PATCH 2/4] add admin api implementation --- Makefile | 2 +- cmd/cerberus/main.go | 32 ++- go.mod | 4 +- go.sum | 2 + internal/common/common.go | 12 + internal/configuration/configuration.go | 11 +- internal/database/repository/key_metadata.go | 3 +- .../repository/postgres/key_metadata.go | 47 +++- .../repository/postgres/key_metadata_test.go | 11 +- internal/server/config.go | 8 + internal/server/server.go | 262 +++++++++++------- internal/server/shared_resources.go | 165 +++++++++++ internal/services/admin/admin.go | 109 ++++++++ internal/services/kms/kms.go | 15 +- 14 files changed, 538 insertions(+), 145 deletions(-) create mode 100644 internal/server/config.go create mode 100644 internal/server/shared_resources.go create mode 100644 internal/services/admin/admin.go diff --git a/Makefile b/Makefile index 2959a39..9e94399 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ build: .PHONY: start start: make build - ./bin/$(APP_NAME) --log-level=debug + ./bin/$(APP_NAME) --log-level=debug --enable-admin .PHONY: fmt fmt: ## formats all go files diff --git a/cmd/cerberus/main.go b/cmd/cerberus/main.go index 9f97e5f..aaf27f3 100644 --- a/cmd/cerberus/main.go +++ b/cmd/cerberus/main.go @@ -23,20 +23,34 @@ var ( EnvVars: []string{"KEYSTORE_DIR"}, } - grpcPortFlag = &cli.StringFlag{ + grpcPortFlag = &cli.IntFlag{ Name: "grpc-port", Usage: "Port for the gRPC server", - Value: "50051", + Value: 50051, EnvVars: []string{"GRPC_PORT"}, } - metricsPortFlag = &cli.StringFlag{ + adminPortFlag = &cli.IntFlag{ + Name: "admin-port", + Usage: "Port for the admin server", + Value: 50052, + EnvVars: []string{"ADMIN_PORT"}, + } + + metricsPortFlag = &cli.IntFlag{ Name: "metrics-port", Usage: "Port for the metrics server", - Value: "9091", + Value: 9091, EnvVars: []string{"METRICS_PORT"}, } + enableAdminFlag = &cli.BoolFlag{ + Name: "enable-admin", + Usage: "Enable the admin server", + Value: false, + EnvVars: []string{"ENABLE_ADMIN"}, + } + logLevelFlag = &cli.StringFlag{ Name: "log-level", Usage: "Log level - supported levels: debug, info, warn, error", @@ -142,6 +156,7 @@ func main() { logFormatFlag, logLevelFlag, metricsPortFlag, + enableAdminFlag, tlsCaCertFlag, tlsServerKeyFlag, storageTypeFlag, @@ -152,6 +167,7 @@ func main() { awsSecretAccessKeyFlag, gcpProjectIDFlag, postgresDatabaseURLFlag, + adminPortFlag, } sort.Sort(cli.FlagsByName(app.Flags)) @@ -168,8 +184,9 @@ func main() { func start(c *cli.Context) error { keystoreDir := c.String(keystoreDirFlag.Name) - grpcPort := c.String(grpcPortFlag.Name) - metricsPort := c.String(metricsPortFlag.Name) + grpcPort := c.Int(grpcPortFlag.Name) + adminPort := c.Int(adminPortFlag.Name) + metricsPort := c.Int(metricsPortFlag.Name) logLevel := c.String(logLevelFlag.Name) logFormat := c.String(logFormatFlag.Name) tlsCaCert := c.String(tlsCaCertFlag.Name) @@ -182,9 +199,11 @@ func start(c *cli.Context) error { awsSecretAccessKey := c.String(awsSecretAccessKeyFlag.Name) gcpProjectID := c.String(gcpProjectIDFlag.Name) postgresDatabaseURL := c.String(postgresDatabaseURLFlag.Name) + enableAdmin := c.Bool(enableAdminFlag.Name) cfg := &configuration.Configuration{ KeystoreDir: keystoreDir, GrpcPort: grpcPort, + AdminPort: adminPort, MetricsPort: metricsPort, TLSCACert: tlsCaCert, TLSServerKey: tlsServerKey, @@ -196,6 +215,7 @@ func start(c *cli.Context) error { AWSSecretAccessKey: awsSecretAccessKey, GCPProjectID: gcpProjectID, PostgresDatabaseURL: postgresDatabaseURL, + EnableAdmin: enableAdmin, } if err := cfg.Validate(); err != nil { diff --git a/go.mod b/go.mod index 2157461..04febbe 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,10 @@ go 1.22.0 toolchain go1.22.3 -replace github.com/Layr-Labs/cerberus-api => ../cerberus-api - require ( cloud.google.com/go/secretmanager v1.14.2 github.com/Layr-Labs/bn254-keystore-go v0.0.0-20250107020618-26bd412fae87 - github.com/Layr-Labs/cerberus-api v0.0.2-0.20250108174619-d5e1eb03fbd5 + github.com/Layr-Labs/cerberus-api v0.0.2-0.20250117015901-0b1220ea735f github.com/aws/aws-sdk-go-v2 v1.32.5 github.com/aws/aws-sdk-go-v2/config v1.28.5 github.com/aws/aws-sdk-go-v2/credentials v1.17.46 diff --git a/go.sum b/go.sum index d7e55bc..c7561b6 100644 --- a/go.sum +++ b/go.sum @@ -20,6 +20,8 @@ github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg6 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Layr-Labs/bn254-keystore-go v0.0.0-20250107020618-26bd412fae87 h1:EkaBNT0o8RTgtFeYSKaoNHNbnCVxrcsAyRpUeN29hiQ= github.com/Layr-Labs/bn254-keystore-go v0.0.0-20250107020618-26bd412fae87/go.mod h1:7J8hptSX8cFq7KmVb+rEO5aEifj7E44c3i0afIyr4WA= +github.com/Layr-Labs/cerberus-api v0.0.2-0.20250117015901-0b1220ea735f h1:Od50IBPPjsAF9w8QGUsTFcFwmKA0UEH9zNtUN0PDM68= +github.com/Layr-Labs/cerberus-api v0.0.2-0.20250117015901-0b1220ea735f/go.mod h1:Lm4fhzy0S3P7GjerzuseGaBFVczsIKmEhIjcT52Hluo= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/aws/aws-sdk-go-v2 v1.32.5 h1:U8vdWJuY7ruAkzaOdD7guwJjD06YSKmnKCJs7s3IkIo= diff --git a/internal/common/common.go b/internal/common/common.go index a692d64..167af39 100644 --- a/internal/common/common.go +++ b/internal/common/common.go @@ -3,6 +3,8 @@ package common import ( "crypto/sha256" "encoding/hex" + + "github.com/google/uuid" ) func Trim0x(s string) string { @@ -24,3 +26,13 @@ func CreateSHA256Hash(s string) string { hash.Write([]byte(s)) return hex.EncodeToString(hash.Sum(nil)) } + +func GenerateNewAPIKeyAndHash() (string, string, error) { + newUUID, err := uuid.NewV7() + if err != nil { + return "", "", err + } + apiKey := newUUID.String() + apiKeyHash := CreateSHA256Hash(apiKey) + return apiKey, apiKeyHash, nil +} diff --git a/internal/configuration/configuration.go b/internal/configuration/configuration.go index b779ff9..8422033 100644 --- a/internal/configuration/configuration.go +++ b/internal/configuration/configuration.go @@ -31,8 +31,11 @@ type Configuration struct { // Google Secrets Manager storage parameters GCPProjectID string - GrpcPort string - MetricsPort string + GrpcPort int + MetricsPort int + AdminPort int + + EnableAdmin bool TLSCACert string TLSServerKey string @@ -72,11 +75,11 @@ func (s *Configuration) Validate() error { return fmt.Errorf("unsupported storage type: %s", s.StorageType) } - if s.GrpcPort == "" { + if s.GrpcPort == 0 { return fmt.Errorf("gRPC port is required") } - if s.MetricsPort == "" { + if s.MetricsPort == 0 { return fmt.Errorf("metrics port is required") } diff --git a/internal/database/repository/key_metadata.go b/internal/database/repository/key_metadata.go index 31043c7..eb03376 100644 --- a/internal/database/repository/key_metadata.go +++ b/internal/database/repository/key_metadata.go @@ -10,7 +10,8 @@ type KeyMetadataRepository interface { Create(ctx context.Context, metadata *model.KeyMetadata) error Get(ctx context.Context, publicKeyG1 string) (*model.KeyMetadata, error) Update(ctx context.Context, metadata *model.KeyMetadata) error - UpdateAPIKeyHash(ctx context.Context, metadata *model.KeyMetadata) error + UpdateAPIKeyHash(ctx context.Context, publicKeyG1 string, apiKeyHash string) error + UpdateLockStatus(ctx context.Context, publicKeyG1 string, locked bool) error Delete(ctx context.Context, publicKeyG1 string) error List(ctx context.Context) ([]*model.KeyMetadata, error) } diff --git a/internal/database/repository/postgres/key_metadata.go b/internal/database/repository/postgres/key_metadata.go index 3eea062..c601447 100644 --- a/internal/database/repository/postgres/key_metadata.go +++ b/internal/database/repository/postgres/key_metadata.go @@ -45,13 +45,19 @@ const ( WHERE public_key_g1 = $3 ` + updateLockStatusQuery = ` + UPDATE public.keys_metadata + SET locked = $1, updated_at = $2 + WHERE public_key_g1 = $3 + ` + deleteKeyMetadataQuery = ` DELETE FROM public.keys_metadata WHERE public_key_g1 = $1 ` - listKeyMetadataQuery = ` - SELECT public_key_g1, public_key_g2, created_at, updated_at + listAllKeysQuery = ` + SELECT public_key_g1, public_key_g2, created_at, updated_at, locked FROM public.keys_metadata ORDER BY created_at DESC ` @@ -122,20 +128,25 @@ func (r *keyMetadataRepo) Update(ctx context.Context, metadata *model.KeyMetadat return nil } -func (r *keyMetadataRepo) UpdateAPIKeyHash(ctx context.Context, metadata *model.KeyMetadata) error { - if metadata.PublicKeyG1 == "" { +func (r *keyMetadataRepo) UpdateAPIKeyHash( + ctx context.Context, + publicKeyG1 string, + apiKeyHash string, +) error { + if publicKeyG1 == "" { return errors.New("public key g1 is required") } - if metadata.ApiKeyHash == "" { + + if apiKeyHash == "" { return errors.New("api key hash is required") } - metadata.UpdatedAt = time.Now().UTC() + updatedAt := time.Now().UTC() _, err := r.db.ExecContext(ctx, updateAPIKeyHashQuery, - metadata.ApiKeyHash, - metadata.UpdatedAt, - metadata.PublicKeyG1, + apiKeyHash, + updatedAt, + publicKeyG1, ) return err } @@ -161,7 +172,7 @@ func (r *keyMetadataRepo) Delete(ctx context.Context, publicKeyG1 string) error } func (r *keyMetadataRepo) List(ctx context.Context) ([]*model.KeyMetadata, error) { - rows, err := r.db.QueryContext(ctx, listKeyMetadataQuery) + rows, err := r.db.QueryContext(ctx, listAllKeysQuery) if err != nil { return nil, err } @@ -175,6 +186,7 @@ func (r *keyMetadataRepo) List(ctx context.Context) ([]*model.KeyMetadata, error &m.PublicKeyG2, &m.CreatedAt, &m.UpdatedAt, + &m.Locked, ) if err != nil { return nil, err @@ -187,3 +199,18 @@ func (r *keyMetadataRepo) List(ctx context.Context) ([]*model.KeyMetadata, error } return metadata, nil } + +func (r *keyMetadataRepo) UpdateLockStatus( + ctx context.Context, + publicKeyG1 string, + locked bool, +) error { + updatedAt := time.Now().UTC() + + _, err := r.db.ExecContext(ctx, updateLockStatusQuery, + locked, + updatedAt, + publicKeyG1, + ) + return err +} diff --git a/internal/database/repository/postgres/key_metadata_test.go b/internal/database/repository/postgres/key_metadata_test.go index c4eca2c..6b8e797 100644 --- a/internal/database/repository/postgres/key_metadata_test.go +++ b/internal/database/repository/postgres/key_metadata_test.go @@ -279,17 +279,14 @@ func TestKeyMetadataRepository_UpdateAPIKeyHash(t *testing.T) { err := testDB.Repo.Create(context.Background(), initialKey) require.NoError(t, err) - metadata := &model.KeyMetadata{ - PublicKeyG1: "test_key_1", - ApiKeyHash: "test_api_key_hash_2", - } + apiKeyHash := "test_api_key_hash_2" - err = testDB.Repo.UpdateAPIKeyHash(context.Background(), metadata) + err = testDB.Repo.UpdateAPIKeyHash(context.Background(), initialKey.PublicKeyG1, apiKeyHash) require.NoError(t, err) // Verify the update - result, err := testDB.Repo.Get(context.Background(), metadata.PublicKeyG1) + result, err := testDB.Repo.Get(context.Background(), initialKey.PublicKeyG1) assert.NoError(t, err) - assert.Equal(t, metadata.ApiKeyHash, result.ApiKeyHash) + assert.Equal(t, apiKeyHash, result.ApiKeyHash) assert.WithinDuration(t, time.Now(), result.UpdatedAt, 2*time.Second) } diff --git a/internal/server/config.go b/internal/server/config.go new file mode 100644 index 0000000..591fbc1 --- /dev/null +++ b/internal/server/config.go @@ -0,0 +1,8 @@ +package server + +type GrpcServerConfig struct { + Port int + EnableTLS bool + TLSCACert string + TLSServerKey string +} diff --git a/internal/server/server.go b/internal/server/server.go index 65d7c6f..5f93805 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,31 +1,23 @@ package server import ( - "database/sql" + "context" "fmt" "log" "log/slog" "net" - "net/http" "os" - - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/collectors" - "github.com/prometheus/client_golang/prometheus/promhttp" + "os/signal" + "sync" + "syscall" + "time" v1 "github.com/Layr-Labs/cerberus-api/pkg/api/v1" "github.com/Layr-Labs/cerberus/internal/configuration" - "github.com/Layr-Labs/cerberus/internal/database" - "github.com/Layr-Labs/cerberus/internal/database/repository/postgres" - "github.com/Layr-Labs/cerberus/internal/metrics" - "github.com/Layr-Labs/cerberus/internal/middleware" + "github.com/Layr-Labs/cerberus/internal/services/admin" "github.com/Layr-Labs/cerberus/internal/services/kms" "github.com/Layr-Labs/cerberus/internal/services/signing" - "github.com/Layr-Labs/cerberus/internal/store" - "github.com/Layr-Labs/cerberus/internal/store/awssecretmanager" - "github.com/Layr-Labs/cerberus/internal/store/filesystem" - "github.com/Layr-Labs/cerberus/internal/store/googlesm" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -34,115 +26,185 @@ import ( _ "github.com/lib/pq" ) -func Start(config *configuration.Configuration, logger *slog.Logger) { - lis, err := net.Listen("tcp", fmt.Sprintf(":%s", config.GrpcPort)) - if err != nil { - logger.Error(fmt.Sprintf("Failed to listen: %v", err)) - os.Exit(1) - } +type Server struct { + resources *SharedResources + servers []*grpc.Server + wg sync.WaitGroup +} - registry := prometheus.NewRegistry() - registry.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})) - registry.MustRegister(collectors.NewGoCollector()) - rpcMetrics := metrics.NewRPCServerMetrics("cerberus", registry) - - go startMetricsServer(registry, config.MetricsPort, logger) - - var keystore store.Store - switch config.StorageType { - case configuration.FileSystemStorageType: - keystore = filesystem.NewStore(config.KeystoreDir, logger) - case configuration.AWSSecretManagerStorageType: - switch config.AWSAuthenticationMode { - case configuration.EnvironmentAWSAuthenticationMode: - keystore, err = awssecretmanager.NewStoreWithEnv( - config.AWSRegion, - config.AWSProfile, - logger, - ) - if err != nil { - logger.Error(fmt.Sprintf("Failed to create AWS Secret Manager store: %v", err)) - os.Exit(1) - } - logger.Info("Using environment credentials for AWS Secret Manager") - case configuration.SpecifiedAWSAuthenticationMode: - keystore, err = awssecretmanager.NewStoreWithSpecifiedCredentials( - config.AWSRegion, - config.AWSAccessKeyID, - config.AWSSecretAccessKey, - logger, - ) - if err != nil { - logger.Error(fmt.Sprintf("Failed to create AWS Secret Manager store: %v", err)) - os.Exit(1) - } - logger.Info("Using specified credentials for AWS Secret Manager") - } - case configuration.GoogleSecretManagerStorageType: - keystore, err = googlesm.NewKeystore(config.GCPProjectID, logger) - if err != nil { - logger.Error(fmt.Sprintf("Failed to create Google Secret Manager store: %v", err)) - os.Exit(1) - } - default: - logger.Error(fmt.Sprintf("Unsupported storage type: %s", config.StorageType)) - os.Exit(1) +// RegisterService represents a function type for service registration +type RegisterService func(*grpc.Server, *SharedResources) + +// AddServiceOnPort adds a new gRPC service on the specified port +func (s *Server) AddServiceOnPort( + serverCfg *GrpcServerConfig, + registerService RegisterService, +) error { + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", serverCfg.Port)) + if err != nil { + return fmt.Errorf("failed to listen on port %d: %v", serverCfg.Port, err) } var opts []grpc.ServerOption - if config.TLSCACert != "" && config.TLSServerKey != "" { - creds, err := credentials.NewServerTLSFromFile(config.TLSCACert, config.TLSServerKey) + if serverCfg.TLSCACert != "" && serverCfg.TLSServerKey != "" { + creds, err := credentials.NewServerTLSFromFile(serverCfg.TLSCACert, serverCfg.TLSServerKey) if err != nil { log.Fatalf("Failed to load TLS certificates: %v", err) } - logger.Info("Server-side TLS support enabled") + s.resources.Logger.Info("Server-side TLS support enabled") opts = append(opts, grpc.Creds(creds)) } - // Initialize database - db, err := sql.Open("postgres", config.PostgresDatabaseURL) - if err != nil { - logger.Error(fmt.Sprintf("Failed to connect to database: %v", err)) - os.Exit(1) + opts = append( + opts, + grpc.ChainUnaryInterceptor(s.resources.GrpcMiddleware...), + ) + + grpcServer := grpc.NewServer(opts...) + + // Register the service with shared resources + registerService(grpcServer, s.resources) + + // Add to servers list + s.servers = append(s.servers, grpcServer) + + // Start the server in a goroutine + s.wg.Add(1) + go func() { + defer s.wg.Done() + if err := grpcServer.Serve(lis); err != nil { + log.Printf("Failed to serve on port %d: %v", serverCfg.Port, err) + } + }() + + return nil +} + +// Start starts all registered services +func (s *Server) Start(ctx context.Context) error { + // Create channel for shutdown signals + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // Wait for either context cancellation or shutdown signal + select { + case <-ctx.Done(): + log.Println("Context cancelled, initiating shutdown...") + case sig := <-sigChan: + log.Printf("Received signal %v, initiating shutdown...", sig) } - defer db.Close() - if err := database.MigrateDB(config.PostgresDatabaseURL, logger); err != nil { - logger.Error(fmt.Sprintf("Failed to migrate database: %v", err)) - os.Exit(1) + // Create context with timeout for graceful shutdown + shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + var wg sync.WaitGroup + for _, server := range s.servers { + wg.Add(1) + go func(srv *grpc.Server) { + defer wg.Done() + srv.GracefulStop() + }(server) } - keyMetadataRepo := postgres.NewKeyMetadataRepository(db) + // Wait for all servers to stop + serverDone := make(chan struct{}) + go func() { + wg.Wait() + close(serverDone) + }() + + // Wait for server shutdown with timeout + select { + case <-serverDone: + log.Println("All gRPC servers stopped successfully") + case <-shutdownCtx.Done(): + log.Println("Shutdown timeout reached, forcing server stop") + for _, server := range s.servers { + server.Stop() + } + } - // Register metrics middleware - metricsMiddleware := middleware.NewMetricsMiddleware(registry, rpcMetrics) - authInterceptor := middleware.AuthInterceptor("signer.v1.Signer", keyMetadataRepo) - opts = append( - opts, - grpc.ChainUnaryInterceptor(metricsMiddleware.UnaryServerInterceptor(), authInterceptor), + return nil +} + +func Start(config *configuration.Configuration, logger *slog.Logger) { + server := NewServer(config, logger) + + kmsService := kms.NewService( + config, + server.resources.KeyStore, + server.resources.KeyMetadataRepo, + logger, + server.resources.RpcMetrics, ) + signingService := signing.NewService( + config, + server.resources.KeyStore, + logger, + server.resources.RpcMetrics, + ) + + logger.Info(fmt.Sprintf("Starting gRPC server on port %d...", config.GrpcPort)) + err := server.AddServiceOnPort(&GrpcServerConfig{ + Port: config.GrpcPort, + }, func(s *grpc.Server, resources *SharedResources) { + v1.RegisterKeyManagerServer(s, kmsService) + v1.RegisterSignerServer(s, signingService) + + // Register reflection service + reflection.Register(s) + }) - s := grpc.NewServer(opts...) - kmsService := kms.NewService(config, keystore, keyMetadataRepo, logger, rpcMetrics) - signingService := signing.NewService(config, keystore, logger, rpcMetrics) + if err != nil { + logger.Error(fmt.Sprintf("Failed to add service on port %d: %v", config.GrpcPort, err)) + os.Exit(1) + } + + if config.EnableAdmin { + adminService := admin.NewService( + config, + logger, + server.resources.RpcMetrics, + server.resources.KeyMetadataRepo, + ) - v1.RegisterKeyManagerServer(s, kmsService) - v1.RegisterSignerServer(s, signingService) + logger.Info(fmt.Sprintf("Starting Admin server on port %d...", config.AdminPort)) + err = server.AddServiceOnPort(&GrpcServerConfig{ + Port: config.AdminPort, + }, func(s *grpc.Server, resources *SharedResources) { + v1.RegisterAdminServer(s, adminService) - // Register the reflection service - reflection.Register(s) + // Register reflection service + reflection.Register(s) + }) - logger.Info(fmt.Sprintf("Starting gRPC server on port %s...", config.GrpcPort)) - if err := s.Serve(lis); err != nil { - log.Fatalf("Failed to serve: %v", err) + if err != nil { + logger.Error( + fmt.Sprintf("Failed to admin service on port %d: %v", config.AdminPort, err), + ) + os.Exit(1) + } + } + + // Start all services + // Create a context that can be cancelled + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Start all services + if err := server.Start(ctx); err != nil { + logger.Error(fmt.Sprintf("Failed to start servers: %v", err)) + os.Exit(1) } + } -func startMetricsServer(r *prometheus.Registry, port string, logger *slog.Logger) { - http.Handle("/metrics", promhttp.HandlerFor(r, promhttp.HandlerOpts{})) - logger.Info(fmt.Sprintf("Starting metrics server on port %s...", port)) - if err := http.ListenAndServe(fmt.Sprintf(":%s", port), nil); err != nil { - logger.Error(fmt.Sprintf("Failed to start metrics server: %v", err)) +// NewServer creates a new Server instance with shared resources +func NewServer(config *configuration.Configuration, logger *slog.Logger) *Server { + return &Server{ + resources: NewSharedResources(config, logger), + servers: make([]*grpc.Server, 0), } } diff --git a/internal/server/shared_resources.go b/internal/server/shared_resources.go new file mode 100644 index 0000000..f02fd86 --- /dev/null +++ b/internal/server/shared_resources.go @@ -0,0 +1,165 @@ +package server + +import ( + "database/sql" + "fmt" + "log/slog" + "net/http" + "os" + + "github.com/Layr-Labs/cerberus/internal/configuration" + "github.com/Layr-Labs/cerberus/internal/database" + "github.com/Layr-Labs/cerberus/internal/database/repository" + "github.com/Layr-Labs/cerberus/internal/database/repository/postgres" + "github.com/Layr-Labs/cerberus/internal/metrics" + "github.com/Layr-Labs/cerberus/internal/middleware" + "github.com/Layr-Labs/cerberus/internal/store" + "github.com/Layr-Labs/cerberus/internal/store/awssecretmanager" + "github.com/Layr-Labs/cerberus/internal/store/filesystem" + "github.com/Layr-Labs/cerberus/internal/store/googlesm" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/collectors" + "github.com/prometheus/client_golang/prometheus/promhttp" + + "google.golang.org/grpc" + + _ "github.com/lib/pq" +) + +type SharedResources struct { + KeyMetadataRepo repository.KeyMetadataRepository + KeyStore store.Store + GrpcMiddleware []grpc.UnaryServerInterceptor + RpcMetrics *metrics.RPCServerMetrics + Logger *slog.Logger + + // Private fields + db *sql.DB +} + +func NewSharedResources( + config *configuration.Configuration, + logger *slog.Logger, +) *SharedResources { + + // Initialize store + keystore, err := initializeStore(config, logger) + if err != nil { + logger.Error(fmt.Sprintf("Failed to initialize store: %v", err)) + os.Exit(1) + } + + // Initialize database + db, err := initializeDatabase(config, logger) + if err != nil { + logger.Error(fmt.Sprintf("Failed to initialize database: %v", err)) + os.Exit(1) + } + + // Initialize key metadata repository + keyMetadataRepo := postgres.NewKeyMetadataRepository(db) + + // Initialize prometheus registry + registry := prometheus.NewRegistry() + registry.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})) + registry.MustRegister(collectors.NewGoCollector()) + rpcMetrics := metrics.NewRPCServerMetrics("cerberus", registry) + + // Start metrics server + go startMetricsServer(registry, config.MetricsPort, logger) + + // Initialize grpc middleware + grpcMiddleware := initializeGrpcMiddleware(registry, rpcMetrics, keyMetadataRepo) + + return &SharedResources{ + db: db, + KeyMetadataRepo: keyMetadataRepo, + KeyStore: keystore, + GrpcMiddleware: grpcMiddleware, + RpcMetrics: rpcMetrics, + Logger: logger, + } +} + +func initializeDatabase( + config *configuration.Configuration, + logger *slog.Logger, +) (*sql.DB, error) { + db, err := sql.Open("postgres", config.PostgresDatabaseURL) + if err != nil { + return nil, fmt.Errorf("failed to connect to database: %w", err) + } + + if err := database.MigrateDB(config.PostgresDatabaseURL, logger); err != nil { + return nil, fmt.Errorf("failed to migrate database: %w", err) + } + + return db, nil +} + +func initializeStore( + config *configuration.Configuration, + logger *slog.Logger, +) (store.Store, error) { + var keystore store.Store + var err error + switch config.StorageType { + case configuration.FileSystemStorageType: + keystore = filesystem.NewStore(config.KeystoreDir, logger) + case configuration.AWSSecretManagerStorageType: + switch config.AWSAuthenticationMode { + case configuration.EnvironmentAWSAuthenticationMode: + keystore, err = awssecretmanager.NewStoreWithEnv( + config.AWSRegion, + config.AWSProfile, + logger, + ) + if err != nil { + return nil, fmt.Errorf("failed to create AWS Secret Manager store: %w", err) + } + logger.Info("Using environment credentials for AWS Secret Manager") + case configuration.SpecifiedAWSAuthenticationMode: + keystore, err = awssecretmanager.NewStoreWithSpecifiedCredentials( + config.AWSRegion, + config.AWSAccessKeyID, + config.AWSSecretAccessKey, + logger, + ) + if err != nil { + return nil, fmt.Errorf("failed to create AWS Secret Manager store: %w", err) + } + logger.Info("Using specified credentials for AWS Secret Manager") + } + case configuration.GoogleSecretManagerStorageType: + keystore, err = googlesm.NewKeystore(config.GCPProjectID, logger) + if err != nil { + return nil, fmt.Errorf("failed to create Google Secret Manager store: %w", err) + } + default: + return nil, fmt.Errorf("unsupported storage type: %s", config.StorageType) + } + + return keystore, nil +} + +func initializeGrpcMiddleware( + registry *prometheus.Registry, + rpcMetrics *metrics.RPCServerMetrics, + keyMetadataRepo repository.KeyMetadataRepository, +) []grpc.UnaryServerInterceptor { + metricsMiddleware := middleware.NewMetricsMiddleware(registry, rpcMetrics) + authInterceptor := middleware.AuthInterceptor("signer.v1.Signer", keyMetadataRepo) + return []grpc.UnaryServerInterceptor{ + metricsMiddleware.UnaryServerInterceptor(), + authInterceptor, + } +} + +func startMetricsServer(r *prometheus.Registry, port int, logger *slog.Logger) { + http.Handle("/metrics", promhttp.HandlerFor(r, promhttp.HandlerOpts{})) + logger.Info(fmt.Sprintf("Starting metrics server on port %d...", port)) + if err := http.ListenAndServe(fmt.Sprintf(":%d", port), nil); err != nil { + logger.Error(fmt.Sprintf("Failed to start metrics server: %v", err)) + } +} diff --git a/internal/services/admin/admin.go b/internal/services/admin/admin.go new file mode 100644 index 0000000..3a2da50 --- /dev/null +++ b/internal/services/admin/admin.go @@ -0,0 +1,109 @@ +package admin + +import ( + "context" + "log/slog" + "time" + + v1 "github.com/Layr-Labs/cerberus-api/pkg/api/v1" + "github.com/Layr-Labs/cerberus/internal/common" + "github.com/Layr-Labs/cerberus/internal/configuration" + "github.com/Layr-Labs/cerberus/internal/database/repository" + "github.com/Layr-Labs/cerberus/internal/metrics" +) + +var _ v1.AdminServer = (*Service)(nil) + +type Service struct { + config *configuration.Configuration + logger *slog.Logger + metrics metrics.Recorder + keyMetadataRepo repository.KeyMetadataRepository + + v1.UnimplementedAdminServer +} + +func NewService( + config *configuration.Configuration, + logger *slog.Logger, + metrics metrics.Recorder, + keyMetadataRepo repository.KeyMetadataRepository, +) *Service { + return &Service{ + config: config, + logger: logger.With("component", "admin"), + metrics: metrics, + keyMetadataRepo: keyMetadataRepo, + } +} + +func (s *Service) GenerateNewApiKey( + ctx context.Context, + req *v1.GenerateNewApiKeyRequest, +) (*v1.GenerateNewApiKeyResponse, error) { + metadata, err := s.keyMetadataRepo.Get(ctx, req.PublicKeyG1) + if err != nil { + return nil, err + } + + apiKey, apiKeyHash, err := common.GenerateNewAPIKeyAndHash() + if err != nil { + return nil, err + } + + err = s.keyMetadataRepo.UpdateAPIKeyHash(ctx, metadata.PublicKeyG1, apiKeyHash) + if err != nil { + return nil, err + } + + return &v1.GenerateNewApiKeyResponse{ + ApiKey: apiKey, + PublicKeyG1: metadata.PublicKeyG1, + }, nil +} + +func (s *Service) LockKey( + ctx context.Context, + req *v1.LockKeyRequest, +) (*v1.LockKeyResponse, error) { + err := s.keyMetadataRepo.UpdateLockStatus(ctx, req.PublicKeyG1, true) + if err != nil { + return nil, err + } + return &v1.LockKeyResponse{}, nil +} + +func (s *Service) UnlockKey( + ctx context.Context, + req *v1.UnlockKeyRequest, +) (*v1.UnlockKeyResponse, error) { + err := s.keyMetadataRepo.UpdateLockStatus(ctx, req.PublicKeyG1, false) + if err != nil { + return nil, err + } + return &v1.UnlockKeyResponse{}, nil +} + +func (s *Service) ListAllKeys( + ctx context.Context, + req *v1.ListAllKeysRequest, +) (*v1.ListAllKeysResponse, error) { + keys, err := s.keyMetadataRepo.List(ctx) + if err != nil { + return nil, err + } + + response := &v1.ListAllKeysResponse{ + Keys: make([]*v1.KeyMetadata, 0, len(keys)), + } + for _, key := range keys { + response.Keys = append(response.Keys, &v1.KeyMetadata{ + PublicKeyG1: key.PublicKeyG1, + PublicKeyG2: key.PublicKeyG2, + CreatedAt: key.CreatedAt.Format(time.RFC3339), + UpdatedAt: key.UpdatedAt.Format(time.RFC3339), + Locked: key.Locked, + }) + } + return response, nil +} diff --git a/internal/services/kms/kms.go b/internal/services/kms/kms.go index 9eaf1b5..5642d81 100644 --- a/internal/services/kms/kms.go +++ b/internal/services/kms/kms.go @@ -8,7 +8,6 @@ import ( "math/big" v1 "github.com/Layr-Labs/cerberus-api/pkg/api/v1" - "github.com/google/uuid" "github.com/Layr-Labs/cerberus/internal/common" "github.com/Layr-Labs/cerberus/internal/configuration" @@ -77,7 +76,7 @@ func (k *Service) GenerateKeyPair( } // Generate a new API key and hash - apiKey, apiKeyHash, err := generateNewAPIKeyAndHash() + apiKey, apiKeyHash, err := common.GenerateNewAPIKeyAndHash() if err != nil { k.logger.Error(fmt.Sprintf("Failed to generate API key: %v", err)) return nil, status.Error(codes.Internal, err.Error()) @@ -160,7 +159,7 @@ func (k *Service) ImportKey( } // Generate a new API key and hash - apiKey, apiKeyHash, err := generateNewAPIKeyAndHash() + apiKey, apiKeyHash, err := common.GenerateNewAPIKeyAndHash() if err != nil { k.logger.Error(fmt.Sprintf("Failed to generate API key: %v", err)) return nil, status.Error(codes.Internal, err.Error()) @@ -219,13 +218,3 @@ func (k *Service) GetKeyMetadata( UpdatedAt: metadata.UpdatedAt.Unix(), }, nil } - -func generateNewAPIKeyAndHash() (string, string, error) { - newUUID, err := uuid.NewV7() - if err != nil { - return "", "", err - } - apiKey := newUUID.String() - apiKeyHash := common.CreateSHA256Hash(apiKey) - return apiKey, apiKeyHash, nil -} From af80c40cb885d4688f524573505a5e659a1c1427 Mon Sep 17 00:00:00 2001 From: Madhur Shrimal Date: Fri, 17 Jan 2025 10:46:58 -0800 Subject: [PATCH 3/4] return error if key exists --- internal/database/repository/errors.go | 7 ++++ .../repository/postgres/key_metadata.go | 2 +- internal/services/kms/kms.go | 33 ++++++++++++++----- 3 files changed, 32 insertions(+), 10 deletions(-) create mode 100644 internal/database/repository/errors.go diff --git a/internal/database/repository/errors.go b/internal/database/repository/errors.go new file mode 100644 index 0000000..f3d8314 --- /dev/null +++ b/internal/database/repository/errors.go @@ -0,0 +1,7 @@ +package repository + +import "errors" + +var ( + ErrKeyNotFound = errors.New("key not found") +) diff --git a/internal/database/repository/postgres/key_metadata.go b/internal/database/repository/postgres/key_metadata.go index c601447..4fc3f98 100644 --- a/internal/database/repository/postgres/key_metadata.go +++ b/internal/database/repository/postgres/key_metadata.go @@ -96,7 +96,7 @@ func (r *keyMetadataRepo) Get(ctx context.Context, publicKeyG1 string) (*model.K &metadata.Locked, ) if err == sql.ErrNoRows { - return nil, errors.New("key metadata not found") + return nil, repository.ErrKeyNotFound } if err != nil { return nil, err diff --git a/internal/services/kms/kms.go b/internal/services/kms/kms.go index 5642d81..c600157 100644 --- a/internal/services/kms/kms.go +++ b/internal/services/kms/kms.go @@ -139,25 +139,40 @@ func (k *Service) ImportKey( } } - pubKeyHex, err := k.store.StoreKey( - ctx, - &keystore.KeyPair{PrivateKey: pkBytes, Password: password}, - ) - if err != nil { - k.logger.Error(fmt.Sprintf("Failed to save BLS key pair to file: %v", err)) - return nil, status.Error(codes.Internal, err.Error()) - } - ks := &keystore.KeyPair{ PrivateKey: pkBytes, } + g1PubKey, err := ks.GetG1PublicKey(curve.BN254) + if err != nil { + k.logger.Error(fmt.Sprintf("Failed to get G1 public key: %v", err)) + return nil, status.Error(codes.Internal, err.Error()) + } + g2PubKey, err := ks.GetG2PublicKey(curve.BN254) if err != nil { k.logger.Error(fmt.Sprintf("Failed to get G2 public key: %v", err)) return nil, status.Error(codes.Internal, err.Error()) } + _, err = k.keyMetadataRepo.Get(ctx, g1PubKey) + if err == nil { + return nil, status.Error(codes.AlreadyExists, "key already exists") + } + if err != repository.ErrKeyNotFound { + k.logger.Error(fmt.Sprintf("Failed to get key metadata: %v", err)) + return nil, status.Error(codes.Internal, err.Error()) + } + + pubKeyHex, err := k.store.StoreKey( + ctx, + &keystore.KeyPair{PrivateKey: pkBytes, Password: password}, + ) + if err != nil { + k.logger.Error(fmt.Sprintf("Failed to save BLS key pair to file: %v", err)) + return nil, status.Error(codes.Internal, err.Error()) + } + // Generate a new API key and hash apiKey, apiKeyHash, err := common.GenerateNewAPIKeyAndHash() if err != nil { From 4ca423d719b4b471c3d6464c504a836a5db695ca Mon Sep 17 00:00:00 2001 From: Madhur Shrimal Date: Fri, 17 Jan 2025 11:38:02 -0800 Subject: [PATCH 4/4] get master commit --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 04febbe..5b04e49 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ toolchain go1.22.3 require ( cloud.google.com/go/secretmanager v1.14.2 github.com/Layr-Labs/bn254-keystore-go v0.0.0-20250107020618-26bd412fae87 - github.com/Layr-Labs/cerberus-api v0.0.2-0.20250117015901-0b1220ea735f + github.com/Layr-Labs/cerberus-api v0.0.2-0.20250117193600-e69c5e8b08fd github.com/aws/aws-sdk-go-v2 v1.32.5 github.com/aws/aws-sdk-go-v2/config v1.28.5 github.com/aws/aws-sdk-go-v2/credentials v1.17.46 diff --git a/go.sum b/go.sum index c7561b6..d8a1096 100644 --- a/go.sum +++ b/go.sum @@ -20,8 +20,8 @@ github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg6 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Layr-Labs/bn254-keystore-go v0.0.0-20250107020618-26bd412fae87 h1:EkaBNT0o8RTgtFeYSKaoNHNbnCVxrcsAyRpUeN29hiQ= github.com/Layr-Labs/bn254-keystore-go v0.0.0-20250107020618-26bd412fae87/go.mod h1:7J8hptSX8cFq7KmVb+rEO5aEifj7E44c3i0afIyr4WA= -github.com/Layr-Labs/cerberus-api v0.0.2-0.20250117015901-0b1220ea735f h1:Od50IBPPjsAF9w8QGUsTFcFwmKA0UEH9zNtUN0PDM68= -github.com/Layr-Labs/cerberus-api v0.0.2-0.20250117015901-0b1220ea735f/go.mod h1:Lm4fhzy0S3P7GjerzuseGaBFVczsIKmEhIjcT52Hluo= +github.com/Layr-Labs/cerberus-api v0.0.2-0.20250117193600-e69c5e8b08fd h1:prMzW4BY6KZtWEanf5EIsyHzIZKCNV2mVIXrE6glRRM= +github.com/Layr-Labs/cerberus-api v0.0.2-0.20250117193600-e69c5e8b08fd/go.mod h1:Lm4fhzy0S3P7GjerzuseGaBFVczsIKmEhIjcT52Hluo= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/aws/aws-sdk-go-v2 v1.32.5 h1:U8vdWJuY7ruAkzaOdD7guwJjD06YSKmnKCJs7s3IkIo=