Skip to content

Commit

Permalink
feat: add payment request validation and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hopeyen committed Nov 22, 2024
1 parent f9ad944 commit ee752c0
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 6 deletions.
96 changes: 96 additions & 0 deletions core/auth/v2/auth_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
package v2_test

import (
"crypto/sha256"
"math/big"
"testing"

pb "github.com/Layr-Labs/eigenda/api/grpc/disperser/v2"
"github.com/Layr-Labs/eigenda/core"
auth "github.com/Layr-Labs/eigenda/core/auth/v2"
corev2 "github.com/Layr-Labs/eigenda/core/v2"
"github.com/Layr-Labs/eigenda/encoding"
"github.com/consensys/gnark-crypto/ecc/bn254/fp"
"github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -114,3 +117,96 @@ func testHeader(t *testing.T, accountID string) *corev2.BlobHeader {
Signature: []byte{},
}
}

func TestAuthenticatePaymentStateRequestValid(t *testing.T) {
signer := auth.NewLocalBlobRequestSigner(privateKeyHex)
authenticator := auth.NewAuthenticator()

signature, err := signer.SignPaymentStateRequest()
assert.NoError(t, err)

accountId, err := signer.GetAccountID()

assert.NoError(t, err)

request := &pb.GetPaymentStateRequest{
AccountId: accountId,
Signature: signature,
}

err = authenticator.AuthenticatePaymentStateRequest(request)
assert.NoError(t, err)
}

func TestAuthenticatePaymentStateRequestInvalidSignatureLength(t *testing.T) {
authenticator := auth.NewAuthenticator()

request := &pb.GetPaymentStateRequest{
AccountId: "0x123",
Signature: []byte{1, 2, 3}, // Invalid length
}

err := authenticator.AuthenticatePaymentStateRequest(request)
assert.Error(t, err)
assert.Contains(t, err.Error(), "signature length is unexpected")
}

func TestAuthenticatePaymentStateRequestInvalidPublicKey(t *testing.T) {
authenticator := auth.NewAuthenticator()

request := &pb.GetPaymentStateRequest{
AccountId: "not-hex-encoded",
Signature: make([]byte, 65),
}

err := authenticator.AuthenticatePaymentStateRequest(request)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to decode public key")
}

func TestAuthenticatePaymentStateRequestSignatureMismatch(t *testing.T) {
signer := auth.NewLocalBlobRequestSigner(privateKeyHex)
authenticator := auth.NewAuthenticator()

// Create a different signer with wrong private key
wrongSigner := auth.NewLocalBlobRequestSigner("0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcded")

// Sign with wrong key
accountId, err := signer.GetAccountID()
assert.NoError(t, err)

signature, err := wrongSigner.SignPaymentStateRequest()
assert.NoError(t, err)

request := &pb.GetPaymentStateRequest{
AccountId: accountId,
Signature: signature,
}

err = authenticator.AuthenticatePaymentStateRequest(request)
assert.Error(t, err)
assert.Contains(t, err.Error(), "signature doesn't match with provided public key")
}

func TestAuthenticatePaymentStateRequestCorruptedSignature(t *testing.T) {
signer := auth.NewLocalBlobRequestSigner(privateKeyHex)
authenticator := auth.NewAuthenticator()

accountId, err := signer.GetAccountID()
assert.NoError(t, err)

hash := sha256.Sum256([]byte(accountId))
signature, err := crypto.Sign(hash[:], signer.PrivateKey)
assert.NoError(t, err)

// Corrupt the signature
signature[0] ^= 0x01

request := &pb.GetPaymentStateRequest{
AccountId: accountId,
Signature: signature,
}

err = authenticator.AuthenticatePaymentStateRequest(request)
assert.Error(t, err)
}
38 changes: 36 additions & 2 deletions core/auth/v2/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package v2

import (
"bytes"
"crypto/sha256"
"errors"
"fmt"

pb "github.com/Layr-Labs/eigenda/api/grpc/disperser/v2"
core "github.com/Layr-Labs/eigenda/core/v2"

"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto"
)
Expand Down Expand Up @@ -40,7 +41,7 @@ func (*authenticator) AuthenticateBlobRequest(header *core.BlobHeader) error {
// Decode public key
pubKey, err := crypto.UnmarshalPubkey(publicKeyBytes)
if err != nil {
return fmt.Errorf("failed to decode public key (%v): %v", header.PaymentMetadata.AccountID, err)
return fmt.Errorf("failed to convert bytes to public key (%v): %v", header.PaymentMetadata.AccountID, err)
}

// Verify the signature
Expand All @@ -55,3 +56,36 @@ func (*authenticator) AuthenticateBlobRequest(header *core.BlobHeader) error {

return nil
}

func (*authenticator) AuthenticatePaymentStateRequest(request *pb.GetPaymentStateRequest) error {
// Ensure the signature is 65 bytes (Recovery ID is the last byte)
sig := request.GetSignature()
if len(sig) != 65 {
return fmt.Errorf("signature length is unexpected: %d", len(sig))
}

// Decode public key
publicKeyBytes, err := hexutil.Decode(request.GetAccountId())
if err != nil {
return fmt.Errorf("failed to decode public key (%v): %v", request.GetAccountId(), err)
}

// Convert bytes to public key
pubKey, err := crypto.UnmarshalPubkey(publicKeyBytes)
if err != nil {
return fmt.Errorf("failed to convert bytes to public key (%v): %v", request.GetAccountId(), err)
}

// Verify the signature
hash := sha256.Sum256([]byte(request.GetAccountId()))
sigPublicKeyECDSA, err := crypto.SigToPub(hash[:], sig)
if err != nil {
return fmt.Errorf("failed to recover public key from signature: %v", err)
}

if !bytes.Equal(pubKey.X.Bytes(), sigPublicKeyECDSA.X.Bytes()) || !bytes.Equal(pubKey.Y.Bytes(), sigPublicKeyECDSA.Y.Bytes()) {
return errors.New("signature doesn't match with provided public key")
}

return nil
}
21 changes: 21 additions & 0 deletions core/auth/v2/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package v2

import (
"crypto/ecdsa"
"crypto/sha256"
"fmt"
"log"

Expand Down Expand Up @@ -44,6 +45,22 @@ func (s *LocalBlobRequestSigner) SignBlobRequest(header *core.BlobHeader) ([]byt
return sig, nil
}

func (s *LocalBlobRequestSigner) SignPaymentStateRequest() ([]byte, error) {
accountId, err := s.GetAccountID()
if err != nil {
return nil, fmt.Errorf("failed to get account ID: %v", err)
}

hash := sha256.Sum256([]byte(accountId))
// Sign the blob key using the private key
sig, err := crypto.Sign(hash[:], s.PrivateKey)
if err != nil {
return nil, fmt.Errorf("failed to sign hash: %v", err)
}

return sig, nil
}

func (s *LocalBlobRequestSigner) GetAccountID() (string, error) {

publicKeyBytes := crypto.FromECDSAPub(&s.PrivateKey.PublicKey)
Expand All @@ -63,6 +80,10 @@ func (s *LocalNoopSigner) SignBlobRequest(header *core.BlobHeader) ([]byte, erro
return nil, fmt.Errorf("noop signer cannot sign blob request")
}

func (s *LocalNoopSigner) SignPaymentStateRequest() ([]byte, error) {
return nil, fmt.Errorf("noop signer cannot sign payment state request")
}

func (s *LocalNoopSigner) GetAccountID() (string, error) {
return "", fmt.Errorf("noop signer cannot get accountID")
}
4 changes: 4 additions & 0 deletions core/v2/auth.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package v2

import pb "github.com/Layr-Labs/eigenda/api/grpc/disperser/v2"

type BlobRequestAuthenticator interface {
AuthenticateBlobRequest(header *BlobHeader) error
AuthenticatePaymentStateRequest(request *pb.GetPaymentStateRequest) error
}

type BlobRequestSigner interface {
SignBlobRequest(header *BlobHeader) ([]byte, error)
SignPaymentStateRequest() ([]byte, error)
GetAccountID() (string, error)
}
6 changes: 2 additions & 4 deletions disperser/apiserver/server_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"github.com/Layr-Labs/eigenda/common"
healthcheck "github.com/Layr-Labs/eigenda/common/healthcheck"
"github.com/Layr-Labs/eigenda/core"
"github.com/Layr-Labs/eigenda/core/auth"
"github.com/Layr-Labs/eigenda/core/meterer"
corev2 "github.com/Layr-Labs/eigenda/core/v2"
"github.com/Layr-Labs/eigenda/disperser"
Expand Down Expand Up @@ -205,10 +204,9 @@ func (s *DispersalServerV2) RefreshOnchainState(ctx context.Context) error {

func (s *DispersalServerV2) GetPaymentState(ctx context.Context, req *pb.GetPaymentStateRequest) (*pb.GetPaymentStateReply, error) {
// validate the signature
if !auth.VerifyAccountSignature(req.AccountId, req.Signature) {
return nil, api.NewErrorInvalidArg("invalid signature")
if err := s.authenticator.AuthenticatePaymentStateRequest(req); err != nil {
return nil, api.NewErrorInvalidArg(fmt.Sprintf("authentication failed: %s", err.Error()))
}

// on-chain global payment parameters
globalSymbolsPerSecond := s.meterer.ChainPaymentState.GetGlobalSymbolsPerSecond()
minNumSymbols := s.meterer.ChainPaymentState.GetMinNumSymbols()
Expand Down

0 comments on commit ee752c0

Please sign in to comment.