Skip to content

Commit

Permalink
Round trip check validation on privacy group message types
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Broadhurst <peter.broadhurst@kaleido.io>
  • Loading branch information
peterbroadhurst committed Feb 24, 2025
1 parent 87f155f commit 9641987
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 16 deletions.
127 changes: 127 additions & 0 deletions core/go/internal/transportmgr/peer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"time"

"github.com/google/uuid"
"github.com/hyperledger/firefly-signer/pkg/abi"
"github.com/kaleido-io/paladin/config/pkg/confutil"
"github.com/kaleido-io/paladin/config/pkg/pldconf"
"github.com/kaleido-io/paladin/core/internal/components"
Expand Down Expand Up @@ -503,3 +504,129 @@ func TestProcessReliableMsgPageInsertFail(t *testing.T) {
require.Regexp(t, "PD020302", err)

}

func TestProcessReliableMsgPagePrivacyGroup(t *testing.T) {

simpleABI := &abi.Parameter{
Type: "tuple", InternalType: "struct EmptyType;",
}
schemaID := tktypes.RandBytes32()
ctx, tm, tp, done := newTestTransport(t, false,
mockGetStateOk,
func(mc *mockComponents, conf *pldconf.TransportManagerConfig) {
mc.stateManager.On("GetSchemaByID", mock.Anything, mock.Anything, "domain1", schemaID, false).
Return(&pldapi.Schema{ID: schemaID, Definition: tktypes.JSONString(simpleABI)}, nil)

mc.db.Mock.ExpectExec("INSERT.*reliable_msgs").WillReturnResult(driver.ResultNoRows)
})
defer done()

p := &peer{
ctx: ctx,
tm: tm,
transport: tp.t,
}

sd := &components.StateDistribution{
Domain: "domain1",
ContractAddress: tktypes.RandAddress().String(),
SchemaID: schemaID.String(),
StateID: tktypes.RandHex(32),
}

rm := &components.ReliableMessage{
ID: uuid.New(),
Sequence: 50,
MessageType: components.RMTPrivacyGroup.Enum(),
Node: "node2",
Metadata: tktypes.JSONString(sd),
Created: tktypes.TimestampNow(),
}

sentMessages := make(chan *prototk.PaladinMsg, 1)
tp.Functions.SendMessage = func(ctx context.Context, req *prototk.SendMessageRequest) (*prototk.SendMessageResponse, error) {
sent := req.Message
sentMessages <- sent
return nil, nil
}

err := p.processReliableMsgPage(tm.persistence.NOTX(), []*components.ReliableMessage{rm})
require.NoError(t, err)

sentMsg := <-sentMessages

rMsg, err := parseReceivedMessage(ctx, "ndoe2", sentMsg)
require.NoError(t, err)
require.Equal(t, RMHMessageTypePrivacyGroup, rMsg.MessageType)

domain, genesisABI, genesisState, err := parsePrivacyGroupDistribution(ctx, rMsg.MessageID, rMsg.Payload)
require.NoError(t, err)
require.Equal(t, "domain1", domain)
require.Equal(t, simpleABI, genesisABI)
require.JSONEq(t, fmt.Sprintf(`{"dataFor": "%s"}`, genesisState.ID.HexString()), genesisState.Data.Pretty())
}

func TestProcessReliableMsgPagePrivacyGroupMessage(t *testing.T) {

origMsg := &pldapi.PrivacyGroupMessage{
ID: uuid.New(),
Sent: tktypes.TimestampNow(),
PrivacyGroupMessageInput: pldapi.PrivacyGroupMessageInput{
Domain: "domain1",
Group: tktypes.RandBytes(32),
Topic: "topic1",
Data: tktypes.JSONString("some data"),
},
}
ctx, tm, tp, done := newTestTransport(t, false,
func(mc *mockComponents, conf *pldconf.TransportManagerConfig) {
mc.groupManager.On("GetMessageByID", mock.Anything, mock.Anything, origMsg.ID, false).
Return(origMsg, nil)

mc.db.Mock.ExpectExec("INSERT.*reliable_msgs").WillReturnResult(driver.ResultNoRows)
})
defer done()

p := &peer{
ctx: ctx,
tm: tm,
transport: tp.t,
}

pmd := &components.PrivacyGroupMessageDistribution{
Domain: "domain1",
Group: tktypes.RandBytes(32),
ID: origMsg.ID,
}

rm := &components.ReliableMessage{
ID: origMsg.ID,
Sequence: 50,
MessageType: components.RMTPrivacyGroupMessage.Enum(),
Node: "node2",
Metadata: tktypes.JSONString(pmd),
Created: tktypes.TimestampNow(),
}

sentMessages := make(chan *prototk.PaladinMsg, 1)
tp.Functions.SendMessage = func(ctx context.Context, req *prototk.SendMessageRequest) (*prototk.SendMessageResponse, error) {
sent := req.Message
sentMessages <- sent
return nil, nil
}

err := p.processReliableMsgPage(tm.persistence.NOTX(), []*components.ReliableMessage{rm})
require.NoError(t, err)

sentMsg := <-sentMessages

rMsg, err := parseReceivedMessage(ctx, "node2", sentMsg)
require.NoError(t, err)
require.Equal(t, RMHMessageTypePrivacyGroupMessage, rMsg.MessageType)

receivedMsg, err := parsePrivacyGroupMessage(ctx, rMsg.FromNode, rMsg.MessageID, rMsg.Payload)
require.NoError(t, err)
origMsg.Received = receivedMsg.Received // expect to be changed on incoming message
origMsg.Node = receivedMsg.Node // expect to be changed on incoming message
require.Equal(t, origMsg, receivedMsg)
}
2 changes: 1 addition & 1 deletion core/go/internal/transportmgr/reliable_msg_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ func (tm *transportManager) buildPrivacyGroupDistributionMsg(ctx context.Context
return &prototk.PaladinMsg{
MessageId: rm.ID.String(),
Component: prototk.PaladinMsg_RELIABLE_MESSAGE_HANDLER,
MessageType: RMHMessageTypeStateDistribution,
MessageType: RMHMessageTypePrivacyGroup,
Payload: tktypes.JSONString(components.PrivacyGroupGenesisWithABI{
GenesisState: *sd,
GenesisABI: abiDefinition,
Expand Down
42 changes: 27 additions & 15 deletions core/go/internal/transportmgr/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,7 @@ func protoToJSON(m proto.Message) (s string) {
return
}

// Transport callback to the transport manager when a message is received
func (t *transport) ReceiveMessage(ctx context.Context, req *prototk.ReceiveMessageRequest) (*prototk.ReceiveMessageResponse, error) {
if err := t.checkInit(ctx); err != nil {
return nil, err
}

msg := req.Message
func parseReceivedMessage(ctx context.Context, fromNode string, msg *prototk.PaladinMsg) (*components.ReceivedMessage, error) {
if msg == nil || len(msg.Payload) == 0 || len(msg.MessageType) == 0 {
log.L(ctx).Errorf("Invalid message from transport: %s", protoToJSON(msg))
return nil, i18n.NewError(ctx, msgs.MsgTransportInvalidMessage)
Expand All @@ -153,24 +147,42 @@ func (t *transport) ReceiveMessage(ctx context.Context, req *prototk.ReceiveMess
correlationID = &parsedUUID
}

return &components.ReceivedMessage{
FromNode: fromNode,
MessageID: msgID,
CorrelationID: correlationID,
MessageType: msg.MessageType,
Payload: msg.Payload,
}, nil

}

// Transport callback to the transport manager when a message is received
func (t *transport) ReceiveMessage(ctx context.Context, req *prototk.ReceiveMessageRequest) (*prototk.ReceiveMessageResponse, error) {
if err := t.checkInit(ctx); err != nil {
return nil, err
}

msg := req.Message

rMsg, err := parseReceivedMessage(ctx, req.FromNode, msg)
if err != nil {
return nil, err
}

p, err := t.tm.getPeer(ctx, req.FromNode, false /* we do not require a connection for sending here */)
if err != nil {
return nil, err
}

p.updateReceivedStats(msg)

log.L(ctx).Debugf("transport %s message received from %s id=%s (cid=%s)", t.name, p.Name, msgID, tktypes.StrOrEmpty(msg.CorrelationId))
log.L(ctx).Debugf("transport %s message received from %s id=%s (cid=%s)", t.name, p.Name, rMsg.MessageID, tktypes.StrOrEmpty(msg.CorrelationId))
if log.IsTraceEnabled() {
log.L(ctx).Tracef("transport %s message received: %s", t.name, protoToJSON(msg))
}

if err := t.deliverMessage(ctx, p, msg.Component, &components.ReceivedMessage{
FromNode: req.FromNode,
MessageID: msgID,
CorrelationID: correlationID,
MessageType: msg.MessageType,
Payload: msg.Payload,
}); err != nil {
if err := t.deliverMessage(ctx, p, msg.Component, rMsg); err != nil {
return nil, err
}

Expand Down

0 comments on commit 9641987

Please sign in to comment.