diff --git a/api/clients/accountant.go b/api/clients/accountant.go index 375494166b..bd8192b31e 100644 --- a/api/clients/accountant.go +++ b/api/clients/accountant.go @@ -8,20 +8,14 @@ import ( "sync" "time" - commonpb "github.com/Layr-Labs/eigenda/api/grpc/common" + disperser_rpc "github.com/Layr-Labs/eigenda/api/grpc/disperser/v2" "github.com/Layr-Labs/eigenda/core" "github.com/Layr-Labs/eigenda/core/meterer" ) var requiredQuorums = []uint8{0, 1} -type Accountant interface { - AccountBlob(ctx context.Context, numSymbols uint64, quorums []uint8) (*commonpb.PaymentHeader, error) -} - -var _ Accountant = &accountant{} - -type accountant struct { +type Accountant struct { // on-chain states accountID string reservation *core.ActiveReservation @@ -45,7 +39,7 @@ type BinRecord struct { Usage uint64 } -func NewAccountant(accountID string, reservation *core.ActiveReservation, onDemand *core.OnDemandPayment, reservationWindow uint32, pricePerSymbol uint32, minNumSymbols uint32, numBins uint32) *accountant { +func NewAccountant(accountID string, reservation *core.ActiveReservation, onDemand *core.OnDemandPayment, reservationWindow uint32, pricePerSymbol uint32, minNumSymbols uint32, numBins uint32) *Accountant { //TODO: client storage; currently every instance starts fresh but on-chain or a small store makes more sense // Also client is currently responsible for supplying network params, we need to add RPC in order to be automatic // There's a subsequent PR that handles populating the accountant with on-chain state from the disperser @@ -53,7 +47,7 @@ func NewAccountant(accountID string, reservation *core.ActiveReservation, onDema for i := range binRecords { binRecords[i] = BinRecord{Index: uint32(i), Usage: 0} } - a := accountant{ + a := Accountant{ accountID: accountID, reservation: reservation, onDemand: onDemand, @@ -73,7 +67,7 @@ func NewAccountant(accountID string, reservation *core.ActiveReservation, onDema // then on-demand if the reservation is not available. The returned values are // bin index for reservation payments and cumulative payment for on-demand payments, // and both fields are used to create the payment header and signature -func (a *accountant) BlobPaymentInfo(ctx context.Context, numSymbols uint64, quorumNumbers []uint8) (uint32, *big.Int, error) { +func (a *Accountant) BlobPaymentInfo(ctx context.Context, numSymbols uint64, quorumNumbers []uint8) (uint32, *big.Int, error) { now := time.Now().Unix() currentBinIndex := meterer.GetBinIndex(uint64(now), a.reservationWindow) @@ -116,7 +110,7 @@ func (a *accountant) BlobPaymentInfo(ctx context.Context, numSymbols uint64, quo } // AccountBlob accountant provides and records payment information -func (a *accountant) AccountBlob(ctx context.Context, numSymbols uint64, quorums []uint8) (*commonpb.PaymentHeader, error) { +func (a *Accountant) AccountBlob(ctx context.Context, numSymbols uint64, quorums []uint8) (*core.PaymentMetadata, error) { binIndex, cumulativePayment, err := a.BlobPaymentInfo(ctx, numSymbols, quorums) if err != nil { return nil, err @@ -127,20 +121,19 @@ func (a *accountant) AccountBlob(ctx context.Context, numSymbols uint64, quorums BinIndex: binIndex, CumulativePayment: cumulativePayment, } - protoPaymentHeader := pm.ConvertToProtoPaymentHeader() - return protoPaymentHeader, nil + return pm, nil } // TODO: PaymentCharged and SymbolsCharged copied from meterer, should be refactored // PaymentCharged returns the chargeable price for a given data length -func (a *accountant) PaymentCharged(numSymbols uint) uint64 { +func (a *Accountant) PaymentCharged(numSymbols uint) uint64 { return uint64(a.SymbolsCharged(numSymbols)) * uint64(a.pricePerSymbol) } // SymbolsCharged returns the number of symbols charged for a given data length // being at least MinNumSymbols or the nearest rounded-up multiple of MinNumSymbols. -func (a *accountant) SymbolsCharged(numSymbols uint) uint32 { +func (a *Accountant) SymbolsCharged(numSymbols uint) uint32 { if numSymbols <= uint(a.minNumSymbols) { return a.minNumSymbols } @@ -148,7 +141,7 @@ func (a *accountant) SymbolsCharged(numSymbols uint) uint32 { return uint32(core.RoundUpDivide(uint(numSymbols), uint(a.minNumSymbols))) * a.minNumSymbols } -func (a *accountant) GetRelativeBinRecord(index uint32) *BinRecord { +func (a *Accountant) GetRelativeBinRecord(index uint32) *BinRecord { relativeIndex := index % a.numBins if a.binRecords[relativeIndex].Index != uint32(index) { a.binRecords[relativeIndex] = BinRecord{ @@ -160,6 +153,59 @@ func (a *accountant) GetRelativeBinRecord(index uint32) *BinRecord { return &a.binRecords[relativeIndex] } +func (a *Accountant) SetPaymentState(paymentState *disperser_rpc.GetPaymentStateReply) error { + if paymentState == nil { + return fmt.Errorf("payment state cannot be nil") + } else if paymentState.GetPaymentGlobalParams() == nil { + return fmt.Errorf("payment global params cannot be nil") + } else if paymentState.GetOnchainCumulativePayment() == nil { + return fmt.Errorf("onchain cumulative payment cannot be nil") + } else if paymentState.GetCumulativePayment() == nil { + return fmt.Errorf("cumulative payment cannot be nil") + } else if paymentState.GetReservation() == nil { + return fmt.Errorf("reservation cannot be nil") + } else if paymentState.GetReservation().GetQuorumNumbers() == nil { + return fmt.Errorf("reservation quorum numbers cannot be nil") + } else if paymentState.GetReservation().GetQuorumSplit() == nil { + return fmt.Errorf("reservation quorum split cannot be nil") + } else if paymentState.GetBinRecords() == nil { + return fmt.Errorf("bin records cannot be nil") + } + + a.minNumSymbols = uint32(paymentState.PaymentGlobalParams.MinNumSymbols) + a.onDemand.CumulativePayment = new(big.Int).SetBytes(paymentState.OnchainCumulativePayment) + a.cumulativePayment = new(big.Int).SetBytes(paymentState.CumulativePayment) + a.pricePerSymbol = uint32(paymentState.PaymentGlobalParams.PricePerSymbol) + + a.reservation.SymbolsPerSec = uint64(paymentState.PaymentGlobalParams.GlobalSymbolsPerSecond) + a.reservation.StartTimestamp = uint64(paymentState.Reservation.StartTimestamp) + a.reservation.EndTimestamp = uint64(paymentState.Reservation.EndTimestamp) + a.reservationWindow = uint32(paymentState.PaymentGlobalParams.ReservationWindow) + + quorumNumbers := make([]uint8, len(paymentState.Reservation.QuorumNumbers)) + for i, quorum := range paymentState.Reservation.QuorumNumbers { + quorumNumbers[i] = uint8(quorum) + } + a.reservation.QuorumNumbers = quorumNumbers + + quorumSplit := make([]uint8, len(paymentState.Reservation.QuorumSplit)) + for i, quorum := range paymentState.Reservation.QuorumSplit { + quorumSplit[i] = uint8(quorum) + } + a.reservation.QuorumSplit = quorumSplit + + binRecords := make([]BinRecord, len(paymentState.BinRecords)) + for i, record := range paymentState.BinRecords { + binRecords[i] = BinRecord{ + Index: record.Index, + Usage: record.Usage, + } + } + a.binRecords = binRecords + + return nil +} + // QuorumCheck eagerly returns error if the check finds a quorum number not an element of the allowed quorum numbers func QuorumCheck(quorumNumbers []uint8, allowedNumbers []uint8) error { if len(quorumNumbers) == 0 { diff --git a/api/clients/accountant_test.go b/api/clients/accountant_test.go index 979c087925..09b880664c 100644 --- a/api/clients/accountant_test.go +++ b/api/clients/accountant_test.go @@ -71,30 +71,27 @@ func TestAccountBlob_Reservation(t *testing.T) { quorums := []uint8{0, 1} header, err := accountant.AccountBlob(ctx, symbolLength, quorums) - metadata := core.ConvertPaymentHeader(header) assert.NoError(t, err) assert.Equal(t, meterer.GetBinIndex(uint64(time.Now().Unix()), reservationWindow), header.BinIndex) - assert.Equal(t, big.NewInt(0), metadata.CumulativePayment) + assert.Equal(t, big.NewInt(0), header.CumulativePayment) assert.Equal(t, isRotation([]uint64{500, 0, 0}, mapRecordUsage(accountant.binRecords)), true) symbolLength = uint64(700) header, err = accountant.AccountBlob(ctx, symbolLength, quorums) - metadata = core.ConvertPaymentHeader(header) assert.NoError(t, err) assert.NotEqual(t, 0, header.BinIndex) - assert.Equal(t, big.NewInt(0), metadata.CumulativePayment) + assert.Equal(t, big.NewInt(0), header.CumulativePayment) assert.Equal(t, isRotation([]uint64{1200, 0, 200}, mapRecordUsage(accountant.binRecords)), true) // Second call should use on-demand payment header, err = accountant.AccountBlob(ctx, 300, quorums) - metadata = core.ConvertPaymentHeader(header) assert.NoError(t, err) assert.Equal(t, uint32(0), header.BinIndex) - assert.Equal(t, big.NewInt(300), metadata.CumulativePayment) + assert.Equal(t, big.NewInt(300), header.CumulativePayment) } func TestAccountBlob_OnDemand(t *testing.T) { @@ -124,10 +121,9 @@ func TestAccountBlob_OnDemand(t *testing.T) { header, err := accountant.AccountBlob(ctx, numSymbols, quorums) assert.NoError(t, err) - metadata := core.ConvertPaymentHeader(header) expectedPayment := big.NewInt(int64(numSymbols * uint64(pricePerSymbol))) assert.Equal(t, uint32(0), header.BinIndex) - assert.Equal(t, expectedPayment, metadata.CumulativePayment) + assert.Equal(t, expectedPayment, header.CumulativePayment) assert.Equal(t, isRotation([]uint64{0, 0, 0}, mapRecordUsage(accountant.binRecords)), true) assert.Equal(t, expectedPayment, accountant.cumulativePayment) } @@ -180,24 +176,21 @@ func TestAccountBlobCallSeries(t *testing.T) { // First call: Use reservation header, err := accountant.AccountBlob(ctx, 800, quorums) - metadata := core.ConvertPaymentHeader(header) assert.NoError(t, err) assert.Equal(t, meterer.GetBinIndex(uint64(now), reservationWindow), header.BinIndex) - assert.Equal(t, big.NewInt(0), metadata.CumulativePayment) + assert.Equal(t, big.NewInt(0), header.CumulativePayment) // Second call: Use remaining reservation + overflow header, err = accountant.AccountBlob(ctx, 300, quorums) - metadata = core.ConvertPaymentHeader(header) assert.NoError(t, err) assert.Equal(t, meterer.GetBinIndex(uint64(now), reservationWindow), header.BinIndex) - assert.Equal(t, big.NewInt(0), metadata.CumulativePayment) + assert.Equal(t, big.NewInt(0), header.CumulativePayment) // Third call: Use on-demand header, err = accountant.AccountBlob(ctx, 500, quorums) - metadata = core.ConvertPaymentHeader(header) assert.NoError(t, err) assert.Equal(t, uint32(0), header.BinIndex) - assert.Equal(t, big.NewInt(500), metadata.CumulativePayment) + assert.Equal(t, big.NewInt(500), header.CumulativePayment) // Fourth call: Insufficient on-demand _, err = accountant.AccountBlob(ctx, 600, quorums) @@ -321,23 +314,20 @@ func TestAccountBlob_ReservationWithOneOverflow(t *testing.T) { header, err := accountant.AccountBlob(ctx, 800, quorums) assert.NoError(t, err) assert.Equal(t, meterer.GetBinIndex(uint64(now), reservationWindow), header.BinIndex) - metadata := core.ConvertPaymentHeader(header) - assert.Equal(t, big.NewInt(0), metadata.CumulativePayment) + assert.Equal(t, big.NewInt(0), header.CumulativePayment) assert.Equal(t, isRotation([]uint64{800, 0, 0}, mapRecordUsage(accountant.binRecords)), true) // Second call: Allow one overflow header, err = accountant.AccountBlob(ctx, 500, quorums) assert.NoError(t, err) - metadata = core.ConvertPaymentHeader(header) - assert.Equal(t, big.NewInt(0), metadata.CumulativePayment) + assert.Equal(t, big.NewInt(0), header.CumulativePayment) assert.Equal(t, isRotation([]uint64{1300, 0, 300}, mapRecordUsage(accountant.binRecords)), true) // Third call: Should use on-demand payment header, err = accountant.AccountBlob(ctx, 200, quorums) assert.NoError(t, err) assert.Equal(t, uint32(0), header.BinIndex) - metadata = core.ConvertPaymentHeader(header) - assert.Equal(t, big.NewInt(200), metadata.CumulativePayment) + assert.Equal(t, big.NewInt(200), header.CumulativePayment) assert.Equal(t, isRotation([]uint64{1300, 0, 300}, mapRecordUsage(accountant.binRecords)), true) } @@ -373,8 +363,7 @@ func TestAccountBlob_ReservationOverflowReset(t *testing.T) { header, err := accountant.AccountBlob(ctx, 500, quorums) assert.NoError(t, err) assert.Equal(t, isRotation([]uint64{1000, 0, 0}, mapRecordUsage(accountant.binRecords)), true) - metadata := core.ConvertPaymentHeader(header) - assert.Equal(t, big.NewInt(500), metadata.CumulativePayment) + assert.Equal(t, big.NewInt(500), header.CumulativePayment) // Wait for next reservation duration time.Sleep(time.Duration(reservationWindow) * time.Second) diff --git a/api/clients/disperser_client_v2.go b/api/clients/disperser_client_v2.go index e9d2e57f3e..ec34071923 100644 --- a/api/clients/disperser_client_v2.go +++ b/api/clients/disperser_client_v2.go @@ -3,7 +3,6 @@ package clients import ( "context" "fmt" - "math/big" "sync" "github.com/Layr-Labs/eigenda/api" @@ -30,12 +29,13 @@ type DisperserClientV2 interface { } type disperserClientV2 struct { - config *DisperserClientV2Config - signer corev2.BlobRequestSigner - initOnce sync.Once - conn *grpc.ClientConn - client disperser_rpc.DisperserClient - prover encoding.Prover + config *DisperserClientV2Config + signer corev2.BlobRequestSigner + initOnce sync.Once + conn *grpc.ClientConn + client disperser_rpc.DisperserClient + prover encoding.Prover + accountant *Accountant } var _ DisperserClientV2 = &disperserClientV2{} @@ -60,7 +60,7 @@ var _ DisperserClientV2 = &disperserClientV2{} // // // Subsequent calls will use the existing connection // status2, blobKey2, err := client.DisperseBlob(ctx, data, blobHeader) -func NewDisperserClientV2(config *DisperserClientV2Config, signer corev2.BlobRequestSigner, prover encoding.Prover) (*disperserClientV2, error) { +func NewDisperserClientV2(config *DisperserClientV2Config, signer corev2.BlobRequestSigner, prover encoding.Prover, accountant *Accountant) (*disperserClientV2, error) { if config == nil { return nil, api.NewErrorInvalidArg("config must be provided") } @@ -75,13 +75,28 @@ func NewDisperserClientV2(config *DisperserClientV2Config, signer corev2.BlobReq } return &disperserClientV2{ - config: config, - signer: signer, - prover: prover, + config: config, + signer: signer, + prover: prover, + accountant: accountant, // conn and client are initialized lazily }, nil } +// PopulateAccountant populates the accountant with the payment state from the disperser. +func (c *disperserClientV2) PopulateAccountant(ctx context.Context) error { + paymentState, err := c.GetPaymentState(ctx) + if err != nil { + return fmt.Errorf("error getting payment state for initializing accountant: %w", err) + } + + err = c.accountant.SetPaymentState(paymentState) + if err != nil { + return fmt.Errorf("error setting payment state for accountant: %w", err) + } + return nil +} + // Close closes the grpc connection to the disperser server. // It is thread safe and can be called multiple times. func (c *disperserClientV2) Close() error { @@ -108,16 +123,15 @@ func (c *disperserClientV2) DisperseBlob( if c.signer == nil { return nil, [32]byte{}, api.NewErrorInternal("uninitialized signer for authenticated dispersal") } + if c.accountant == nil { + return nil, [32]byte{}, api.NewErrorInternal("uninitialized accountant for paid dispersal; make sure to call PopulateAccountant after creating the client") + } - var payment core.PaymentMetadata - accountId, err := c.signer.GetAccountID() + symbolLength := encoding.GetBlobLengthPowerOf2(uint(len(data))) + payment, err := c.accountant.AccountBlob(ctx, uint64(symbolLength), quorums) if err != nil { - return nil, [32]byte{}, api.NewErrorInvalidArg(fmt.Sprintf("please configure signer key if you want to use authenticated endpoint %v", err)) + return nil, [32]byte{}, fmt.Errorf("error accounting blob: %w", err) } - payment.AccountID = accountId - // TODO: add payment metadata - payment.BinIndex = 0 - payment.CumulativePayment = big.NewInt(0) if len(quorums) == 0 { return nil, [32]byte{}, api.NewErrorInvalidArg("quorum numbers must be provided") @@ -160,7 +174,7 @@ func (c *disperserClientV2) DisperseBlob( BlobVersion: blobVersion, BlobCommitments: blobCommitments, QuorumNumbers: quorums, - PaymentMetadata: payment, + PaymentMetadata: *payment, } sig, err := c.signer.SignBlobRequest(blobHeader) if err != nil { @@ -202,6 +216,30 @@ func (c *disperserClientV2) GetBlobStatus(ctx context.Context, blobKey corev2.Bl return c.client.GetBlobStatus(ctx, request) } +// GetPaymentState returns the payment state of the disperser client +func (c *disperserClientV2) GetPaymentState(ctx context.Context) (*disperser_rpc.GetPaymentStateReply, error) { + err := c.initOnceGrpcConnection() + if err != nil { + return nil, api.NewErrorInternal(err.Error()) + } + + accountID, err := c.signer.GetAccountID() + if err != nil { + return nil, fmt.Errorf("error getting signer's account ID: %w", err) + } + + signature, err := c.signer.SignPaymentStateRequest() + if err != nil { + return nil, fmt.Errorf("error signing payment state request: %w", err) + } + + request := &disperser_rpc.GetPaymentStateRequest{ + AccountId: accountID, + Signature: signature, + } + return c.client.GetPaymentState(ctx, request) +} + // GetBlobCommitment is a utility method that calculates commitment for a blob payload. // While the blob commitment can be calculated by anyone, it requires SRS points to // be loaded. For service that does not have access to SRS points, this method can be diff --git a/disperser/apiserver/server_v2.go b/disperser/apiserver/server_v2.go index 2c502aaafa..e8dece0b36 100644 --- a/disperser/apiserver/server_v2.go +++ b/disperser/apiserver/server_v2.go @@ -264,6 +264,10 @@ func (s *DispersalServerV2) GetPaymentState(ctx context.Context, req *pb.GetPaym for i, v := range reservation.QuorumNumbers { quorumNumbers[i] = uint32(v) } + quorumSplit := make([]uint32, len(reservation.QuorumSplit)) + for i, v := range reservation.QuorumSplit { + quorumSplit[i] = uint32(v) + } // build reply reply := &pb.GetPaymentStateReply{ PaymentGlobalParams: &paymentGlobalParams, @@ -273,6 +277,7 @@ func (s *DispersalServerV2) GetPaymentState(ctx context.Context, req *pb.GetPaym StartTimestamp: uint32(reservation.StartTimestamp), EndTimestamp: uint32(reservation.EndTimestamp), QuorumNumbers: quorumNumbers, + QuorumSplit: quorumSplit, }, CumulativePayment: largestCumulativePayment.Bytes(), OnchainCumulativePayment: onDemandPayment.CumulativePayment.Bytes(),