diff --git a/rfqmsg/messages.go b/rfqmsg/messages.go index cc61b1ffa..1111b26e9 100644 --- a/rfqmsg/messages.go +++ b/rfqmsg/messages.go @@ -117,6 +117,9 @@ type WireMsgDataVersion uint8 const ( // V0 represents version 0 of the contents in a wire message data field. V0 WireMsgDataVersion = 0 + + // V1 represents version 1 of the contents in a wire message data field. + V1 WireMsgDataVersion = 1 ) // Record returns a TLV record that can be used to encode/decode a diff --git a/rfqmsg/reject.go b/rfqmsg/reject.go index 69a1a9ed6..44eef2a6a 100644 --- a/rfqmsg/reject.go +++ b/rfqmsg/reject.go @@ -9,74 +9,47 @@ import ( "github.com/lightningnetwork/lnd/tlv" ) -const ( - // Reject message type field TLV types. - - TypeRejectVersion tlv.Type = 0 - TypeRejectID tlv.Type = 2 - TypeRejectErrCode tlv.Type = 3 - TypeRejectErrMsg tlv.Type = 5 -) - -func TypeRecordRejectVersion(version *WireMsgDataVersion) tlv.Record { - const recordSize = 1 - - return tlv.MakeStaticRecord( - TypeRejectVersion, version, recordSize, - WireMsgDataVersionEncoder, WireMsgDataVersionDecoder, - ) -} - -func TypeRecordRejectID(id *ID) tlv.Record { - const recordSize = 32 - - return tlv.MakeStaticRecord( - TypeRejectID, id, recordSize, IdEncoder, IdDecoder, - ) -} - -func TypeRecordRejectErrCode(errCode *uint8) tlv.Record { - return tlv.MakePrimitiveRecord(TypeRejectErrCode, errCode) -} - -func TypeRecordRejectErrMsg(errMsg *string) tlv.Record { - sizeFunc := func() uint64 { - return uint64(len(*errMsg)) - } - return tlv.MakeDynamicRecord( - TypeRejectErrMsg, errMsg, sizeFunc, - rejectErrMsgEncoder, rejectErrMsgDecoder, - ) -} +// rejectErrEncoder is a function that encodes a RejectErr into a writer. +func rejectErrEncoder(w io.Writer, val any, buf *[8]byte) error { + if typ, ok := val.(*RejectErr); ok { + if err := tlv.EUint8(w, &typ.Code, buf); err != nil { + return err + } -func rejectErrMsgEncoder(w io.Writer, val any, buf *[8]byte) error { - if typ, ok := val.(*string); ok { - msgBytes := []byte(*typ) - err := tlv.EVarBytes(w, &msgBytes, buf) - if err != nil { + msgBytes := []byte(typ.Msg) + if err := tlv.EVarBytes(w, &msgBytes, buf); err != nil { return err } return nil } - return tlv.NewTypeForEncodingErr(val, "RejectErrMsg") + return tlv.NewTypeForEncodingErr(val, "RejectErr") } -func rejectErrMsgDecoder(r io.Reader, val any, buf *[8]byte, l uint64) error { - if typ, ok := val.(*string); ok { +// rejectErrDecoder is a function that decodes a RejectErr from a reader. +func rejectErrDecoder(r io.Reader, val any, buf *[8]byte, l uint64) error { + if typ, ok := val.(*RejectErr); ok { + var rejectCode uint8 + if err := tlv.DUint8(r, &rejectCode, buf, 1); err != nil { + return err + } + var errMsgBytes []byte - err := tlv.DVarBytes(r, &errMsgBytes, buf, l) + err := tlv.DVarBytes(r, &errMsgBytes, buf, l-1) if err != nil { return err } - *typ = string(errMsgBytes) + *typ = RejectErr{ + Code: rejectCode, + Msg: string(errMsgBytes), + } return nil } - return tlv.NewTypeForDecodingErr(val, "RejectErrMsg", l, l) + return tlv.NewTypeForDecodingErr(val, "RejectErr", l, l) } // RejectErr is a struct that represents the error code and message of a quote @@ -89,6 +62,20 @@ type RejectErr struct { Msg string } +// Record returns a TLV record that can be used to encode/decode a RejectErr +// to/from a TLV stream. +func (v *RejectErr) Record() tlv.Record { + sizeFunc := func() uint64 { + return 1 + uint64(len(v.Msg)) + } + + // We set the type to zero here because the type parameter in + // tlv.RecordT will be used as the actual type. + return tlv.MakeDynamicRecord( + 0, v, sizeFunc, rejectErrEncoder, rejectErrDecoder, + ) +} + var ( // ErrUnknownReject is the error code for when the quote is rejected // for an unspecified reason. @@ -108,57 +95,45 @@ var ( const ( // latestRejectVersion is the latest supported reject wire message data // field version. - latestRejectVersion = V0 + latestRejectVersion = V1 ) -// rejectMsgData is a struct that represents the data field of a quote -// reject message. -type rejectMsgData struct { +// rejectWireMsgData is a struct that represents the data field of a quote +// reject wire message. +type rejectWireMsgData struct { // Version is the version of the message data. - Version WireMsgDataVersion + Version tlv.RecordT[tlv.TlvType0, WireMsgDataVersion] // ID represents the unique identifier of the quote request message that // this response is associated with. - ID ID + ID tlv.RecordT[tlv.TlvType2, ID] // Err is the error code and message that provides the reason for the // rejection. - Err RejectErr + Err tlv.RecordT[tlv.TlvType5, RejectErr] } -// EncodeRecords determines the non-nil records to include when encoding at -// runtime. -func (q *rejectMsgData) encodeRecords() []tlv.Record { +// records returns all records for encoding/decoding. +func (q *rejectWireMsgData) records() []tlv.Record { return []tlv.Record{ - TypeRecordRejectVersion(&q.Version), - TypeRecordRejectID(&q.ID), - TypeRecordRejectErrCode(&q.Err.Code), - TypeRecordRejectErrMsg(&q.Err.Msg), + q.Version.Record(), + q.ID.Record(), + q.Err.Record(), } } // Encode encodes the structure into a TLV stream. -func (q *rejectMsgData) Encode(writer io.Writer) error { - stream, err := tlv.NewStream(q.encodeRecords()...) +func (q *rejectWireMsgData) Encode(writer io.Writer) error { + stream, err := tlv.NewStream(q.records()...) if err != nil { return err } return stream.Encode(writer) } -// DecodeRecords provides all TLV records for decoding. -func (q *rejectMsgData) decodeRecords() []tlv.Record { - return []tlv.Record{ - TypeRecordRejectVersion(&q.Version), - TypeRecordRejectID(&q.ID), - TypeRecordRejectErrCode(&q.Err.Code), - TypeRecordRejectErrMsg(&q.Err.Msg), - } -} - // Decode decodes the structure from a TLV stream. -func (q *rejectMsgData) Decode(r io.Reader) error { - stream, err := tlv.NewStream(q.decodeRecords()...) +func (q *rejectWireMsgData) Decode(r io.Reader) error { + stream, err := tlv.NewStream(q.records()...) if err != nil { return err } @@ -166,7 +141,7 @@ func (q *rejectMsgData) Decode(r io.Reader) error { } // Bytes encodes the structure into a TLV stream and returns the bytes. -func (q *rejectMsgData) Bytes() ([]byte, error) { +func (q *rejectWireMsgData) Bytes() ([]byte, error) { var b bytes.Buffer err := q.Encode(&b) if err != nil { @@ -181,8 +156,9 @@ type Reject struct { // Peer is the peer that sent the quote request. Peer route.Vertex - // rejectMsgData is the message data for the quote reject message. - rejectMsgData + // rejectWireMsgData is the message data for the quote reject wire + // message. + rejectWireMsgData } // NewReject creates a new instance of a quote reject message. @@ -191,10 +167,12 @@ func NewReject(peer route.Vertex, requestId ID, return &Reject{ Peer: peer, - rejectMsgData: rejectMsgData{ - Version: latestRejectVersion, - ID: requestId, - Err: rejectErr, + rejectWireMsgData: rejectWireMsgData{ + Version: tlv.NewRecordT[tlv.TlvType0]( + latestRejectVersion, + ), + ID: tlv.NewRecordT[tlv.TlvType2](requestId), + Err: tlv.NewRecordT[tlv.TlvType5](rejectErr), }, } } @@ -208,7 +186,7 @@ func NewQuoteRejectFromWireMsg(wireMsg WireMessage) (*Reject, error) { } // Decode message data component from TLV bytes. - var msgData rejectMsgData + var msgData rejectWireMsgData err := msgData.Decode(bytes.NewReader(wireMsg.Data)) if err != nil { return nil, fmt.Errorf("unable to decode quote reject "+ @@ -216,21 +194,21 @@ func NewQuoteRejectFromWireMsg(wireMsg WireMessage) (*Reject, error) { } // Ensure that the message version is supported. - if msgData.Version > latestRejectVersion { + if msgData.Version.Val != latestRejectVersion { return nil, fmt.Errorf("unsupported reject message version: %d", - msgData.Version) + msgData.Version.Val) } return &Reject{ - Peer: wireMsg.Peer, - rejectMsgData: msgData, + Peer: wireMsg.Peer, + rejectWireMsgData: msgData, }, nil } // ToWire returns a wire message with a serialized data field. func (q *Reject) ToWire() (WireMessage, error) { // Encode message data component as TLV bytes. - msgDataBytes, err := q.rejectMsgData.Bytes() + msgDataBytes, err := q.rejectWireMsgData.Bytes() if err != nil { return WireMessage{}, fmt.Errorf("unable to encode message "+ "data: %w", err) @@ -250,13 +228,13 @@ func (q *Reject) MsgPeer() route.Vertex { // MsgID returns the quote request session ID. func (q *Reject) MsgID() ID { - return q.ID + return q.ID.Val } // String returns a human-readable string representation of the message. func (q *Reject) String() string { return fmt.Sprintf("Reject(id=%x, err_code=%d, err_msg=%s)", - q.ID[:], q.Err.Code, q.Err.Msg) + q.ID.Val[:], q.Err.Val.Code, q.Err.Val.Msg) } // Ensure that the message type implements the OutgoingMsg interface. diff --git a/rfqmsg/reject_test.go b/rfqmsg/reject_test.go index 8106135b0..a0cb82dca 100644 --- a/rfqmsg/reject_test.go +++ b/rfqmsg/reject_test.go @@ -46,6 +46,13 @@ func TestRejectEncodeDecode(t *testing.T) { id: id, err: ErrPriceOracleUnavailable, }, + { + testName: "empty error message", + peer: route.Vertex{1, 2, 3}, + version: 5, + id: id, + err: RejectErr{}, + }, } for _, tc := range testCases { @@ -64,7 +71,8 @@ func TestRejectEncodeDecode(t *testing.T) { // Assert that the decoded message is equal to the // original message. require.Equal( - tt, msg.rejectMsgData, decodedMsg.rejectMsgData, + tt, msg.rejectWireMsgData, + decodedMsg.rejectWireMsgData, ) }) } diff --git a/taprpc/marshal.go b/taprpc/marshal.go index 66c43b28a..c7198f847 100644 --- a/taprpc/marshal.go +++ b/taprpc/marshal.go @@ -612,9 +612,9 @@ func MarshalIncomingRejectQuoteEvent( return &rfqrpc.RejectedQuoteResponse{ Peer: event.Peer.String(), - Id: event.ID[:], - ErrorMessage: event.Err.Msg, - ErrorCode: uint32(event.Err.Code), + Id: event.ID.Val[:], + ErrorMessage: event.Err.Val.Msg, + ErrorCode: uint32(event.Err.Val.Code), } }