Skip to content

Commit

Permalink
Add subnetID arg to warp APIs (ava-labs#1008)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronbuchwald authored Nov 28, 2023
1 parent 551d07c commit 6883253
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 20 deletions.
4 changes: 2 additions & 2 deletions tests/warp/warp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ var _ = ginkgo.Describe("[Warp]", ginkgo.Ordered, func() {
gomega.Expect(err).Should(gomega.BeNil())

// Specify WarpQuorumDenominator to retrieve signatures from every validator
signedWarpMessageBytes, err := client.GetMessageAggregateSignature(ctx, unsignedWarpMessageID, params.WarpQuorumDenominator)
signedWarpMessageBytes, err := client.GetMessageAggregateSignature(ctx, unsignedWarpMessageID, params.WarpQuorumDenominator, "")
gomega.Expect(err).Should(gomega.BeNil())
gomega.Expect(signedWarpMessageBytes).Should(gomega.Equal(signedWarpMsg.Bytes()))
})
Expand All @@ -344,7 +344,7 @@ var _ = ginkgo.Describe("[Warp]", ginkgo.Ordered, func() {
gomega.Expect(err).Should(gomega.BeNil())

// Specify WarpQuorumDenominator to retrieve signatures from every validator
signedWarpBlockBytes, err := client.GetBlockAggregateSignature(ctx, warpBlockID, params.WarpQuorumDenominator)
signedWarpBlockBytes, err := client.GetBlockAggregateSignature(ctx, warpBlockID, params.WarpQuorumDenominator, "")
gomega.Expect(err).Should(gomega.BeNil())
gomega.Expect(signedWarpBlockBytes).Should(gomega.Equal(warpBlockHashSignedMsg.Bytes()))
})
Expand Down
2 changes: 1 addition & 1 deletion warp/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (b *backend) GetMessageSignature(messageID ids.ID) ([bls.SignatureLen]byte,

unsignedMessage, err := b.GetMessage(messageID)
if err != nil {
return [bls.SignatureLen]byte{}, fmt.Errorf("failed to get warp message %s from db: %w", messageID.String(), err)
return [bls.SignatureLen]byte{}, err
}

var signature [bls.SignatureLen]byte
Expand Down
12 changes: 6 additions & 6 deletions warp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ var _ Client = (*client)(nil)

type Client interface {
GetMessageSignature(ctx context.Context, messageID ids.ID) ([]byte, error)
GetMessageAggregateSignature(ctx context.Context, messageID ids.ID, quorumNum uint64) ([]byte, error)
GetMessageAggregateSignature(ctx context.Context, messageID ids.ID, quorumNum uint64, subnetIDStr string) ([]byte, error)
GetBlockSignature(ctx context.Context, blockID ids.ID) ([]byte, error)
GetBlockAggregateSignature(ctx context.Context, blockID ids.ID, quorumNum uint64) ([]byte, error)
GetBlockAggregateSignature(ctx context.Context, blockID ids.ID, quorumNum uint64, subnetIDStr string) ([]byte, error)
}

// client implementation for interacting with EVM [chain]
Expand All @@ -45,9 +45,9 @@ func (c *client) GetMessageSignature(ctx context.Context, messageID ids.ID) ([]b
return res, nil
}

func (c *client) GetMessageAggregateSignature(ctx context.Context, messageID ids.ID, quorumNum uint64) ([]byte, error) {
func (c *client) GetMessageAggregateSignature(ctx context.Context, messageID ids.ID, quorumNum uint64, subnetIDStr string) ([]byte, error) {
var res hexutil.Bytes
if err := c.client.CallContext(ctx, &res, "warp_getMessageAggregateSignature", messageID, quorumNum); err != nil {
if err := c.client.CallContext(ctx, &res, "warp_getMessageAggregateSignature", messageID, quorumNum, subnetIDStr); err != nil {
return nil, fmt.Errorf("call to warp_getMessageAggregateSignature failed. err: %w", err)
}
return res, nil
Expand All @@ -61,9 +61,9 @@ func (c *client) GetBlockSignature(ctx context.Context, blockID ids.ID) ([]byte,
return res, nil
}

func (c *client) GetBlockAggregateSignature(ctx context.Context, blockID ids.ID, quorumNum uint64) ([]byte, error) {
func (c *client) GetBlockAggregateSignature(ctx context.Context, blockID ids.ID, quorumNum uint64, subnetIDStr string) ([]byte, error) {
var res hexutil.Bytes
if err := c.client.CallContext(ctx, &res, "warp_getBlockAggregateSignature", blockID, quorumNum); err != nil {
if err := c.client.CallContext(ctx, &res, "warp_getBlockAggregateSignature", blockID, quorumNum, subnetIDStr); err != nil {
return nil, fmt.Errorf("call to warp_getBlockAggregateSignature failed. err: %w", err)
}
return res, nil
Expand Down
33 changes: 22 additions & 11 deletions warp/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ func (a *API) GetBlockSignature(ctx context.Context, blockID ids.ID) (hexutil.By
}

// GetMessageAggregateSignature fetches the aggregate signature for the requested [messageID]
func (a *API) GetMessageAggregateSignature(ctx context.Context, messageID ids.ID, quorumNum uint64) (signedMessageBytes hexutil.Bytes, err error) {
func (a *API) GetMessageAggregateSignature(ctx context.Context, messageID ids.ID, quorumNum uint64, subnetIDStr string) (signedMessageBytes hexutil.Bytes, err error) {
unsignedMessage, err := a.backend.GetMessage(messageID)
if err != nil {
return nil, err
}
return a.aggregateSignatures(ctx, unsignedMessage, quorumNum)
return a.aggregateSignatures(ctx, unsignedMessage, quorumNum, subnetIDStr)
}

// GetBlockAggregateSignature fetches the aggregate signature for the requested [blockID]
func (a *API) GetBlockAggregateSignature(ctx context.Context, blockID ids.ID, quorumNum uint64) (signedMessageBytes hexutil.Bytes, err error) {
func (a *API) GetBlockAggregateSignature(ctx context.Context, blockID ids.ID, quorumNum uint64, subnetIDStr string) (signedMessageBytes hexutil.Bytes, err error) {
blockHashPayload, err := payload.NewHash(blockID)
if err != nil {
return nil, err
Expand All @@ -78,27 +78,38 @@ func (a *API) GetBlockAggregateSignature(ctx context.Context, blockID ids.ID, qu
return nil, err
}

return a.aggregateSignatures(ctx, unsignedMessage, quorumNum)
return a.aggregateSignatures(ctx, unsignedMessage, quorumNum, subnetIDStr)
}

func (a *API) aggregateSignatures(ctx context.Context, unsignedMessage *warp.UnsignedMessage, quorumNum uint64) (hexutil.Bytes, error) {
func (a *API) aggregateSignatures(ctx context.Context, unsignedMessage *warp.UnsignedMessage, quorumNum uint64, subnetIDStr string) (hexutil.Bytes, error) {
subnetID := a.sourceSubnetID
if len(subnetIDStr) > 0 {
sid, err := ids.FromString(subnetIDStr)
if err != nil {
return nil, fmt.Errorf("failed to parse subnetID: %q", subnetIDStr)
}
subnetID = sid
}
pChainHeight, err := a.state.GetCurrentHeight(ctx)
if err != nil {
return nil, err
}

log.Debug("Fetching signature",
"a.subnetID", a.sourceSubnetID,
"height", pChainHeight,
)
validators, totalWeight, err := warp.GetCanonicalValidatorSet(ctx, a.state, pChainHeight, a.sourceSubnetID)
validators, totalWeight, err := warp.GetCanonicalValidatorSet(ctx, a.state, pChainHeight, subnetID)
if err != nil {
return nil, fmt.Errorf("failed to get validator set: %w", err)
}
if len(validators) == 0 {
return nil, fmt.Errorf("%w (SubnetID: %s, Height: %d)", errNoValidators, a.sourceSubnetID, pChainHeight)
return nil, fmt.Errorf("%w (SubnetID: %s, Height: %d)", errNoValidators, subnetID, pChainHeight)
}

log.Debug("Fetching signature",
"sourceSubnetID", subnetID,
"height", pChainHeight,
"numValidators", len(validators),
"totalWeight", totalWeight,
)

agg := aggregator.New(aggregator.NewSignatureGetter(a.client), validators, totalWeight)
signatureResult, err := agg.AggregateSignatures(ctx, unsignedMessage, quorumNum)
if err != nil {
Expand Down

0 comments on commit 6883253

Please sign in to comment.