diff --git a/channeldb/channel.go b/channeldb/channel.go index bb63477147..18db1d2078 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -35,6 +35,10 @@ const ( // begins to be interpreted as an absolute block height, rather than a // relative one. AbsoluteThawHeightThreshold uint32 = 500000 + + // HTLCBlindingPointTLV is the tlv type used for storing blinding + // points with HTLCs. + HTLCBlindingPointTLV tlv.Type = 0 ) var ( @@ -2316,7 +2320,56 @@ type HTLC struct { // Note that this extra data is stored inline with the OnionBlob for // legacy reasons, see serialization/deserialization functions for // detail. - ExtraData []byte + ExtraData lnwire.ExtraOpaqueData + + // BlindingPoint is an optional blinding point included with the HTLC. + // + // Note: this field is not a part of on-disk representation of the + // HTLC. It is stored in the ExtraData field, which is used to store + // a TLV stream of additional information associated with the HTLC. + BlindingPoint lnwire.BlindingPointRecord +} + +// serializeExtraData encodes a TLV stream of extra data to be stored with a +// HTLC. It uses the update_add_htlc TLV types, because this is where extra +// data is passed with a HTLC. At present blinding points are the only extra +// data that we will store, and the function is a no-op if a nil blinding +// point is provided. +// +// This function MUST be called to persist all HTLC values when they are +// serialized. +func (h *HTLC) serializeExtraData() error { + var records []tlv.RecordProducer + h.BlindingPoint.WhenSome(func(b tlv.RecordT[lnwire.BlindingPointTlvType, + *btcec.PublicKey]) { + + records = append(records, &b) + }) + + return h.ExtraData.PackRecords(records...) +} + +// deserializeExtraData extracts TLVs from the extra data persisted for the +// htlc and populates values in the struct accordingly. +// +// This function MUST be called to populate the struct properly when HTLCs +// are deserialized. +func (h *HTLC) deserializeExtraData() error { + if len(h.ExtraData) == 0 { + return nil + } + + blindingPoint := h.BlindingPoint.Zero() + tlvMap, err := h.ExtraData.ExtractRecords(&blindingPoint) + if err != nil { + return err + } + + if val, ok := tlvMap[h.BlindingPoint.TlvType()]; ok && val == nil { + h.BlindingPoint = tlv.SomeRecordT(blindingPoint) + } + + return nil } // SerializeHtlcs writes out the passed set of HTLC's into the passed writer @@ -2340,6 +2393,12 @@ func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error { } for _, htlc := range htlcs { + // Populate TLV stream for any additional fields contained + // in the TLV. + if err := htlc.serializeExtraData(); err != nil { + return err + } + // The onion blob and hltc data are stored as a single var // bytes blob. onionAndExtraData := make( @@ -2425,6 +2484,12 @@ func DeserializeHtlcs(r io.Reader) ([]HTLC, error) { onionAndExtraData[lnwire.OnionPacketSize:], ) } + + // Finally, deserialize any TLVs contained in that extra data + // if they are present. + if err := htlcs[i].deserializeExtraData(); err != nil { + return nil, err + } } return htlcs, nil @@ -2440,6 +2505,7 @@ func (h *HTLC) Copy() HTLC { } copy(clone.Signature[:], h.Signature) copy(clone.RHash[:], h.RHash[:]) + copy(clone.ExtraData, h.ExtraData) return clone } diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 6047a1e67e..981ddf688b 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -23,6 +23,7 @@ import ( "github.com/lightningnetwork/lnd/lntest/channels" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" + "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" ) @@ -651,8 +652,7 @@ func TestChannelStateTransition(t *testing.T) { { LogIndex: 2, UpdateMsg: &lnwire.UpdateAddHTLC{ - ChanID: lnwire.ChannelID{1, 2, 3}, - ExtraData: make([]byte, 0), + ChanID: lnwire.ChannelID{1, 2, 3}, }, }, } @@ -710,25 +710,22 @@ func TestChannelStateTransition(t *testing.T) { wireSig, wireSig, }, - ExtraData: make([]byte, 0), }, LogUpdates: []LogUpdate{ { LogIndex: 1, UpdateMsg: &lnwire.UpdateAddHTLC{ - ID: 1, - Amount: lnwire.NewMSatFromSatoshis(100), - Expiry: 25, - ExtraData: make([]byte, 0), + ID: 1, + Amount: lnwire.NewMSatFromSatoshis(100), + Expiry: 25, }, }, { LogIndex: 2, UpdateMsg: &lnwire.UpdateAddHTLC{ - ID: 2, - Amount: lnwire.NewMSatFromSatoshis(200), - Expiry: 50, - ExtraData: make([]byte, 0), + ID: 2, + Amount: lnwire.NewMSatFromSatoshis(200), + Expiry: 50, }, }, }, @@ -1610,9 +1607,25 @@ func TestHTLCsExtraData(t *testing.T) { OnionBlob: lnmock.MockOnion(), } + // Add a blinding point to a htlc. + blindingPointHTLC := HTLC{ + Signature: testSig.Serialize(), + Incoming: false, + Amt: 10, + RHash: key, + RefundTimeout: 1, + OnionBlob: lnmock.MockOnion(), + BlindingPoint: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( + pubKey, + ), + ), + } + testCases := []struct { - name string - htlcs []HTLC + name string + htlcs []HTLC + blindingIdx int }{ { // Serialize multiple HLTCs with no extra data to @@ -1624,30 +1637,12 @@ func TestHTLCsExtraData(t *testing.T) { }, }, { + // Some HTLCs with extra data, some without. name: "mixed extra data", htlcs: []HTLC{ mockHtlc, - { - Signature: testSig.Serialize(), - Incoming: false, - Amt: 10, - RHash: key, - RefundTimeout: 1, - OnionBlob: lnmock.MockOnion(), - ExtraData: []byte{1, 2, 3}, - }, + blindingPointHTLC, mockHtlc, - { - Signature: testSig.Serialize(), - Incoming: false, - Amt: 10, - RHash: key, - RefundTimeout: 1, - OnionBlob: lnmock.MockOnion(), - ExtraData: bytes.Repeat( - []byte{9}, 999, - ), - }, }, }, } @@ -1665,7 +1660,15 @@ func TestHTLCsExtraData(t *testing.T) { r := bytes.NewReader(b.Bytes()) htlcs, err := DeserializeHtlcs(r) require.NoError(t, err) - require.Equal(t, testCase.htlcs, htlcs) + + require.EqualValues(t, len(testCase.htlcs), len(htlcs)) + for i, htlc := range htlcs { + // We use the extra data field when we + // serialize, so we set to nil to be able to + // assert on equal for the test. + htlc.ExtraData = nil + require.Equal(t, testCase.htlcs[i], htlc) + } }) } } diff --git a/cmd/lncli/cmd_payments.go b/cmd/lncli/cmd_payments.go index 1ba3b42c88..9ba6f296d1 100644 --- a/cmd/lncli/cmd_payments.go +++ b/cmd/lncli/cmd_payments.go @@ -1250,9 +1250,9 @@ func parseBlindedPaymentParameters(ctx *cli.Context) ( BaseFeeMsat: ctx.Uint64( blindedBaseFlag.Name, ), - ProportionalFeeMsat: ctx.Uint64( + ProportionalFeeRate: uint32(ctx.Uint64( blindedPPMFlag.Name, - ), + )), TotalCltvDelta: uint32(ctx.Uint64( blindedCLTVFlag.Name, )), diff --git a/contractcourt/htlc_incoming_contest_resolver.go b/contractcourt/htlc_incoming_contest_resolver.go index 9c45ee7bae..d43a50d906 100644 --- a/contractcourt/htlc_incoming_contest_resolver.go +++ b/contractcourt/htlc_incoming_contest_resolver.go @@ -7,6 +7,7 @@ import ( "fmt" "io" + "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/lightningnetwork/lnd/channeldb" @@ -17,6 +18,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/queue" + "github.com/lightningnetwork/lnd/tlv" ) // htlcIncomingContestResolver is a ContractResolver that's able to resolve an @@ -520,9 +522,18 @@ func (h *htlcIncomingContestResolver) Supplement(htlc channeldb.HTLC) { func (h *htlcIncomingContestResolver) decodePayload() (*hop.Payload, []byte, error) { + var blindingPoint *btcec.PublicKey + h.htlc.BlindingPoint.WhenSome( + func(b tlv.RecordT[lnwire.BlindingPointTlvType, + *btcec.PublicKey]) { + + blindingPoint = b.Val + }, + ) + onionReader := bytes.NewReader(h.htlc.OnionBlob[:]) iterator, err := h.OnionProcessor.ReconstructHopIterator( - onionReader, h.htlc.RHash[:], + onionReader, h.htlc.RHash[:], blindingPoint, ) if err != nil { return nil, nil, err diff --git a/contractcourt/htlc_incoming_contest_resolver_test.go b/contractcourt/htlc_incoming_contest_resolver_test.go index e8532abe58..d789858fb4 100644 --- a/contractcourt/htlc_incoming_contest_resolver_test.go +++ b/contractcourt/htlc_incoming_contest_resolver_test.go @@ -6,6 +6,7 @@ import ( "io/ioutil" "testing" + "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" @@ -288,8 +289,8 @@ type mockOnionProcessor struct { offeredOnionBlob []byte } -func (o *mockOnionProcessor) ReconstructHopIterator(r io.Reader, rHash []byte) ( - hop.Iterator, error) { +func (o *mockOnionProcessor) ReconstructHopIterator(r io.Reader, rHash []byte, + blindingPoint *btcec.PublicKey) (hop.Iterator, error) { data, err := ioutil.ReadAll(r) if err != nil { diff --git a/contractcourt/interfaces.go b/contractcourt/interfaces.go index 45cd75735a..146670a414 100644 --- a/contractcourt/interfaces.go +++ b/contractcourt/interfaces.go @@ -4,6 +4,7 @@ import ( "context" "io" + "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" @@ -40,7 +41,8 @@ type Registry interface { type OnionProcessor interface { // ReconstructHopIterator attempts to decode a valid sphinx packet from // the passed io.Reader instance. - ReconstructHopIterator(r io.Reader, rHash []byte) (hop.Iterator, error) + ReconstructHopIterator(r io.Reader, rHash []byte, + blindingKey *btcec.PublicKey) (hop.Iterator, error) } // UtxoSweeper defines the sweep functions that contract court requires. diff --git a/docs/release-notes/release-notes-0.18.0.md b/docs/release-notes/release-notes-0.18.0.md index bc72de4c53..9bb68ab4f8 100644 --- a/docs/release-notes/release-notes-0.18.0.md +++ b/docs/release-notes/release-notes-0.18.0.md @@ -172,6 +172,9 @@ * When computing a minimum fee for transaction construction, `lnd` [now takes our bitcoin peers' feefilter values into account](https://github.com/lightningnetwork/lnd/pull/8418). +* [Preparatory work](https://github.com/lightningnetwork/lnd/pull/8159) for + forwarding of blinded routes was added. + ## RPC Additions * [Deprecated](https://github.com/lightningnetwork/lnd/pull/7175) diff --git a/htlcswitch/hop/iterator.go b/htlcswitch/hop/iterator.go index 96a5c5f2bf..1829522f48 100644 --- a/htlcswitch/hop/iterator.go +++ b/htlcswitch/hop/iterator.go @@ -147,67 +147,30 @@ func (p *OnionProcessor) Stop() error { return nil } -// DecodeHopIterator attempts to decode a valid sphinx packet from the passed io.Reader -// instance using the rHash as the associated data when checking the relevant -// MACs during the decoding process. -func (p *OnionProcessor) DecodeHopIterator(r io.Reader, rHash []byte, - incomingCltv uint32) (Iterator, lnwire.FailCode) { - - onionPkt := &sphinx.OnionPacket{} - if err := onionPkt.Decode(r); err != nil { - switch err { - case sphinx.ErrInvalidOnionVersion: - return nil, lnwire.CodeInvalidOnionVersion - case sphinx.ErrInvalidOnionKey: - return nil, lnwire.CodeInvalidOnionKey - default: - log.Errorf("unable to decode onion packet: %v", err) - return nil, lnwire.CodeInvalidOnionKey - } - } - - // Attempt to process the Sphinx packet. We include the payment hash of - // the HTLC as it's authenticated within the Sphinx packet itself as - // associated data in order to thwart attempts a replay attacks. In the - // case of a replay, an attacker is *forced* to use the same payment - // hash twice, thereby losing their money entirely. - sphinxPacket, err := p.router.ProcessOnionPacket( - onionPkt, rHash, incomingCltv, - ) - if err != nil { - switch err { - case sphinx.ErrInvalidOnionVersion: - return nil, lnwire.CodeInvalidOnionVersion - case sphinx.ErrInvalidOnionHMAC: - return nil, lnwire.CodeInvalidOnionHmac - case sphinx.ErrInvalidOnionKey: - return nil, lnwire.CodeInvalidOnionKey - default: - log.Errorf("unable to process onion packet: %v", err) - return nil, lnwire.CodeInvalidOnionKey - } - } - - return makeSphinxHopIterator(onionPkt, sphinxPacket), lnwire.CodeNone -} - // ReconstructHopIterator attempts to decode a valid sphinx packet from the passed io.Reader // instance using the rHash as the associated data when checking the relevant // MACs during the decoding process. -func (p *OnionProcessor) ReconstructHopIterator(r io.Reader, rHash []byte) ( - Iterator, error) { +func (p *OnionProcessor) ReconstructHopIterator(r io.Reader, rHash []byte, + blindingPoint *btcec.PublicKey) (Iterator, error) { onionPkt := &sphinx.OnionPacket{} if err := onionPkt.Decode(r); err != nil { return nil, err } + var opts []sphinx.ProcessOnionOpt + if blindingPoint != nil { + opts = append(opts, sphinx.WithBlindingPoint(blindingPoint)) + } + // Attempt to process the Sphinx packet. We include the payment hash of // the HTLC as it's authenticated within the Sphinx packet itself as // associated data in order to thwart attempts a replay attacks. In the // case of a replay, an attacker is *forced* to use the same payment // hash twice, thereby losing their money entirely. - sphinxPacket, err := p.router.ReconstructOnionPacket(onionPkt, rHash) + sphinxPacket, err := p.router.ReconstructOnionPacket( + onionPkt, rHash, opts..., + ) if err != nil { return nil, err } @@ -219,9 +182,11 @@ func (p *OnionProcessor) ReconstructHopIterator(r io.Reader, rHash []byte) ( // packet, perform sphinx replay detection, and schedule the entry for garbage // collection. type DecodeHopIteratorRequest struct { - OnionReader io.Reader - RHash []byte - IncomingCltv uint32 + OnionReader io.Reader + RHash []byte + IncomingCltv uint32 + IncomingAmount lnwire.MilliSatoshi + BlindingPoint *btcec.PublicKey } // DecodeHopIteratorResponse encapsulates the outcome of a batched sphinx onion @@ -277,8 +242,15 @@ func (p *OnionProcessor) DecodeHopIterators(id []byte, return lnwire.CodeInvalidOnionKey } + var opts []sphinx.ProcessOnionOpt + if req.BlindingPoint != nil { + opts = append(opts, sphinx.WithBlindingPoint( + req.BlindingPoint, + )) + } + err = tx.ProcessOnionPacket( - seqNum, onionPkt, req.RHash, req.IncomingCltv, + seqNum, onionPkt, req.RHash, req.IncomingCltv, opts..., ) switch err { case nil: diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index a49a264fe9..cbd8d08a57 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -28,6 +28,10 @@ const ( // RequiredViolation indicates that an unknown even type was found in // the payload that we could not process. RequiredViolation + + // InsufficientViolation indicates that the provided type does + // not satisfy constraints. + InsufficientViolation ) // String returns a human-readable description of the violation as a verb. @@ -42,6 +46,9 @@ func (v PayloadViolation) String() string { case RequiredViolation: return "required" + case InsufficientViolation: + return "insufficient" + default: return "unknown violation" } @@ -410,3 +417,70 @@ func getMinRequiredViolation(set tlv.TypeMap) *tlv.Type { return nil } + +// ValidateBlindedRouteData performs the additional validation that is +// required for payments that rely on data provided in an encrypted blob to +// be forwarded. We enforce the blinded route's maximum expiry height so that +// the route "expires" and a malicious party does not have endless opportunity +// to probe the blinded route and compare it to updated channel policies in +// the network. +// +// Note that this function only validates blinded route data for forwarding +// nodes, as LND does not yet support receiving via a blinded route (which has +// different validation rules). +func ValidateBlindedRouteData(blindedData *record.BlindedRouteData, + incomingAmount lnwire.MilliSatoshi, incomingTimelock uint32) error { + + // Bolt 04 notes that we should enforce payment constraints _if_ they + // are present, so we do not fail if not provided. + var err error + blindedData.Constraints.WhenSome( + func(c tlv.RecordT[tlv.TlvType12, record.PaymentConstraints]) { + // MUST fail if the expiry is greater than + // max_cltv_expiry. + if incomingTimelock > c.Val.MaxCltvExpiry { + err = ErrInvalidPayload{ + Type: record.LockTimeOnionType, + Violation: InsufficientViolation, + } + } + + // MUST fail if the amount is below htlc_minimum_msat. + if incomingAmount < c.Val.HtlcMinimumMsat { + err = ErrInvalidPayload{ + Type: record.AmtOnionType, + Violation: InsufficientViolation, + } + } + }, + ) + if err != nil { + return err + } + + // Fail if we don't understand any features (even or odd), because we + // expect the features to have been set from our announcement. If the + // feature vector TLV is not included, it's interpreted as an empty + // vector (no validation required). + // expect the features to have been set from our announcement. + // + // Note that we do not yet check the features that the blinded payment + // is using against our own features, because there are currently no + // payment-related features that they utilize other than tlv-onion, + // which is implicitly supported. + blindedData.Features.WhenSome( + func(f tlv.RecordT[tlv.TlvType14, lnwire.FeatureVector]) { + if f.Val.UnknownFeatures() { + err = ErrInvalidPayload{ + Type: 14, + Violation: IncludedViolation, + } + } + }, + ) + if err != nil { + return err + } + + return nil +} diff --git a/htlcswitch/hop/payload_test.go b/htlcswitch/hop/payload_test.go index 63c9ceeef4..148b806f96 100644 --- a/htlcswitch/hop/payload_test.go +++ b/htlcswitch/hop/payload_test.go @@ -557,3 +557,141 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) { t.Fatalf("invalid custom records") } } + +// TestValidateBlindedRouteData tests validation of the values provided in a +// blinded route. +func TestValidateBlindedRouteData(t *testing.T) { + scid := lnwire.NewShortChanIDFromInt(1) + + tests := []struct { + name string + data *record.BlindedRouteData + incomingAmount lnwire.MilliSatoshi + incomingTimelock uint32 + err error + }{ + { + name: "max cltv expired", + data: record.NewBlindedRouteData( + scid, + nil, + record.PaymentRelayInfo{}, + &record.PaymentConstraints{ + MaxCltvExpiry: 100, + }, + nil, + ), + incomingTimelock: 200, + err: hop.ErrInvalidPayload{ + Type: record.LockTimeOnionType, + Violation: hop.InsufficientViolation, + }, + }, + { + name: "zero max cltv", + data: record.NewBlindedRouteData( + scid, + nil, + record.PaymentRelayInfo{}, + &record.PaymentConstraints{ + MaxCltvExpiry: 0, + HtlcMinimumMsat: 10, + }, + nil, + ), + incomingAmount: 100, + incomingTimelock: 10, + err: hop.ErrInvalidPayload{ + Type: record.LockTimeOnionType, + Violation: hop.InsufficientViolation, + }, + }, + { + name: "amount below minimum", + data: record.NewBlindedRouteData( + scid, + nil, + record.PaymentRelayInfo{}, + &record.PaymentConstraints{ + HtlcMinimumMsat: 15, + }, + nil, + ), + incomingAmount: 10, + err: hop.ErrInvalidPayload{ + Type: record.AmtOnionType, + Violation: hop.InsufficientViolation, + }, + }, + { + name: "valid, no features", + data: record.NewBlindedRouteData( + scid, + nil, + record.PaymentRelayInfo{}, + &record.PaymentConstraints{ + MaxCltvExpiry: 100, + HtlcMinimumMsat: 20, + }, + nil, + ), + incomingAmount: 40, + incomingTimelock: 80, + }, + { + name: "unknown features", + data: record.NewBlindedRouteData( + scid, + nil, + record.PaymentRelayInfo{}, + &record.PaymentConstraints{ + MaxCltvExpiry: 100, + HtlcMinimumMsat: 20, + }, + lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector( + lnwire.FeatureBit(9999), + ), + lnwire.Features, + ), + ), + incomingAmount: 40, + incomingTimelock: 80, + err: hop.ErrInvalidPayload{ + Type: 14, + Violation: hop.IncludedViolation, + }, + }, + { + name: "valid data", + data: record.NewBlindedRouteData( + scid, + nil, + record.PaymentRelayInfo{ + CltvExpiryDelta: 10, + FeeRate: 10, + BaseFee: 100, + }, + &record.PaymentConstraints{ + MaxCltvExpiry: 100, + HtlcMinimumMsat: 20, + }, + nil, + ), + incomingAmount: 40, + incomingTimelock: 80, + }, + } + + for _, testCase := range tests { + testCase := testCase + + t.Run(testCase.name, func(t *testing.T) { + err := hop.ValidateBlindedRouteData( + testCase.data, testCase.incomingAmount, + testCase.incomingTimelock, + ) + require.Equal(t, testCase.err, err) + }) + } +} diff --git a/htlcswitch/link.go b/htlcswitch/link.go index e80c7650c7..c06ca5324b 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -3160,9 +3160,11 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, onionReader := bytes.NewReader(pd.OnionBlob) req := hop.DecodeHopIteratorRequest{ - OnionReader: onionReader, - RHash: pd.RHash[:], - IncomingCltv: pd.Timeout, + OnionReader: onionReader, + RHash: pd.RHash[:], + IncomingCltv: pd.Timeout, + IncomingAmount: pd.Amount, + BlindingPoint: pd.BlindingPoint, } decodeReqs = append(decodeReqs, req) diff --git a/lnrpc/lightning.pb.go b/lnrpc/lightning.pb.go index fb6804b718..81d1f9fb26 100644 --- a/lnrpc/lightning.pb.go +++ b/lnrpc/lightning.pb.go @@ -11992,8 +11992,9 @@ type BlindedPaymentPath struct { BlindedPath *BlindedPath `protobuf:"bytes,1,opt,name=blinded_path,json=blindedPath,proto3" json:"blinded_path,omitempty"` // The base fee for the blinded path provided, expressed in msat. BaseFeeMsat uint64 `protobuf:"varint,2,opt,name=base_fee_msat,json=baseFeeMsat,proto3" json:"base_fee_msat,omitempty"` - // The proportional fee for the blinded path provided, expressed in msat. - ProportionalFeeMsat uint64 `protobuf:"varint,3,opt,name=proportional_fee_msat,json=proportionalFeeMsat,proto3" json:"proportional_fee_msat,omitempty"` + // The proportional fee for the blinded path provided, expressed in parts + // per million. + ProportionalFeeRate uint32 `protobuf:"varint,3,opt,name=proportional_fee_rate,json=proportionalFeeRate,proto3" json:"proportional_fee_rate,omitempty"` // The total CLTV delta for the blinded path provided, including the // final CLTV delta for the receiving node. TotalCltvDelta uint32 `protobuf:"varint,4,opt,name=total_cltv_delta,json=totalCltvDelta,proto3" json:"total_cltv_delta,omitempty"` @@ -12053,9 +12054,9 @@ func (x *BlindedPaymentPath) GetBaseFeeMsat() uint64 { return 0 } -func (x *BlindedPaymentPath) GetProportionalFeeMsat() uint64 { +func (x *BlindedPaymentPath) GetProportionalFeeRate() uint32 { if x != nil { - return x.ProportionalFeeMsat + return x.ProportionalFeeRate } return 0 } @@ -19611,9 +19612,9 @@ var file_lightning_proto_rawDesc = []byte{ 0x74, 0x68, 0x12, 0x22, 0x0a, 0x0d, 0x62, 0x61, 0x73, 0x65, 0x5f, 0x66, 0x65, 0x65, 0x5f, 0x6d, 0x73, 0x61, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x0b, 0x62, 0x61, 0x73, 0x65, 0x46, 0x65, 0x65, 0x4d, 0x73, 0x61, 0x74, 0x12, 0x32, 0x0a, 0x15, 0x70, 0x72, 0x6f, 0x70, 0x6f, 0x72, - 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x5f, 0x66, 0x65, 0x65, 0x5f, 0x6d, 0x73, 0x61, 0x74, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x04, 0x52, 0x13, 0x70, 0x72, 0x6f, 0x70, 0x6f, 0x72, 0x74, 0x69, 0x6f, - 0x6e, 0x61, 0x6c, 0x46, 0x65, 0x65, 0x4d, 0x73, 0x61, 0x74, 0x12, 0x28, 0x0a, 0x10, 0x74, 0x6f, + 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x5f, 0x66, 0x65, 0x65, 0x5f, 0x72, 0x61, 0x74, 0x65, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x13, 0x70, 0x72, 0x6f, 0x70, 0x6f, 0x72, 0x74, 0x69, 0x6f, + 0x6e, 0x61, 0x6c, 0x46, 0x65, 0x65, 0x52, 0x61, 0x74, 0x65, 0x12, 0x28, 0x0a, 0x10, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x5f, 0x63, 0x6c, 0x74, 0x76, 0x5f, 0x64, 0x65, 0x6c, 0x74, 0x61, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x43, 0x6c, 0x74, 0x76, 0x44, 0x65, 0x6c, 0x74, 0x61, 0x12, 0x22, 0x0a, 0x0d, 0x68, 0x74, 0x6c, 0x63, 0x5f, 0x6d, 0x69, 0x6e, diff --git a/lnrpc/lightning.proto b/lnrpc/lightning.proto index 079f2cddd8..644a9585db 100644 --- a/lnrpc/lightning.proto +++ b/lnrpc/lightning.proto @@ -3545,8 +3545,11 @@ message BlindedPaymentPath { // The base fee for the blinded path provided, expressed in msat. uint64 base_fee_msat = 2; - // The proportional fee for the blinded path provided, expressed in msat. - uint64 proportional_fee_msat = 3; + /* + The proportional fee for the blinded path provided, expressed in parts + per million. + */ + uint32 proportional_fee_rate = 3; /* The total CLTV delta for the blinded path provided, including the diff --git a/lnrpc/lightning.swagger.json b/lnrpc/lightning.swagger.json index 93eeb2ea4f..db5d3581ff 100644 --- a/lnrpc/lightning.swagger.json +++ b/lnrpc/lightning.swagger.json @@ -3523,10 +3523,10 @@ "format": "uint64", "description": "The base fee for the blinded path provided, expressed in msat." }, - "proportional_fee_msat": { - "type": "string", - "format": "uint64", - "description": "The proportional fee for the blinded path provided, expressed in msat." + "proportional_fee_rate": { + "type": "integer", + "format": "int64", + "description": "The proportional fee for the blinded path provided, expressed in parts\nper million." }, "total_cltv_delta": { "type": "integer", diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 64dd6917f0..18233991d0 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -487,15 +487,13 @@ func unmarshalBlindedPayment(rpcPayment *lnrpc.BlindedPaymentPath) ( } return &routing.BlindedPayment{ - BlindedPath: path, - CltvExpiryDelta: uint16(rpcPayment.TotalCltvDelta), - BaseFee: uint32(rpcPayment.BaseFeeMsat), - ProportionalFee: uint32( - rpcPayment.ProportionalFeeMsat, - ), - HtlcMinimum: rpcPayment.HtlcMinMsat, - HtlcMaximum: rpcPayment.HtlcMaxMsat, - Features: features, + BlindedPath: path, + CltvExpiryDelta: uint16(rpcPayment.TotalCltvDelta), + BaseFee: uint32(rpcPayment.BaseFeeMsat), + ProportionalFeeRate: rpcPayment.ProportionalFeeRate, + HtlcMinimum: rpcPayment.HtlcMinMsat, + HtlcMaximum: rpcPayment.HtlcMaxMsat, + Features: features, }, nil } diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 40bdbebb5b..1b6e71ffd6 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -31,6 +31,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" + "github.com/lightningnetwork/lnd/tlv" ) var ( @@ -371,6 +372,12 @@ type PaymentDescriptor struct { // isForwarded denotes if an incoming HTLC has been forwarded to any // possible upstream peers in the route. isForwarded bool + + // BlindingPoint is an optional ephemeral key used in route blinding. + // This value is set for nodes that are relaying payments inside of a + // blinded route (ie, not the introduction node) from update_add_htlc's + // TLVs. + BlindingPoint *btcec.PublicKey } // PayDescsFromRemoteLogUpdates converts a slice of LogUpdates received from the @@ -411,6 +418,7 @@ func PayDescsFromRemoteLogUpdates(chanID lnwire.ShortChannelID, height uint64, Height: height, Index: uint16(i), }, + BlindingPoint: wireMsg.BlingingPointOrNil(), } pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob)) copy(pd.OnionBlob[:], wireMsg.OnionBlob[:]) @@ -736,6 +744,14 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { Incoming: false, } copy(h.OnionBlob[:], htlc.OnionBlob) + if htlc.BlindingPoint != nil { + h.BlindingPoint = tlv.SomeRecordT( + //nolint:lll + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( + htlc.BlindingPoint, + ), + ) + } if ourCommit && htlc.sig != nil { h.Signature = htlc.sig.Serialize() @@ -760,7 +776,14 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { Incoming: true, } copy(h.OnionBlob[:], htlc.OnionBlob) - + if htlc.BlindingPoint != nil { + h.BlindingPoint = tlv.SomeRecordT( + //nolint:lll + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( + htlc.BlindingPoint, + ), + ) + } if ourCommit && htlc.sig != nil { h.Signature = htlc.sig.Serialize() } @@ -859,6 +882,12 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, theirWitnessScript: theirWitnessScript, } + htlc.BlindingPoint.WhenSome(func(b tlv.RecordT[ + lnwire.BlindingPointTlvType, *btcec.PublicKey]) { + + pd.BlindingPoint = b.Val + }) + return pd, nil } @@ -1548,6 +1577,7 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, HtlcIndex: wireMsg.ID, LogIndex: logUpdate.LogIndex, addCommitHeightRemote: commitHeight, + BlindingPoint: wireMsg.BlingingPointOrNil(), } pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob)) copy(pd.OnionBlob[:], wireMsg.OnionBlob[:]) @@ -1745,6 +1775,7 @@ func (lc *LightningChannel) remoteLogUpdateToPayDesc(logUpdate *channeldb.LogUpd HtlcIndex: wireMsg.ID, LogIndex: logUpdate.LogIndex, addCommitHeightLocal: commitHeight, + BlindingPoint: wireMsg.BlingingPointOrNil(), } pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob)) copy(pd.OnionBlob, wireMsg.OnionBlob[:]) @@ -3607,6 +3638,14 @@ func (lc *LightningChannel) createCommitDiff( PaymentHash: pd.RHash, } copy(htlc.OnionBlob[:], pd.OnionBlob) + if pd.BlindingPoint != nil { + htlc.BlindingPoint = tlv.SomeRecordT( + //nolint:lll + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( + pd.BlindingPoint, + ), + ) + } logUpdate.UpdateMsg = htlc // Gather any references for circuits opened by this Add @@ -3736,12 +3775,21 @@ func (lc *LightningChannel) getUnsignedAckedUpdates() []channeldb.LogUpdate { // four messages that it corresponds to. switch pd.EntryType { case Add: + var b lnwire.BlindingPointRecord + if pd.BlindingPoint != nil { + tlv.SomeRecordT( + //nolint:lll + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](pd.BlindingPoint), + ) + } + htlc := &lnwire.UpdateAddHTLC{ - ChanID: chanID, - ID: pd.HtlcIndex, - Amount: pd.Amount, - Expiry: pd.Timeout, - PaymentHash: pd.RHash, + ChanID: chanID, + ID: pd.HtlcIndex, + Amount: pd.Amount, + Expiry: pd.Timeout, + PaymentHash: pd.RHash, + BlindingPoint: b, } copy(htlc.OnionBlob[:], pd.OnionBlob) logUpdate.UpdateMsg = htlc @@ -5742,6 +5790,14 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) ( Expiry: pd.Timeout, PaymentHash: pd.RHash, } + if pd.BlindingPoint != nil { + htlc.BlindingPoint = tlv.SomeRecordT( + //nolint:lll + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( + pd.BlindingPoint, + ), + ) + } copy(htlc.OnionBlob[:], pd.OnionBlob) logUpdate.UpdateMsg = htlc addUpdates = append(addUpdates, logUpdate) @@ -6079,6 +6135,7 @@ func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC, HtlcIndex: lc.localUpdateLog.htlcCounter, OnionBlob: htlc.OnionBlob[:], OpenCircuitKey: openKey, + BlindingPoint: htlc.BlingingPointOrNil(), } } @@ -6129,13 +6186,14 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, err } pd := &PaymentDescriptor{ - EntryType: Add, - RHash: PaymentHash(htlc.PaymentHash), - Timeout: htlc.Expiry, - Amount: htlc.Amount, - LogIndex: lc.remoteUpdateLog.logIndex, - HtlcIndex: lc.remoteUpdateLog.htlcCounter, - OnionBlob: htlc.OnionBlob[:], + EntryType: Add, + RHash: PaymentHash(htlc.PaymentHash), + Timeout: htlc.Expiry, + Amount: htlc.Amount, + LogIndex: lc.remoteUpdateLog.logIndex, + HtlcIndex: lc.remoteUpdateLog.htlcCounter, + OnionBlob: htlc.OnionBlob[:], + BlindingPoint: htlc.BlingingPointOrNil(), } localACKedIndex := lc.remoteCommitChain.tail().ourMessageIndex diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 53c93aab2e..d224b45983 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -25,6 +25,7 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" ) @@ -3037,7 +3038,6 @@ func TestChanSyncOweCommitment(t *testing.T) { Amount: htlcAmt, Expiry: uint32(10), OnionBlob: fakeOnionBlob, - ExtraData: make([]byte, 0), } htlcIndex, err := bobChannel.AddHTLC(h, nil) @@ -3082,7 +3082,6 @@ func TestChanSyncOweCommitment(t *testing.T) { Amount: htlcAmt, Expiry: uint32(10), OnionBlob: fakeOnionBlob, - ExtraData: make([]byte, 0), } aliceHtlcIndex, err := aliceChannel.AddHTLC(aliceHtlc, nil) require.NoError(t, err, "unable to add alice's htlc") @@ -10421,8 +10420,9 @@ func createRandomHTLC(t *testing.T, incoming bool) channeldb.HTLC { _, err = rand.Read(sig) require.NoError(t, err) - extra := make([]byte, 1000) - _, err = rand.Read(extra) + blinding, err := pubkeyFromHex( + "0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d48236c39", //nolint:lll + ) require.NoError(t, err) return channeldb.HTLC{ @@ -10435,7 +10435,10 @@ func createRandomHTLC(t *testing.T, incoming bool) channeldb.HTLC { OnionBlob: onionBlob, HtlcIndex: rand.Uint64(), LogIndex: rand.Uint64(), - ExtraData: extra, + BlindingPoint: tlv.SomeRecordT( + //nolint:lll + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](blinding), + ), } } @@ -11002,3 +11005,61 @@ func TestEnforceFeeBuffer(t *testing.T) { require.Equal(t, aliceBalance, expectedAmt) } + +// TestBlindingPointPersistence tests persistence of blinding points attached +// to htlcs across restarts. +func TestBlindingPointPersistence(t *testing.T) { + // Create a test channel which will be used for the duration of this + // test. The channel will be funded evenly with Alice having 5 BTC, + // and Bob having 5 BTC. + aliceChannel, bobChannel, err := CreateTestChannels( + t, channeldb.SingleFunderTweaklessBit, + ) + require.NoError(t, err, "unable to create test channels") + + // Send a HTLC from Alice to Bob that has a blinding point populated. + htlc, _ := createHTLC(0, 100_000_000) + blinding, err := pubkeyFromHex( + "0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d48236c39", //nolint:lll + ) + require.NoError(t, err) + htlc.BlindingPoint = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](blinding), + ) + + _, err = aliceChannel.AddHTLC(htlc, nil) + + require.NoError(t, err) + _, err = bobChannel.ReceiveHTLC(htlc) + require.NoError(t, err) + + // Now, Alice will send a new commitment to Bob, which will persist our + // pending HTLC to disk. + aliceCommit, err := aliceChannel.SignNextCommitment() + require.NoError(t, err, "unable to sign commitment") + + // Restart alice to force fetching state from disk. + aliceChannel, err = restartChannel(aliceChannel) + require.NoError(t, err, "unable to restart alice") + + // Assert that the blinding point is restored from disk. + remoteCommit := aliceChannel.remoteCommitChain.tip() + require.Len(t, remoteCommit.outgoingHTLCs, 1) + require.Equal(t, blinding, remoteCommit.outgoingHTLCs[0].BlindingPoint) + + // Next, update bob's commitment and assert that we can still retrieve + // his incoming blinding point after restart. + err = bobChannel.ReceiveNewCommitment(aliceCommit.CommitSigs) + require.NoError(t, err, "bob unable to receive new commitment") + + _, _, _, err = bobChannel.RevokeCurrentCommitment() + require.NoError(t, err, "bob unable to revoke current commitment") + + bobChannel, err = restartChannel(bobChannel) + require.NoError(t, err, "unable to restart bob's channel") + + // Assert that Bob is able to recover the blinding point from disk. + bobCommit := bobChannel.localCommitChain.tip() + require.Len(t, bobCommit.incomingHTLCs, 1) + require.Equal(t, blinding, bobCommit.incomingHTLCs[0].BlindingPoint) +} diff --git a/lnwire/channel_type.go b/lnwire/channel_type.go index a0696048be..de755e135a 100644 --- a/lnwire/channel_type.go +++ b/lnwire/channel_type.go @@ -19,7 +19,7 @@ type ChannelType RawFeatureVector // featureBitLen returns the length in bytes of the encoded feature bits. func (c ChannelType) featureBitLen() uint64 { fv := RawFeatureVector(c) - return uint64(fv.SerializeSize()) + return fv.sizeFunc() } // Record returns a TLV record that can be used to encode/decode the channel @@ -34,25 +34,27 @@ func (c *ChannelType) Record() tlv.Record { // channelTypeEncoder is a custom TLV encoder for the ChannelType record. func channelTypeEncoder(w io.Writer, val interface{}, buf *[8]byte) error { if v, ok := val.(*ChannelType); ok { - // Encode the feature bits as a byte slice without its length - // prepended, as that's already taken care of by the TLV record. fv := RawFeatureVector(*v) - return fv.encode(w, fv.SerializeSize(), 8) + return rawFeatureEncoder(w, &fv, buf) } - return tlv.NewTypeForEncodingErr(val, "lnwire.ChannelType") + return tlv.NewTypeForEncodingErr(val, "*lnwire.ChannelType") } // channelTypeDecoder is a custom TLV decoder for the ChannelType record. -func channelTypeDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { +func channelTypeDecoder(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + if v, ok := val.(*ChannelType); ok { fv := NewRawFeatureVector() - if err := fv.decode(r, int(l), 8); err != nil { + + if err := rawFeatureDecoder(r, fv, buf, l); err != nil { return err } + *v = ChannelType(*fv) return nil } - return tlv.NewTypeForEncodingErr(val, "lnwire.ChannelType") + return tlv.NewTypeForEncodingErr(val, "*lnwire.ChannelType") } diff --git a/lnwire/features.go b/lnwire/features.go index 81472a7544..ab6facc75f 100644 --- a/lnwire/features.go +++ b/lnwire/features.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "io" + + "github.com/lightningnetwork/lnd/tlv" ) var ( @@ -612,6 +614,41 @@ func (fv *RawFeatureVector) decode(r io.Reader, length, width int) error { return nil } +// sizeFunc returns the length required to encode the feature vector. +func (fv *RawFeatureVector) sizeFunc() uint64 { + return uint64(fv.SerializeSize()) +} + +// Record returns a TLV record that can be used to encode/decode raw feature +// vectors. Note that the length of the feature vector is not included, because +// it is covered by the TLV record's length field. +func (fv *RawFeatureVector) Record(recordType tlv.Type) tlv.Record { + return tlv.MakeDynamicRecord( + recordType, fv, fv.sizeFunc, rawFeatureEncoder, + rawFeatureDecoder, + ) +} + +// rawFeatureEncoder is a custom TLV encoder for raw feature vectors. +func rawFeatureEncoder(w io.Writer, val interface{}, _ *[8]byte) error { + if f, ok := val.(*RawFeatureVector); ok { + return f.encode(w, f.SerializeSize(), 8) + } + + return tlv.NewTypeForEncodingErr(val, "*lnwire.RawFeatureVector") +} + +// rawFeatureDecoder is a custom TLV decoder for raw feature vectors. +func rawFeatureDecoder(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if f, ok := val.(*RawFeatureVector); ok { + return f.decode(r, int(l), 8) + } + + return tlv.NewTypeForEncodingErr(val, "*lnwire.RawFeatureVector") +} + // FeatureVector represents a set of enabled features. The set stores // information on enabled flags and metadata about the feature names. A feature // vector is serializable to a compact byte representation that is included in @@ -641,6 +678,50 @@ func EmptyFeatureVector() *FeatureVector { return NewFeatureVector(nil, Features) } +// Record implements the RecordProducer interface for FeatureVector. Note that +// it uses a zero-value type is used to produce the record, as we expect this +// type value to be overwritten when used in generic TLV record production. +// This allows a single Record function to serve in the many different contexts +// in which feature vectors are encoded. This record wraps the encoding/ +// decoding for our raw feature vectors so that we can directly parse fully +// formed feature vector types. +func (fv *FeatureVector) Record() tlv.Record { + return tlv.MakeDynamicRecord(0, fv, fv.sizeFunc, + func(w io.Writer, val interface{}, buf *[8]byte) error { + if f, ok := val.(*FeatureVector); ok { + return rawFeatureEncoder( + w, f.RawFeatureVector, buf, + ) + } + + return tlv.NewTypeForEncodingErr( + val, "*lnwire.FeatureVector", + ) + }, + func(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if f, ok := val.(*FeatureVector); ok { + features := NewFeatureVector(nil, Features) + err := rawFeatureDecoder( + r, features.RawFeatureVector, buf, l, + ) + if err != nil { + return err + } + + *f = *features + + return nil + } + + return tlv.NewTypeForDecodingErr( + val, "*lnwire.FeatureVector", l, l, + ) + }, + ) +} + // HasFeature returns whether a particular feature is included in the set. The // feature can be seen as set either if the bit is set directly OR the queried // bit has the same meaning as its corresponding even/odd bit, which is set @@ -678,6 +759,18 @@ func (fv *FeatureVector) UnknownRequiredFeatures() []FeatureBit { return unknown } +// UnknownFeatures returns a boolean if a feature vector contains *any* +// unknown features (even if they are odd). +func (fv *FeatureVector) UnknownFeatures() bool { + for feature := range fv.features { + if !fv.IsKnown(feature) { + return true + } + } + + return false +} + // Name returns a string identifier for the feature represented by this bit. If // the bit does not represent a known feature, this returns a string indicating // as such. diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 3a1d02c18f..e4c5c6baf5 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -1337,6 +1337,42 @@ func TestLightningWireProtocol(t *testing.T) { v[0] = reflect.ValueOf(req) }, + MsgUpdateAddHTLC: func(v []reflect.Value, r *rand.Rand) { + req := &UpdateAddHTLC{ + ID: r.Uint64(), + Amount: MilliSatoshi(r.Uint64()), + Expiry: r.Uint32(), + } + + _, err := r.Read(req.ChanID[:]) + require.NoError(t, err) + + _, err = r.Read(req.PaymentHash[:]) + require.NoError(t, err) + + _, err = r.Read(req.OnionBlob[:]) + require.NoError(t, err) + + // Generate a blinding point 50% of the time, since not + // all update adds will use route blinding. + if r.Int31()%2 == 0 { + pubkey, err := randPubKey() + if err != nil { + t.Fatalf("unable to generate key: %v", + err) + + return + } + + req.BlindingPoint = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType0]( + pubkey, + ), + ) + } + + v[0] = reflect.ValueOf(*req) + }, } // With the above types defined, we'll now generate a slice of diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index c9ba8d269f..951dc7f54c 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -3,6 +3,9 @@ package lnwire import ( "bytes" "io" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/tlv" ) // OnionPacketSize is the size of the serialized Sphinx onion packet included @@ -11,6 +14,17 @@ import ( // of per-hop data, and a 32-byte HMAC over the entire packet. const OnionPacketSize = 1366 +type ( + // BlindingPointTlvType is the type for ephemeral pubkeys used in + // route blinding. + BlindingPointTlvType = tlv.TlvType0 + + // BlindingPointRecord holds an optional blinding point on update add + // htlc. + //nolint:lll + BlindingPointRecord = tlv.OptionalRecordT[BlindingPointTlvType, *btcec.PublicKey] +) + // UpdateAddHTLC is the message sent by Alice to Bob when she wishes to add an // HTLC to his remote commitment transaction. In addition to information // detailing the value, the ID, expiry, and the onion blob is also included @@ -54,12 +68,29 @@ type UpdateAddHTLC struct { // used in the subsequent UpdateAddHTLC message. OnionBlob [OnionPacketSize]byte + // BlindingPoint is the ephemeral pubkey used to optionally blind the + // next hop for this htlc. + BlindingPoint BlindingPointRecord + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. ExtraData ExtraOpaqueData } +// BlingingPointOrNil returns the blinding point associated with the update, or +// nil. +func (c *UpdateAddHTLC) BlingingPointOrNil() *btcec.PublicKey { + var blindingPoint *btcec.PublicKey + c.BlindingPoint.WhenSome(func(b tlv.RecordT[BlindingPointTlvType, + *btcec.PublicKey]) { + + blindingPoint = b.Val + }) + + return blindingPoint +} + // NewUpdateAddHTLC returns a new empty UpdateAddHTLC message. func NewUpdateAddHTLC() *UpdateAddHTLC { return &UpdateAddHTLC{} @@ -74,7 +105,7 @@ var _ Message = (*UpdateAddHTLC)(nil) // // This is part of the lnwire.Message interface. func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, + if err := ReadElements(r, &c.ChanID, &c.ID, &c.Amount, @@ -82,7 +113,27 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { &c.Expiry, c.OnionBlob[:], &c.ExtraData, - ) + ); err != nil { + return err + } + + blindingRecord := c.BlindingPoint.Zero() + tlvMap, err := c.ExtraData.ExtractRecords(&blindingRecord) + if err != nil { + return err + } + + if val, ok := tlvMap[c.BlindingPoint.TlvType()]; ok && val == nil { + c.BlindingPoint = tlv.SomeRecordT(blindingRecord) + } + + // Set extra data to nil if we didn't parse anything out of it so that + // we can use assert.Equal in tests. + if len(tlvMap) == 0 { + c.ExtraData = nil + } + + return nil } // Encode serializes the target UpdateAddHTLC into the passed io.Writer @@ -114,6 +165,20 @@ func (c *UpdateAddHTLC) Encode(w *bytes.Buffer, pver uint32) error { return err } + // Only include blinding point in extra data if present. + var records []tlv.RecordProducer + + c.BlindingPoint.WhenSome(func(b tlv.RecordT[BlindingPointTlvType, + *btcec.PublicKey]) { + + records = append(records, &b) + }) + + err := EncodeMessageExtraData(&c.ExtraData, records...) + if err != nil { + return err + } + return WriteBytes(w, c.ExtraData) } diff --git a/peer/brontide.go b/peer/brontide.go index 07fb5731e7..989db3a8d4 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -45,6 +45,7 @@ import ( "github.com/lightningnetwork/lnd/queue" "github.com/lightningnetwork/lnd/subscribe" "github.com/lightningnetwork/lnd/ticker" + "github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/watchtower/wtclient" ) @@ -1985,8 +1986,19 @@ func messageSummary(msg lnwire.Message) string { msg.FeeSatoshis) case *lnwire.UpdateAddHTLC: - return fmt.Sprintf("chan_id=%v, id=%v, amt=%v, expiry=%v, hash=%x", - msg.ChanID, msg.ID, msg.Amount, msg.Expiry, msg.PaymentHash[:]) + var blindingPoint []byte + msg.BlindingPoint.WhenSome( + func(b tlv.RecordT[lnwire.BlindingPointTlvType, + *btcec.PublicKey]) { + + blindingPoint = b.Val.SerializeCompressed() + }, + ) + + return fmt.Sprintf("chan_id=%v, id=%v, amt=%v, expiry=%v, "+ + "hash=%x, blinding_point=%x", msg.ChanID, msg.ID, + msg.Amount, msg.Expiry, msg.PaymentHash[:], + blindingPoint) case *lnwire.UpdateFailHTLC: return fmt.Sprintf("chan_id=%v, id=%v, reason=%x", msg.ChanID, diff --git a/record/blinded_data.go b/record/blinded_data.go new file mode 100644 index 0000000000..7990fa7388 --- /dev/null +++ b/record/blinded_data.go @@ -0,0 +1,304 @@ +package record + +import ( + "bytes" + "encoding/binary" + "io" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +// BlindedRouteData contains the information that is included in a blinded +// route encrypted data blob that is created by the recipient to provide +// forwarding information. +type BlindedRouteData struct { + // ShortChannelID is the channel ID of the next hop. + ShortChannelID tlv.RecordT[tlv.TlvType2, lnwire.ShortChannelID] + + // NextBlindingOverride is a blinding point that should be switched + // in for the next hop. This is used to combine two blinded paths into + // one (which primarily is used in onion messaging, but in theory + // could be used for payments as well). + NextBlindingOverride tlv.OptionalRecordT[tlv.TlvType8, *btcec.PublicKey] + + // RelayInfo provides the relay parameters for the hop. + RelayInfo tlv.RecordT[tlv.TlvType10, PaymentRelayInfo] + + // Constraints provides the payment relay constraints for the hop. + Constraints tlv.OptionalRecordT[tlv.TlvType12, PaymentConstraints] + + // Features is the set of features the payment requires. + Features tlv.OptionalRecordT[tlv.TlvType14, lnwire.FeatureVector] +} + +// NewBlindedRouteData creates the data that's provided for hops within a +// blinded route. +func NewBlindedRouteData(chanID lnwire.ShortChannelID, + blindingOverride *btcec.PublicKey, relayInfo PaymentRelayInfo, + constraints *PaymentConstraints, + features *lnwire.FeatureVector) *BlindedRouteData { + + info := &BlindedRouteData{ + ShortChannelID: tlv.NewRecordT[tlv.TlvType2](chanID), + RelayInfo: tlv.NewRecordT[tlv.TlvType10](relayInfo), + } + + if blindingOverride != nil { + info.NextBlindingOverride = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType8](blindingOverride)) + } + + if constraints != nil { + info.Constraints = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType12](*constraints)) + } + + if features != nil { + info.Features = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType14](*features), + ) + } + + return info +} + +// DecodeBlindedRouteData decodes the data provided within a blinded route. +func DecodeBlindedRouteData(r io.Reader) (*BlindedRouteData, error) { + var ( + d BlindedRouteData + + blindingOverride = d.NextBlindingOverride.Zero() + constraints = d.Constraints.Zero() + features = d.Features.Zero() + ) + + var tlvRecords lnwire.ExtraOpaqueData + if err := lnwire.ReadElements(r, &tlvRecords); err != nil { + return nil, err + } + + typeMap, err := tlvRecords.ExtractRecords( + &d.ShortChannelID, + &blindingOverride, &d.RelayInfo.Val, &constraints, + &features, + ) + if err != nil { + return nil, err + } + + val, ok := typeMap[d.NextBlindingOverride.TlvType()] + if ok && val == nil { + d.NextBlindingOverride = tlv.SomeRecordT(blindingOverride) + } + + if val, ok := typeMap[d.Constraints.TlvType()]; ok && val == nil { + d.Constraints = tlv.SomeRecordT(constraints) + } + + if val, ok := typeMap[d.Features.TlvType()]; ok && val == nil { + d.Features = tlv.SomeRecordT(features) + } + + return &d, nil +} + +// EncodeBlindedRouteData encodes the blinded route data provided. +func EncodeBlindedRouteData(data *BlindedRouteData) ([]byte, error) { + var ( + e lnwire.ExtraOpaqueData + recordProducers = make([]tlv.RecordProducer, 0, 5) + ) + + recordProducers = append(recordProducers, &data.ShortChannelID) + + data.NextBlindingOverride.WhenSome(func(pk tlv.RecordT[tlv.TlvType8, + *btcec.PublicKey]) { + + recordProducers = append(recordProducers, &pk) + }) + + recordProducers = append(recordProducers, &data.RelayInfo.Val) + + data.Constraints.WhenSome(func(cs tlv.RecordT[tlv.TlvType12, + PaymentConstraints]) { + + recordProducers = append(recordProducers, &cs) + }) + + data.Features.WhenSome(func(f tlv.RecordT[tlv.TlvType14, + lnwire.FeatureVector]) { + + recordProducers = append(recordProducers, &f) + }) + + if err := e.PackRecords(recordProducers...); err != nil { + return nil, err + } + + return e[:], nil +} + +// PaymentRelayInfo describes the relay policy for a blinded path. +type PaymentRelayInfo struct { + // CltvExpiryDelta is the expiry delta for the payment. + CltvExpiryDelta uint16 + + // FeeRate is the fee rate that will be charged per millionth of a + // satoshi. + FeeRate uint32 + + // BaseFee is the per-htlc fee charged. + BaseFee uint32 +} + +// newPaymentRelayRecord creates a tlv.Record that encodes the payment relay +// (type 10) type for an encrypted blob payload. +func (i *PaymentRelayInfo) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 10, &i, func() uint64 { + // uint16 + uint32 + tuint32 + return 2 + 4 + tlv.SizeTUint32(i.BaseFee) + }, encodePaymentRelay, decodePaymentRelay, + ) +} + +func encodePaymentRelay(w io.Writer, val interface{}, buf *[8]byte) error { + if t, ok := val.(**PaymentRelayInfo); ok { + relayInfo := *t + + // Just write our first 6 bytes directly. + binary.BigEndian.PutUint16(buf[:2], relayInfo.CltvExpiryDelta) + binary.BigEndian.PutUint32(buf[2:6], relayInfo.FeeRate) + if _, err := w.Write(buf[0:6]); err != nil { + return err + } + + // We can safely reuse buf here because we overwrite its + // contents. + return tlv.ETUint32(w, &relayInfo.BaseFee, buf) + } + + return tlv.NewTypeForEncodingErr(val, "**hop.PaymentRelayInfo") +} + +func decodePaymentRelay(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if t, ok := val.(**PaymentRelayInfo); ok && l <= 10 { + scratch := make([]byte, l) + + n, err := io.ReadFull(r, scratch) + if err != nil { + return err + } + + // We expect at least 6 bytes, because we have 2 bytes for + // cltv delta and 4 bytes for fee rate. + if n < 6 { + return tlv.NewTypeForDecodingErr(val, + "*hop.paymentRelayInfo", uint64(n), 6) + } + + relayInfo := *t + + relayInfo.CltvExpiryDelta = binary.BigEndian.Uint16( + scratch[0:2], + ) + relayInfo.FeeRate = binary.BigEndian.Uint32(scratch[2:6]) + + // To be able to re-use the DTUint32 function we create a + // buffer with just the bytes holding the variable length u32. + // If the base fee is zero, this will be an empty buffer, which + // is okay. + b := bytes.NewBuffer(scratch[6:]) + + return tlv.DTUint32(b, &relayInfo.BaseFee, buf, l-6) + } + + return tlv.NewTypeForDecodingErr(val, "*hop.paymentRelayInfo", l, 10) +} + +// PaymentConstraints is a set of restrictions on a payment. +type PaymentConstraints struct { + // MaxCltvExpiry is the maximum expiry height for the payment. + MaxCltvExpiry uint32 + + // HtlcMinimumMsat is the minimum htlc size for the payment. + HtlcMinimumMsat lnwire.MilliSatoshi +} + +func (p *PaymentConstraints) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 12, &p, func() uint64 { + // uint32 + tuint64. + return 4 + tlv.SizeTUint64(uint64( + p.HtlcMinimumMsat, + )) + }, + encodePaymentConstraints, decodePaymentConstraints, + ) +} + +func encodePaymentConstraints(w io.Writer, val interface{}, + buf *[8]byte) error { + + if c, ok := val.(**PaymentConstraints); ok { + constraints := *c + + binary.BigEndian.PutUint32(buf[:4], constraints.MaxCltvExpiry) + if _, err := w.Write(buf[:4]); err != nil { + return err + } + + // We can safely re-use buf here because we overwrite its + // contents. + htlcMsat := uint64(constraints.HtlcMinimumMsat) + + return tlv.ETUint64(w, &htlcMsat, buf) + } + + return tlv.NewTypeForEncodingErr(val, "**PaymentConstraints") +} + +func decodePaymentConstraints(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if c, ok := val.(**PaymentConstraints); ok && l <= 12 { + scratch := make([]byte, l) + + n, err := io.ReadFull(r, scratch) + if err != nil { + return err + } + + // We expect at least 4 bytes for our uint32. + if n < 4 { + return tlv.NewTypeForDecodingErr(val, + "*paymentConstraints", uint64(n), 4) + } + + payConstraints := *c + + payConstraints.MaxCltvExpiry = binary.BigEndian.Uint32( + scratch[:4], + ) + + // This could be empty if our minimum is zero, that's okay. + var ( + b = bytes.NewBuffer(scratch[4:]) + minHtlc uint64 + ) + + err = tlv.DTUint64(b, &minHtlc, buf, l-4) + if err != nil { + return err + } + payConstraints.HtlcMinimumMsat = lnwire.MilliSatoshi(minHtlc) + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "**PaymentConstraints", l, l) +} diff --git a/record/blinded_data_test.go b/record/blinded_data_test.go new file mode 100644 index 0000000000..f8e95cdcc0 --- /dev/null +++ b/record/blinded_data_test.go @@ -0,0 +1,197 @@ +package record + +import ( + "bytes" + "encoding/hex" + "fmt" + "math" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" +) + +//nolint:lll +const pubkeyStr = "02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619" + +func pubkey(t *testing.T) *btcec.PublicKey { + t.Helper() + + nodeBytes, err := hex.DecodeString(pubkeyStr) + require.NoError(t, err) + + nodePk, err := btcec.ParsePubKey(nodeBytes) + require.NoError(t, err) + + return nodePk +} + +// TestBlindedDataEncoding tests encoding and decoding of blinded data blobs. +// These tests specifically cover cases where the variable length encoded +// integers values have different numbers of leading zeros trimmed because +// these TLVs are the first composite records with variable length tlvs +// (previously, a variable length integer would take up the whole record). +func TestBlindedDataEncoding(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + baseFee uint32 + htlcMin lnwire.MilliSatoshi + features *lnwire.FeatureVector + constraints bool + }{ + { + name: "zero variable values", + baseFee: 0, + htlcMin: 0, + }, + { + name: "zeros trimmed", + baseFee: math.MaxUint32 / 2, + htlcMin: math.MaxUint64 / 2, + }, + { + name: "no zeros trimmed", + baseFee: math.MaxUint32, + htlcMin: math.MaxUint64, + }, + { + name: "nil feature vector", + features: nil, + }, + { + name: "non-nil, but empty feature vector", + features: lnwire.EmptyFeatureVector(), + }, + { + name: "populated feature vector", + features: lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector(lnwire.AMPOptional), + lnwire.Features, + ), + }, + { + name: "no payment constraints", + constraints: true, + }, + } + + for _, testCase := range tests { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + // Create a standard set of blinded route data, using + // the values from our test case for the variable + // length encoded values. + channelID := lnwire.NewShortChanIDFromInt(1) + info := PaymentRelayInfo{ + FeeRate: 2, + CltvExpiryDelta: 3, + BaseFee: testCase.baseFee, + } + + var constraints *PaymentConstraints + if testCase.constraints { + constraints = &PaymentConstraints{ + MaxCltvExpiry: 4, + HtlcMinimumMsat: testCase.htlcMin, + } + } + + encodedData := NewBlindedRouteData( + channelID, pubkey(t), info, constraints, + testCase.features, + ) + + encoded, err := EncodeBlindedRouteData(encodedData) + require.NoError(t, err) + + b := bytes.NewBuffer(encoded) + decodedData, err := DecodeBlindedRouteData(b) + require.NoError(t, err) + + require.Equal(t, encodedData, decodedData) + }) + } +} + +// TestBlindedRouteVectors tests encoding/decoding of the test vectors for +// blinded route data provided in the specification. +// +//nolint:lll +func TestBlindingSpecTestVectors(t *testing.T) { + nextBlindingOverrideStr, err := hex.DecodeString("031b84c5567b126440995d3ed5aaba0565d71e1834604819ff9c17f5e9d5dd078f") + require.NoError(t, err) + nextBlindingOverride, err := btcec.ParsePubKey(nextBlindingOverrideStr) + require.NoError(t, err) + + tests := []struct { + encoded string + expectedPaymentData *BlindedRouteData + }{ + { + encoded: "011a0000000000000000000000000000000000000000000000000000020800000000000006c10a0800240000009627100c06000b69e505dc0e00fd023103123456", + expectedPaymentData: NewBlindedRouteData( + lnwire.ShortChannelID{ + BlockHeight: 0, + TxIndex: 0, + TxPosition: 1729, + }, + nil, + PaymentRelayInfo{ + CltvExpiryDelta: 36, + FeeRate: 150, + BaseFee: 10000, + }, + &PaymentConstraints{ + MaxCltvExpiry: 748005, + HtlcMinimumMsat: 1500, + }, + lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector(), + lnwire.Features, + ), + ), + }, + { + encoded: "020800000000000004510821031b84c5567b126440995d3ed5aaba0565d71e1834604819ff9c17f5e9d5dd078f0a0800300000006401f40c06000b69c105dc0e00", + expectedPaymentData: NewBlindedRouteData( + lnwire.ShortChannelID{ + TxPosition: 1105, + }, + nextBlindingOverride, + PaymentRelayInfo{ + CltvExpiryDelta: 48, + FeeRate: 100, + BaseFee: 500, + }, + &PaymentConstraints{ + MaxCltvExpiry: 747969, + HtlcMinimumMsat: 1500, + }, + lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector(), + lnwire.Features, + )), + }, + } + + for i, test := range tests { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + route, err := hex.DecodeString(test.encoded) + require.NoError(t, err) + + buff := bytes.NewBuffer(route) + + decodedRoute, err := DecodeBlindedRouteData(buff) + require.NoError(t, err) + + require.Equal( + t, test.expectedPaymentData, decodedRoute, + ) + }) + } +} diff --git a/routing/blinding.go b/routing/blinding.go index 61d303a0c9..d2d64aa5dd 100644 --- a/routing/blinding.go +++ b/routing/blinding.go @@ -36,13 +36,13 @@ type BlindedPayment struct { // blinded path. BaseFee uint32 - // ProportionalFee is the aggregated proportional fee for payments - // made over the blinded path. - ProportionalFee uint32 + // ProportionalFeeRate is the aggregated proportional fee rate for + // payments made over the blinded path. + ProportionalFeeRate uint32 - // CltvExpiryDelta is the total expiry delta for the blinded path. Note - // this does not include the final cltv delta for the receiving node - // (which should be provided in an invoice). + // CltvExpiryDelta is the total expiry delta for the blinded path. This + // field includes the CLTV for the blinded hops *and* the final cltv + // delta for the receiver. CltvExpiryDelta uint16 // HtlcMinimum is the highest HLTC minimum supported along the blinded @@ -122,7 +122,7 @@ func (b *BlindedPayment) toRouteHints() RouteHints { MaxHTLC: lnwire.MilliSatoshi(b.HtlcMaximum), FeeBaseMSat: lnwire.MilliSatoshi(b.BaseFee), FeeProportionalMillionths: lnwire.MilliSatoshi( - b.ProportionalFee, + b.ProportionalFeeRate, ), ToNodePubKey: func() route.Vertex { return route.NewVertex( diff --git a/routing/blinding_test.go b/routing/blinding_test.go index 561ace6fc1..69b25c696f 100644 --- a/routing/blinding_test.go +++ b/routing/blinding_test.go @@ -121,12 +121,12 @@ func TestBlindedPaymentToHints(t *testing.T) { {}, }, }, - BaseFee: baseFee, - ProportionalFee: ppmFee, - CltvExpiryDelta: cltvDelta, - HtlcMinimum: htlcMin, - HtlcMaximum: htlcMax, - Features: features, + BaseFee: baseFee, + ProportionalFeeRate: ppmFee, + CltvExpiryDelta: cltvDelta, + HtlcMinimum: htlcMin, + HtlcMaximum: htlcMax, + Features: features, } require.Nil(t, blindedPayment.toRouteHints())