diff --git a/docs/release-notes/release-notes-0.18.0.md b/docs/release-notes/release-notes-0.18.0.md index bbc0e37106..0dde33b8e3 100644 --- a/docs/release-notes/release-notes-0.18.0.md +++ b/docs/release-notes/release-notes-0.18.0.md @@ -123,6 +123,9 @@ ## Breaking Changes ## Performance Improvements +* Watchtower client DB migration to massively [improve the start-up + performance](https://github.com/lightningnetwork/lnd/pull/8222) of a client. + # Technical and Architectural Updates ## BOLT Spec Updates diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 412412c1e3..f9036e87fc 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -17,6 +17,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -295,9 +296,8 @@ type TowerClient struct { closableSessionQueue *sessionCloseMinHeap - backupMu sync.Mutex - summaries wtdb.ChannelSummaries - chanCommitHeights map[lnwire.ChannelID]uint64 + backupMu sync.Mutex + chanInfos wtdb.ChannelInfos statTicker *time.Ticker stats *ClientStats @@ -339,9 +339,7 @@ func New(config *Config) (*TowerClient, error) { plog := build.NewPrefixLog(prefix, log) - // Load the sweep pkscripts that have been generated for all previously - // registered channels. - chanSummaries, err := cfg.DB.FetchChanSummaries() + chanInfos, err := cfg.DB.FetchChanInfos() if err != nil { return nil, err } @@ -358,9 +356,8 @@ func New(config *Config) (*TowerClient, error) { cfg: cfg, log: plog, pipeline: queue, - chanCommitHeights: make(map[lnwire.ChannelID]uint64), activeSessions: newSessionQueueSet(), - summaries: chanSummaries, + chanInfos: chanInfos, closableSessionQueue: newSessionCloseMinHeap(), statTicker: time.NewTicker(DefaultStatInterval), stats: new(ClientStats), @@ -369,44 +366,6 @@ func New(config *Config) (*TowerClient, error) { quit: make(chan struct{}), } - // perUpdate is a callback function that will be used to inspect the - // full set of candidate client sessions loaded from disk, and to - // determine the highest known commit height for each channel. This - // allows the client to reject backups that it has already processed for - // its active policy. - perUpdate := func(policy wtpolicy.Policy, chanID lnwire.ChannelID, - commitHeight uint64) { - - // We only want to consider accepted updates that have been - // accepted under an identical policy to the client's current - // policy. - if policy != c.cfg.Policy { - return - } - - c.backupMu.Lock() - defer c.backupMu.Unlock() - - // Take the highest commit height found in the session's acked - // updates. - height, ok := c.chanCommitHeights[chanID] - if !ok || commitHeight > height { - c.chanCommitHeights[chanID] = commitHeight - } - } - - perMaxHeight := func(s *wtdb.ClientSession, chanID lnwire.ChannelID, - height uint64) { - - perUpdate(s.Policy, chanID, height) - } - - perCommittedUpdate := func(s *wtdb.ClientSession, - u *wtdb.CommittedUpdate) { - - perUpdate(s.Policy, u.BackupID.ChanID, u.BackupID.CommitHeight) - } - candidateTowers := newTowerListIterator() perActiveTower := func(tower *Tower) { // If the tower has already been marked as active, then there is @@ -429,8 +388,6 @@ func New(config *Config) (*TowerClient, error) { candidateSessions, err := getTowerAndSessionCandidates( cfg.DB, cfg.SecretKeyRing, perActiveTower, wtdb.WithPreEvalFilterFn(c.genSessionFilter(true)), - wtdb.WithPerMaxHeight(perMaxHeight), - wtdb.WithPerCommittedUpdate(perCommittedUpdate), wtdb.WithPostEvalFilterFn(ExhaustedSessionFilter()), ) if err != nil { @@ -594,7 +551,7 @@ func (c *TowerClient) Start() error { // Iterate over the list of registered channels and check if // any of them can be marked as closed. - for id := range c.summaries { + for id := range c.chanInfos { isClosed, closedHeight, err := c.isChannelClosed(id) if err != nil { returnErr = err @@ -615,7 +572,7 @@ func (c *TowerClient) Start() error { // Since the channel has been marked as closed, we can // also remove it from the channel summaries map. - delete(c.summaries, id) + delete(c.chanInfos, id) } // Load all closable sessions. @@ -732,7 +689,7 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error { // If a pkscript for this channel already exists, the channel has been // previously registered. - if _, ok := c.summaries[chanID]; ok { + if _, ok := c.chanInfos[chanID]; ok { return nil } @@ -752,8 +709,10 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error { // Finally, cache the pkscript in our in-memory cache to avoid db // lookups for the remainder of the daemon's execution. - c.summaries[chanID] = wtdb.ClientChanSummary{ - SweepPkScript: pkScript, + c.chanInfos[chanID] = &wtdb.ChannelInfo{ + ClientChanSummary: wtdb.ClientChanSummary{ + SweepPkScript: pkScript, + }, } return nil @@ -770,16 +729,23 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, // Make sure that this channel is registered with the tower client. c.backupMu.Lock() - if _, ok := c.summaries[*chanID]; !ok { + info, ok := c.chanInfos[*chanID] + if !ok { c.backupMu.Unlock() return ErrUnregisteredChannel } // Ignore backups that have already been presented to the client. - height, ok := c.chanCommitHeights[*chanID] - if ok && stateNum <= height { + var duplicate bool + info.MaxHeight.WhenSome(func(maxHeight uint64) { + if stateNum <= maxHeight { + duplicate = true + } + }) + if duplicate { c.backupMu.Unlock() + c.log.Debugf("Ignoring duplicate backup for chanid=%v at "+ "height=%d", chanID, stateNum) @@ -789,7 +755,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, // This backup has a higher commit height than any known backup for this // channel. We'll update our tip so that we won't accept it again if the // link flaps. - c.chanCommitHeights[*chanID] = stateNum + c.chanInfos[*chanID].MaxHeight = fn.Some(stateNum) c.backupMu.Unlock() id := &wtdb.BackupID{ @@ -899,7 +865,7 @@ func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID, defer c.backupMu.Unlock() // We only care about channels registered with the tower client. - if _, ok := c.summaries[chanID]; !ok { + if _, ok := c.chanInfos[chanID]; !ok { return nil } @@ -924,8 +890,7 @@ func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID, return fmt.Errorf("could not track closable sessions: %w", err) } - delete(c.summaries, chanID) - delete(c.chanCommitHeights, chanID) + delete(c.chanInfos, chanID) return nil } @@ -1332,7 +1297,7 @@ func (c *TowerClient) backupDispatcher() { // the prevTask, and should be reprocessed after obtaining a new sessionQueue. func (c *TowerClient) processTask(task *wtdb.BackupID) { c.backupMu.Lock() - summary, ok := c.summaries[task.ChanID] + summary, ok := c.chanInfos[task.ChanID] if !ok { c.backupMu.Unlock() diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 0f7f1b5391..e691552881 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -81,10 +81,10 @@ type DB interface { // successfully backed up using the given session. NumAckedUpdates(id *wtdb.SessionID) (uint64, error) - // FetchChanSummaries loads a mapping from all registered channels to - // their channel summaries. Only the channels that have not yet been + // FetchChanInfos loads a mapping from all registered channels to + // their wtdb.ChannelInfo. Only the channels that have not yet been // marked as closed will be loaded. - FetchChanSummaries() (wtdb.ChannelSummaries, error) + FetchChanInfos() (wtdb.ChannelInfos, error) // MarkChannelClosed will mark a registered channel as closed by setting // its closed-height as the given block height. It returns a list of diff --git a/watchtower/wtdb/client_chan_summary.go b/watchtower/wtdb/client_chan_summary.go index d4b3c3c388..6fec34c842 100644 --- a/watchtower/wtdb/client_chan_summary.go +++ b/watchtower/wtdb/client_chan_summary.go @@ -3,11 +3,29 @@ package wtdb import ( "io" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnwire" ) -// ChannelSummaries is a map for a given channel id to it's ClientChanSummary. -type ChannelSummaries map[lnwire.ChannelID]ClientChanSummary +// ChannelInfos is a map for a given channel id to it's ChannelInfo. +type ChannelInfos map[lnwire.ChannelID]*ChannelInfo + +// ChannelInfo contains various useful things about a registered channel. +// +// NOTE: the reason for adding this struct which wraps ClientChanSummary +// instead of extending ClientChanSummary is for faster look-up of added fields. +// If we were to extend ClientChanSummary instead then we would need to decode +// the entire struct each time we want to read the new fields and then re-encode +// the struct each time we want to write to a new field. +type ChannelInfo struct { + ClientChanSummary + + // MaxHeight is the highest commitment height that the tower has been + // handed for this channel. An Option type is used to store this since + // a commitment height of zero is valid, and we need a way of knowing if + // we have seen a new height yet or not. + MaxHeight fn.Option[uint64] +} // ClientChanSummary tracks channel-specific information. A new // ClientChanSummary is inserted in the database the first time the client diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 084f2dcfe0..635c6cfa8f 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" @@ -25,6 +26,7 @@ var ( // => cChanDBID -> db-assigned-id // => cChanSessions => db-session-id -> 1 // => cChanClosedHeight -> block-height + // => cChanMaxCommitmentHeight -> commitment-height cChanDetailsBkt = []byte("client-channel-detail-bucket") // cChanSessions is a sub-bucket of cChanDetailsBkt which stores: @@ -45,6 +47,13 @@ var ( // body of ClientChanSummary. cChannelSummary = []byte("client-channel-summary") + // cChanMaxCommitmentHeight is a key used in the cChanDetailsBkt used + // to store the highest commitment height for this channel that the + // tower has been handed. + cChanMaxCommitmentHeight = []byte( + "client-channel-max-commitment-height", + ) + // cSessionBkt is a top-level bucket storing: // session-id => cSessionBody -> encoded ClientSessionBody // => cSessionDBID -> db-assigned-id @@ -1300,11 +1309,11 @@ func (c *ClientDB) NumAckedUpdates(id *SessionID) (uint64, error) { return numAcked, nil } -// FetchChanSummaries loads a mapping from all registered channels to their -// channel summaries. Only the channels that have not yet been marked as closed -// will be loaded. -func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { - var summaries map[lnwire.ChannelID]ClientChanSummary +// FetchChanInfos loads a mapping from all registered channels to their +// ChannelInfo. Only the channels that have not yet been marked as closed will +// be loaded. +func (c *ClientDB) FetchChanInfos() (ChannelInfos, error) { + var infos ChannelInfos err := kvdb.View(c.db, func(tx kvdb.RTx) error { chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt) @@ -1317,34 +1326,47 @@ func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { if chanDetails == nil { return ErrCorruptChanDetails } - // If this channel has already been marked as closed, // then its summary does not need to be loaded. closedHeight := chanDetails.Get(cChanClosedHeight) if len(closedHeight) > 0 { return nil } - var chanID lnwire.ChannelID copy(chanID[:], k) - summary, err := getChanSummary(chanDetails) if err != nil { return err } - summaries[chanID] = *summary + info := &ChannelInfo{ + ClientChanSummary: *summary, + } + + maxHeightBytes := chanDetails.Get( + cChanMaxCommitmentHeight, + ) + if len(maxHeightBytes) != 0 { + height, err := readBigSize(maxHeightBytes) + if err != nil { + return err + } + + info.MaxHeight = fn.Some(height) + } + + infos[chanID] = info return nil }) }, func() { - summaries = make(map[lnwire.ChannelID]ClientChanSummary) + infos = make(ChannelInfos) }) if err != nil { return nil, err } - return summaries, nil + return infos, nil } // RegisterChannel registers a channel for use within the client database. For @@ -1963,6 +1985,12 @@ func (c *ClientDB) CommitUpdate(id *SessionID, return err } + // Update the channel's max commitment height if needed. + err = maybeUpdateMaxCommitHeight(tx, update.BackupID) + if err != nil { + return err + } + // Finally, capture the session's last applied value so it can // be sent in the next state update to the tower. lastApplied = session.TowerLastApplied @@ -2178,9 +2206,11 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, // GetDBQueue returns a BackupID Queue instance under the given namespace. func (c *ClientDB) GetDBQueue(namespace []byte) Queue[*BackupID] { - return NewQueueDB[*BackupID]( + return NewQueueDB( c.db, namespace, func() *BackupID { return &BackupID{} + }, func(tx kvdb.RwTx, item *BackupID) error { + return maybeUpdateMaxCommitHeight(tx, *item) }, ) } @@ -2720,6 +2750,58 @@ func getDBSessionID(sessionsBkt kvdb.RBucket, sessionID SessionID) (uint64, return id, idBytes, nil } +// maybeUpdateMaxCommitHeight updates the given channel details bucket with the +// given height if it is larger than the current max height stored for the +// channel. +func maybeUpdateMaxCommitHeight(tx kvdb.RwTx, backupID BackupID) error { + chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return ErrUninitializedDB + } + + // If an entry for this channel does not exist in the channel details + // bucket then we exit here as this means that the channel has been + // closed. + chanDetails := chanDetailsBkt.NestedReadWriteBucket(backupID.ChanID[:]) + if chanDetails == nil { + return nil + } + + putHeight := func() error { + b, err := writeBigSize(backupID.CommitHeight) + if err != nil { + return err + } + + return chanDetails.Put( + cChanMaxCommitmentHeight, b, + ) + } + + // Get current height. + heightBytes := chanDetails.Get(cChanMaxCommitmentHeight) + + // The height might have not been set yet, in which case + // we can just write the new height. + if len(heightBytes) == 0 { + return putHeight() + } + + // Otherwise, read in the current max commitment height for the channel. + currentHeight, err := readBigSize(heightBytes) + if err != nil { + return err + } + + // If the new height is not larger than the current persisted height, + // then there is nothing left for us to do. + if backupID.CommitHeight <= currentHeight { + return nil + } + + return putHeight() +} + func getRealSessionID(sessIDIndexBkt kvdb.RBucket, dbID uint64) (*SessionID, error) { diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 6d11a69728..36cc049a9c 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -156,13 +156,13 @@ func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID, return tower } -func (h *clientDBHarness) fetchChanSummaries() map[lnwire.ChannelID]wtdb.ClientChanSummary { +func (h *clientDBHarness) fetchChanInfos() wtdb.ChannelInfos { h.t.Helper() - summaries, err := h.db.FetchChanSummaries() + infos, err := h.db.FetchChanInfos() require.NoError(h.t, err) - return summaries + return infos } func (h *clientDBHarness) registerChan(chanID lnwire.ChannelID, @@ -552,7 +552,7 @@ func testRemoveTower(h *clientDBHarness) { func testChanSummaries(h *clientDBHarness) { // First, assert that this channel is not already registered. var chanID lnwire.ChannelID - _, ok := h.fetchChanSummaries()[chanID] + _, ok := h.fetchChanInfos()[chanID] require.Falsef(h.t, ok, "pkscript for channel %x should not exist yet", chanID) @@ -565,7 +565,7 @@ func testChanSummaries(h *clientDBHarness) { // Assert that the channel exists and that its sweep pkscript matches // the one we registered. - summary, ok := h.fetchChanSummaries()[chanID] + summary, ok := h.fetchChanInfos()[chanID] require.Truef(h.t, ok, "pkscript for channel %x should not exist yet", chanID) require.Equal(h.t, expPkScript, summary.SweepPkScript) @@ -767,6 +767,58 @@ func testRogueUpdates(h *clientDBHarness) { require.Len(h.t, closableSessionsMap, 1) } +// testMaxCommitmentHeights tests that the max known commitment height of a +// channel is properly persisted. +func testMaxCommitmentHeights(h *clientDBHarness) { + const maxUpdates = 5 + t := h.t + + // Initially, we expect no channels. + infos := h.fetchChanInfos() + require.Empty(t, infos) + + // Create a new tower. + tower := h.newTower() + + // Create and insert a new session. + session1 := h.randSession(t, tower.ID, maxUpdates) + h.insertSession(session1, nil) + + // Create a new channel and register it. + chanID1 := randChannelID(t) + h.registerChan(chanID1, nil, nil) + + // At this point, we expect one channel to be returned from + // fetchChanInfos but with an unset max height. + infos = h.fetchChanInfos() + require.Len(t, infos, 1) + + info, ok := infos[chanID1] + require.True(t, ok) + require.True(t, info.MaxHeight.IsNone()) + + // Commit and ACK some updates for this channel. + for i := 1; i <= maxUpdates; i++ { + update := randCommittedUpdateForChanWithHeight( + t, chanID1, uint16(i), uint64(i-1), + ) + lastApplied := h.commitUpdate(&session1.ID, update, nil) + h.ackUpdate(&session1.ID, uint16(i), lastApplied, nil) + } + + // Assert that the max height has now been set accordingly for this + // channel. + infos = h.fetchChanInfos() + require.Len(t, infos, 1) + + info, ok = infos[chanID1] + require.True(t, ok) + require.True(t, info.MaxHeight.IsSome()) + info.MaxHeight.WhenSome(func(u uint64) { + require.EqualValues(t, maxUpdates-1, u) + }) +} + // testMarkChannelClosed asserts the behaviour of MarkChannelClosed. func testMarkChannelClosed(h *clientDBHarness) { tower := h.newTower() @@ -1097,6 +1149,10 @@ func TestClientDB(t *testing.T) { name: "rogue updates", run: testRogueUpdates, }, + { + name: "max commitment heights", + run: testMaxCommitmentHeights, + }, } for _, database := range dbs { diff --git a/watchtower/wtdb/log.go b/watchtower/wtdb/log.go index 639030631f..ed453383f9 100644 --- a/watchtower/wtdb/log.go +++ b/watchtower/wtdb/log.go @@ -10,6 +10,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb/migration5" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration6" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration7" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration8" ) // log is a logger that is initialized with no output filters. This @@ -40,6 +41,7 @@ func UseLogger(logger btclog.Logger) { migration5.UseLogger(logger) migration6.UseLogger(logger) migration7.UseLogger(logger) + migration8.UseLogger(logger) } // logClosure is used to provide a closure over expensive logging operations so diff --git a/watchtower/wtdb/migration8/codec.go b/watchtower/wtdb/migration8/codec.go new file mode 100644 index 0000000000..9c8dca1a36 --- /dev/null +++ b/watchtower/wtdb/migration8/codec.go @@ -0,0 +1,234 @@ +package migration8 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "fmt" + "io" + + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/tlv" +) + +// BreachHintSize is the length of the identifier used to detect remote +// commitment broadcasts. +const BreachHintSize = 16 + +// BreachHint is the first 16-bytes of SHA256(txid), which is used to identify +// the breach transaction. +type BreachHint [BreachHintSize]byte + +// ChannelID is a series of 32-bytes that uniquely identifies all channels +// within the network. The ChannelID is computed using the outpoint of the +// funding transaction (the txid, and output index). Given a funding output the +// ChannelID can be calculated by XOR'ing the big-endian serialization of the +// txid and the big-endian serialization of the output index, truncated to +// 2 bytes. +type ChannelID [32]byte + +// writeBigSize will encode the given uint64 as a BigSize byte slice. +func writeBigSize(i uint64) ([]byte, error) { + var b bytes.Buffer + err := tlv.WriteVarInt(&b, i, &[8]byte{}) + if err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +// readBigSize converts the given byte slice into a uint64 and assumes that the +// bytes slice is using BigSize encoding. +func readBigSize(b []byte) (uint64, error) { + r := bytes.NewReader(b) + i, err := tlv.ReadVarInt(r, &[8]byte{}) + if err != nil { + return 0, err + } + + return i, nil +} + +// CommittedUpdate holds a state update sent by a client along with its +// allocated sequence number and the exact remote commitment the encrypted +// justice transaction can rectify. +type CommittedUpdate struct { + // SeqNum is the unique sequence number allocated by the session to this + // update. + SeqNum uint16 + + CommittedUpdateBody +} + +// BackupID identifies a particular revoked, remote commitment by channel id and +// commitment height. +type BackupID struct { + // ChanID is the channel id of the revoked commitment. + ChanID ChannelID + + // CommitHeight is the commitment height of the revoked commitment. + CommitHeight uint64 +} + +// Encode writes the BackupID from the passed io.Writer. +func (b *BackupID) Encode(w io.Writer) error { + return WriteElements(w, + b.ChanID, + b.CommitHeight, + ) +} + +// Decode reads a BackupID from the passed io.Reader. +func (b *BackupID) Decode(r io.Reader) error { + return ReadElements(r, + &b.ChanID, + &b.CommitHeight, + ) +} + +// String returns a human-readable encoding of a BackupID. +func (b BackupID) String() string { + return fmt.Sprintf("backup(%v, %d)", b.ChanID, b.CommitHeight) +} + +// WriteElements serializes a variadic list of elements into the given +// io.Writer. +func WriteElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + if err := WriteElement(w, element); err != nil { + return err + } + } + + return nil +} + +// ReadElements deserializes the provided io.Reader into a variadic list of +// target elements. +func ReadElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + if err := ReadElement(r, element); err != nil { + return err + } + } + + return nil +} + +// WriteElement serializes a single element into the provided io.Writer. +func WriteElement(w io.Writer, element interface{}) error { + switch e := element.(type) { + case ChannelID: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case uint64: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case BreachHint: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case []byte: + if err := wire.WriteVarBytes(w, 0, e); err != nil { + return err + } + + default: + return fmt.Errorf("unexpected type") + } + + return nil +} + +// ReadElement deserializes a single element from the provided io.Reader. +func ReadElement(r io.Reader, element interface{}) error { + switch e := element.(type) { + case *ChannelID: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *uint64: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *BreachHint: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *[]byte: + bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte") + if err != nil { + return err + } + + *e = bytes + + default: + return fmt.Errorf("unexpected type") + } + + return nil +} + +// CommittedUpdateBody represents the primary components of a CommittedUpdate. +// On disk, this is stored under the sequence number, which acts as its key. +type CommittedUpdateBody struct { + // BackupID identifies the breached commitment that the encrypted blob + // can spend from. + BackupID BackupID + + // Hint is the 16-byte prefix of the revoked commitment transaction ID. + Hint BreachHint + + // EncryptedBlob is a ciphertext containing the sweep information for + // exacting justice if the commitment transaction matching the breach + // hint is broadcast. + EncryptedBlob []byte +} + +// Encode writes the CommittedUpdateBody to the passed io.Writer. +func (u *CommittedUpdateBody) Encode(w io.Writer) error { + err := u.BackupID.Encode(w) + if err != nil { + return err + } + + return WriteElements(w, + u.Hint, + u.EncryptedBlob, + ) +} + +// Decode reads a CommittedUpdateBody from the passed io.Reader. +func (u *CommittedUpdateBody) Decode(r io.Reader) error { + err := u.BackupID.Decode(r) + if err != nil { + return err + } + + return ReadElements(r, + &u.Hint, + &u.EncryptedBlob, + ) +} + +// SessionIDSize is 33-bytes; it is a serialized, compressed public key. +const SessionIDSize = 33 + +// SessionID is created from the remote public key of a client, and serves as a +// unique identifier and authentication for sending state updates. +type SessionID [SessionIDSize]byte + +// String returns a hex encoding of the session id. +func (s SessionID) String() string { + return hex.EncodeToString(s[:]) +} diff --git a/watchtower/wtdb/migration8/log.go b/watchtower/wtdb/migration8/log.go new file mode 100644 index 0000000000..ab35682c5a --- /dev/null +++ b/watchtower/wtdb/migration8/log.go @@ -0,0 +1,14 @@ +package migration8 + +import ( + "github.com/btcsuite/btclog" +) + +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled + +// UseLogger uses a specified Logger to output package logging info. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/watchtower/wtdb/migration8/migration.go b/watchtower/wtdb/migration8/migration.go new file mode 100644 index 0000000000..2e9d041e39 --- /dev/null +++ b/watchtower/wtdb/migration8/migration.go @@ -0,0 +1,223 @@ +package migration8 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/lightningnetwork/lnd/kvdb" +) + +var ( + // cSessionBkt is a top-level bucket storing: + // session-id => cSessionBody -> encoded ClientSessionBody + // => cSessionDBID -> db-assigned-id + // => cSessionCommits => seqnum -> encoded CommittedUpdate + // => cSessionAckRangeIndex => db-chan-id => start -> end + // => cSessionRogueUpdateCount -> count + cSessionBkt = []byte("client-session-bucket") + + // cChanIDIndexBkt is a top-level bucket storing: + // db-assigned-id -> channel-ID + cChanIDIndexBkt = []byte("client-channel-id-index") + + // cSessionAckRangeIndex is a sub-bucket of cSessionBkt storing + // chan-id => start -> end + cSessionAckRangeIndex = []byte("client-session-ack-range-index") + + // cSessionBody is a sub-bucket of cSessionBkt storing: + // seqnum -> encoded CommittedUpdate. + cSessionCommits = []byte("client-session-commits") + + // cChanDetailsBkt is a top-level bucket storing: + // channel-id => cChannelSummary -> encoded ClientChanSummary. + // => cChanDBID -> db-assigned-id + // => cChanSessions => db-session-id -> 1 + // => cChanClosedHeight -> block-height + // => cChanMaxCommitmentHeight -> commitment-height + cChanDetailsBkt = []byte("client-channel-detail-bucket") + + cChanMaxCommitmentHeight = []byte( + "client-channel-max-commitment-height", + ) + + // ErrUninitializedDB signals that top-level buckets for the database + // have not been initialized. + ErrUninitializedDB = errors.New("db not initialized") + + byteOrder = binary.BigEndian +) + +// MigrateChannelMaxHeights migrates the tower client db by collecting all the +// max commitment heights that have been backed up for each channel and then +// storing those heights alongside the channel info. +func MigrateChannelMaxHeights(tx kvdb.RwTx) error { + log.Infof("Migrating the tower client DB for quick channel max " + + "commitment height lookup") + + heights, err := collectChanMaxHeights(tx) + if err != nil { + return err + } + + return writeChanMaxHeights(tx, heights) +} + +// writeChanMaxHeights iterates over the given channel ID to height map and +// writes an entry under the cChanMaxCommitmentHeight key for each channel. +func writeChanMaxHeights(tx kvdb.RwTx, heights map[ChannelID]uint64) error { + chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return ErrUninitializedDB + } + + for chanID, maxHeight := range heights { + chanDetails := chanDetailsBkt.NestedReadWriteBucket(chanID[:]) + + // If the details bucket for this channel ID does not exist, + // it is probably a channel that has been closed and deleted + // already. So we can skip this height. + if chanDetails == nil { + continue + } + + b, err := writeBigSize(maxHeight) + if err != nil { + return err + } + + err = chanDetails.Put(cChanMaxCommitmentHeight, b) + if err != nil { + return err + } + } + + return nil +} + +// collectChanMaxHeights iterates over all the sessions in the DB. For each +// session, it iterates over all the Acked updates and the committed updates +// to collect the maximum commitment height for each channel. +func collectChanMaxHeights(tx kvdb.RwTx) (map[ChannelID]uint64, error) { + sessionsBkt := tx.ReadBucket(cSessionBkt) + if sessionsBkt == nil { + return nil, ErrUninitializedDB + } + + chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt) + if chanIDIndexBkt == nil { + return nil, ErrUninitializedDB + } + + heights := make(map[ChannelID]uint64) + + // For each update we consider, we will only update the heights map if + // the commitment height for the channel is larger than the current + // max height stored for the channel. + cb := func(chanID ChannelID, commitHeight uint64) { + if commitHeight > heights[chanID] { + heights[chanID] = commitHeight + } + } + + err := sessionsBkt.ForEach(func(sessIDBytes, _ []byte) error { + sessBkt := sessionsBkt.NestedReadBucket(sessIDBytes) + if sessBkt == nil { + return fmt.Errorf("bucket not found for session %x", + sessIDBytes) + } + + err := forEachCommittedUpdate(sessBkt, cb) + if err != nil { + return err + } + + return forEachAckedUpdate(sessBkt, chanIDIndexBkt, cb) + }) + if err != nil { + return nil, err + } + + return heights, nil +} + +// forEachCommittedUpdate iterates over all the given session's committed +// updates and calls the call-back for each. +func forEachCommittedUpdate(sessBkt kvdb.RBucket, + cb func(chanID ChannelID, commitHeight uint64)) error { + + sessionCommits := sessBkt.NestedReadBucket(cSessionCommits) + if sessionCommits == nil { + return nil + } + + return sessionCommits.ForEach(func(k, v []byte) error { + var update CommittedUpdate + err := update.Decode(bytes.NewReader(v)) + if err != nil { + return err + } + + cb(update.BackupID.ChanID, update.BackupID.CommitHeight) + + return nil + }) +} + +// forEachAckedUpdate iterates over all the given session's acked update range +// indices and calls the call-back for each. +func forEachAckedUpdate(sessBkt, chanIDIndexBkt kvdb.RBucket, + cb func(chanID ChannelID, commitHeight uint64)) error { + + sessionAcksRanges := sessBkt.NestedReadBucket(cSessionAckRangeIndex) + if sessionAcksRanges == nil { + return nil + } + + return sessionAcksRanges.ForEach(func(dbChanID, _ []byte) error { + rangeBkt := sessionAcksRanges.NestedReadBucket(dbChanID) + if rangeBkt == nil { + return nil + } + + index, err := readRangeIndex(rangeBkt) + if err != nil { + return err + } + + chanIDBytes := chanIDIndexBkt.Get(dbChanID) + var chanID ChannelID + copy(chanID[:], chanIDBytes) + + cb(chanID, index.MaxHeight()) + + return nil + }) +} + +// readRangeIndex reads a persisted RangeIndex from the passed bucket and into +// a new in-memory RangeIndex. +func readRangeIndex(rangesBkt kvdb.RBucket) (*RangeIndex, error) { + ranges := make(map[uint64]uint64) + err := rangesBkt.ForEach(func(k, v []byte) error { + start, err := readBigSize(k) + if err != nil { + return err + } + + end, err := readBigSize(v) + if err != nil { + return err + } + + ranges[start] = end + + return nil + }) + if err != nil { + return nil, err + } + + return NewRangeIndex(ranges, WithSerializeUint64Fn(writeBigSize)) +} diff --git a/watchtower/wtdb/migration8/migration_test.go b/watchtower/wtdb/migration8/migration_test.go new file mode 100644 index 0000000000..336069bfd6 --- /dev/null +++ b/watchtower/wtdb/migration8/migration_test.go @@ -0,0 +1,214 @@ +package migration8 + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/channeldb/migtest" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/stretchr/testify/require" +) + +const ( + chan1ID = 10 + chan2ID = 20 + chan3ID = 30 + chan4ID = 40 + + chan1DBID = 111 + chan2DBID = 222 + chan3DBID = 333 +) + +var ( + // preDetails is the expected data of the channel details bucket before + // the migration. + preDetails = map[string]interface{}{ + channelIDString(chan1ID): map[string]interface{}{}, + channelIDString(chan2ID): map[string]interface{}{}, + channelIDString(chan3ID): map[string]interface{}{}, + } + + // channelIDIndex is the data in the channelID index that is used to + // find the mapping between the db-assigned channel ID and the real + // channel ID. + channelIDIndex = map[string]interface{}{ + uint64ToStr(chan1DBID): channelIDString(chan1ID), + uint64ToStr(chan2DBID): channelIDString(chan2ID), + uint64ToStr(chan3DBID): channelIDString(chan3ID), + } + + // postDetails is the expected data in the channel details bucket after + // the migration. + postDetails = map[string]interface{}{ + channelIDString(chan1ID): map[string]interface{}{ + string(cChanMaxCommitmentHeight): uint64ToStr(105), + }, + channelIDString(chan2ID): map[string]interface{}{ + string(cChanMaxCommitmentHeight): uint64ToStr(205), + }, + channelIDString(chan3ID): map[string]interface{}{ + string(cChanMaxCommitmentHeight): uint64ToStr(304), + }, + } +) + +// TestMigrateChannelToSessionIndex tests that the MigrateChannelToSessionIndex +// function correctly builds the new channel-to-sessionID index to the tower +// client DB. +func TestMigrateChannelToSessionIndex(t *testing.T) { + t.Parallel() + + update1 := &CommittedUpdate{ + SeqNum: 1, + CommittedUpdateBody: CommittedUpdateBody{ + BackupID: BackupID{ + ChanID: intToChannelID(chan1ID), + CommitHeight: 105, + }, + }, + } + var update1B bytes.Buffer + require.NoError(t, update1.Encode(&update1B)) + + update3 := &CommittedUpdate{ + SeqNum: 1, + CommittedUpdateBody: CommittedUpdateBody{ + BackupID: BackupID{ + ChanID: intToChannelID(chan3ID), + CommitHeight: 304, + }, + }, + } + var update3B bytes.Buffer + require.NoError(t, update3.Encode(&update3B)) + + update4 := &CommittedUpdate{ + SeqNum: 1, + CommittedUpdateBody: CommittedUpdateBody{ + BackupID: BackupID{ + ChanID: intToChannelID(chan4ID), + CommitHeight: 400, + }, + }, + } + var update4B bytes.Buffer + require.NoError(t, update4.Encode(&update4B)) + + // sessions is the expected data in the sessions bucket before and + // after the migration. + sessions := map[string]interface{}{ + // A session with both acked and committed updates. + sessionIDString("1"): map[string]interface{}{ + string(cSessionAckRangeIndex): map[string]interface{}{ + // This range index gives channel 1 a max height + // of 104. + uint64ToStr(chan1DBID): map[string]interface{}{ + uint64ToStr(100): uint64ToStr(101), + uint64ToStr(104): uint64ToStr(104), + }, + // This range index gives channel 2 a max height + // of 200. + uint64ToStr(chan2DBID): map[string]interface{}{ + uint64ToStr(200): uint64ToStr(200), + }, + }, + string(cSessionCommits): map[string]interface{}{ + // This committed update gives channel 1 a max + // height of 105 and so it overrides the heights + // from the range index. + uint64ToStr(1): update1B.String(), + }, + }, + // A session with only acked updates. + sessionIDString("2"): map[string]interface{}{ + string(cSessionAckRangeIndex): map[string]interface{}{ + // This range index gives channel 2 a max height + // of 205. + uint64ToStr(chan2DBID): map[string]interface{}{ + uint64ToStr(201): uint64ToStr(205), + }, + }, + }, + // A session with only committed updates. + sessionIDString("3"): map[string]interface{}{ + string(cSessionCommits): map[string]interface{}{ + // This committed update gives channel 3 a max + // height of 304. + uint64ToStr(1): update3B.String(), + }, + }, + // This session only contains heights for channel 4 which has + // been closed and so this should have no effect. + sessionIDString("4"): map[string]interface{}{ + string(cSessionAckRangeIndex): map[string]interface{}{ + uint64ToStr(444): map[string]interface{}{ + uint64ToStr(400): uint64ToStr(402), + uint64ToStr(403): uint64ToStr(405), + }, + }, + string(cSessionCommits): map[string]interface{}{ + uint64ToStr(1): update4B.String(), + }, + }, + // A session with no updates. + sessionIDString("5"): map[string]interface{}{}, + } + + // Before the migration we have a channel details + // bucket, a sessions bucket, a session ID index bucket + // and a channel ID index bucket. + before := func(tx kvdb.RwTx) error { + err := migtest.RestoreDB(tx, cChanDetailsBkt, preDetails) + if err != nil { + return err + } + + err = migtest.RestoreDB(tx, cSessionBkt, sessions) + if err != nil { + return err + } + + return migtest.RestoreDB(tx, cChanIDIndexBkt, channelIDIndex) + } + + after := func(tx kvdb.RwTx) error { + err := migtest.VerifyDB(tx, cSessionBkt, sessions) + if err != nil { + return err + } + + return migtest.VerifyDB(tx, cChanDetailsBkt, postDetails) + } + + migtest.ApplyMigration( + t, before, after, MigrateChannelMaxHeights, false, + ) +} + +func sessionIDString(id string) string { + var sessID SessionID + copy(sessID[:], id) + return sessID.String() +} + +func channelIDString(id uint64) string { + var chanID ChannelID + byteOrder.PutUint64(chanID[:], id) + return string(chanID[:]) +} + +func uint64ToStr(id uint64) string { + b, err := writeBigSize(id) + if err != nil { + panic(err) + } + + return string(b) +} + +func intToChannelID(id uint64) ChannelID { + var chanID ChannelID + byteOrder.PutUint64(chanID[:], id) + return chanID +} diff --git a/watchtower/wtdb/migration8/range_index.go b/watchtower/wtdb/migration8/range_index.go new file mode 100644 index 0000000000..94f0e20300 --- /dev/null +++ b/watchtower/wtdb/migration8/range_index.go @@ -0,0 +1,619 @@ +package migration8 + +import ( + "fmt" + "sync" +) + +// rangeItem represents the start and end values of a range. +type rangeItem struct { + start uint64 + end uint64 +} + +// RangeIndexOption describes the signature of a functional option that can be +// used to modify the behaviour of a RangeIndex. +type RangeIndexOption func(*RangeIndex) + +// WithSerializeUint64Fn is a functional option that can be used to set the +// function to be used to do the serialization of a uint64 into a byte slice. +func WithSerializeUint64Fn(fn func(uint64) ([]byte, error)) RangeIndexOption { + return func(index *RangeIndex) { + index.serializeUint64 = fn + } +} + +// RangeIndex can be used to keep track of which numbers have been added to a +// set. It does so by keeping track of a sorted list of rangeItems. Each +// rangeItem has a start and end value of a range where all values in-between +// have been added to the set. It works well in situations where it is expected +// numbers in the set are not sparse. +type RangeIndex struct { + // set is a sorted list of rangeItem. + set []rangeItem + + // mu is used to ensure safe access to set. + mu sync.Mutex + + // serializeUint64 is the function that can be used to convert a uint64 + // to a byte slice. + serializeUint64 func(uint64) ([]byte, error) +} + +// NewRangeIndex constructs a new RangeIndex. An initial set of ranges may be +// passed to the function in the form of a map. +func NewRangeIndex(ranges map[uint64]uint64, + opts ...RangeIndexOption) (*RangeIndex, error) { + + index := &RangeIndex{ + serializeUint64: defaultSerializeUint64, + set: make([]rangeItem, 0), + } + + // Apply any functional options. + for _, o := range opts { + o(index) + } + + for s, e := range ranges { + if err := index.addRange(s, e); err != nil { + return nil, err + } + } + + return index, nil +} + +// addRange can be used to add an entire new range to the set. This method +// should only ever be called by NewRangeIndex to initialise the in-memory +// structure and so the RangeIndex mutex is not held during this method. +func (a *RangeIndex) addRange(start, end uint64) error { + // Check that the given range is valid. + if start > end { + return fmt.Errorf("invalid range. Start height %d is larger "+ + "than end height %d", start, end) + } + + // min is a helper closure that will return the minimum of two uint64s. + min := func(a, b uint64) uint64 { + if a < b { + return a + } + + return b + } + + // max is a helper closure that will return the maximum of two uint64s. + max := func(a, b uint64) uint64 { + if a > b { + return a + } + + return b + } + + // Collect the ranges that fall before and after the new range along + // with the start and end values of the new range. + var before, after []rangeItem + for _, x := range a.set { + // If the new start value can't extend the current ranges end + // value, then the two cannot be merged. The range is added to + // the group of ranges that fall before the new range. + if x.end+1 < start { + before = append(before, x) + continue + } + + // If the current ranges start value does not follow on directly + // from the new end value, then the two cannot be merged. The + // range is added to the group of ranges that fall after the new + // range. + if end+1 < x.start { + after = append(after, x) + continue + } + + // Otherwise, there is an overlap and so the two can be merged. + start = min(start, x.start) + end = max(end, x.end) + } + + // Re-construct the range index set. + a.set = append(append(before, rangeItem{ + start: start, + end: end, + }), after...) + + return nil +} + +// IsInIndex returns true if the given number is in the range set. +func (a *RangeIndex) IsInIndex(n uint64) bool { + a.mu.Lock() + defer a.mu.Unlock() + + _, isCovered := a.lowerBoundIndex(n) + + return isCovered +} + +// NumInSet returns the number of items covered by the range set. +func (a *RangeIndex) NumInSet() uint64 { + a.mu.Lock() + defer a.mu.Unlock() + + var numItems uint64 + for _, r := range a.set { + numItems += r.end - r.start + 1 + } + + return numItems +} + +// MaxHeight returns the highest number covered in the range. +func (a *RangeIndex) MaxHeight() uint64 { + a.mu.Lock() + defer a.mu.Unlock() + + if len(a.set) == 0 { + return 0 + } + + return a.set[len(a.set)-1].end +} + +// GetAllRanges returns a copy of the range set in the form of a map. +func (a *RangeIndex) GetAllRanges() map[uint64]uint64 { + a.mu.Lock() + defer a.mu.Unlock() + + cp := make(map[uint64]uint64, len(a.set)) + for _, item := range a.set { + cp[item.start] = item.end + } + + return cp +} + +// lowerBoundIndex returns the index of the RangeIndex that is most appropriate +// for the new value, n. In other words, it returns the index of the rangeItem +// set of the range where the start value is the highest start value in the set +// that is still lower than or equal to the given number, n. The returned +// boolean is true if the given number is already covered in the RangeIndex. +// A returned index of -1 indicates that no lower bound range exists in the set. +// Since the most likely case is that the new number will just extend the +// highest range, a check is first done to see if this is the case which will +// make the methods' computational complexity O(1). Otherwise, a binary search +// is done which brings the computational complexity to O(log N). +func (a *RangeIndex) lowerBoundIndex(n uint64) (int, bool) { + // If the set is empty, then there is no such index and the value + // definitely is not in the set. + if len(a.set) == 0 { + return -1, false + } + + // In most cases, the last index item will be the one we want. So just + // do a quick check on that index first to avoid doing the binary + // search. + lastIndex := len(a.set) - 1 + lastRange := a.set[lastIndex] + if lastRange.start <= n { + return lastIndex, lastRange.end >= n + } + + // Otherwise, do a binary search to find the index of interest. + var ( + low = 0 + high = len(a.set) - 1 + rangeIndex = -1 + ) + for { + mid := (low + high) / 2 + currentRange := a.set[mid] + + switch { + case currentRange.start > n: + // If the start of the range is greater than n, we can + // completely cut out that entire part of the array. + high = mid + + case currentRange.start < n: + // If the range already includes the given height, we + // can stop searching now. + if currentRange.end >= n { + return mid, true + } + + // If the start of the range is smaller than n, we can + // store this as the new best index to return. + rangeIndex = mid + + // If low and mid are already equal, then increment low + // by 1. Exit if this means that low is now greater than + // high. + if low == mid { + low = mid + 1 + if low > high { + return rangeIndex, false + } + } else { + low = mid + } + + continue + + default: + // If the height is equal to the start value of the + // current range that mid is pointing to, then the + // height is already covered. + return mid, true + } + + // Exit if we have checked all the ranges. + if low == high { + break + } + } + + return rangeIndex, false +} + +// KVStore is an interface representing a key-value store. +type KVStore interface { + // Put saves the specified key/value pair to the store. Keys that do not + // already exist are added and keys that already exist are overwritten. + Put(key, value []byte) error + + // Delete removes the specified key from the bucket. Deleting a key that + // does not exist does not return an error. + Delete(key []byte) error +} + +// Add adds a single number to the range set. It first attempts to apply the +// necessary changes to the passed KV store and then only if this succeeds, will +// the changes be applied to the in-memory structure. +func (a *RangeIndex) Add(newHeight uint64, kv KVStore) error { + a.mu.Lock() + defer a.mu.Unlock() + + // Compute the changes that will need to be applied to both the sorted + // rangeItem array representation and the key-value store representation + // of the range index. + arrayChanges, kvStoreChanges := a.getChanges(newHeight) + + // First attempt to apply the KV store changes. Only if this succeeds + // will we apply the changes to our in-memory range index structure. + err := a.applyKVChanges(kv, kvStoreChanges) + if err != nil { + return err + } + + // Since the DB changes were successful, we can now commit the + // changes to our in-memory representation of the range set. + a.applyArrayChanges(arrayChanges) + + return nil +} + +// applyKVChanges applies the given set of kvChanges to a KV store. It is +// assumed that a transaction is being held on the kv store so that if any +// of the actions of the function fails, the changes will be reverted. +func (a *RangeIndex) applyKVChanges(kv KVStore, changes *kvChanges) error { + // Exit early if there are no changes to apply. + if kv == nil || changes == nil { + return nil + } + + // Check if any range pair needs to be deleted. + if changes.deleteKVKey != nil { + del, err := a.serializeUint64(*changes.deleteKVKey) + if err != nil { + return err + } + + if err := kv.Delete(del); err != nil { + return err + } + } + + start, err := a.serializeUint64(changes.key) + if err != nil { + return err + } + + end, err := a.serializeUint64(changes.value) + if err != nil { + return err + } + + return kv.Put(start, end) +} + +// applyArrayChanges applies the given arrayChanges to the in-memory RangeIndex +// itself. This should only be done once the persisted kv store changes have +// already been applied. +func (a *RangeIndex) applyArrayChanges(changes *arrayChanges) { + if changes == nil { + return + } + + if changes.indexToDelete != nil { + a.set = append( + a.set[:*changes.indexToDelete], + a.set[*changes.indexToDelete+1:]..., + ) + } + + if changes.newIndex != nil { + switch { + case *changes.newIndex == 0: + a.set = append([]rangeItem{{ + start: changes.start, + end: changes.end, + }}, a.set...) + + case *changes.newIndex == len(a.set): + a.set = append(a.set, rangeItem{ + start: changes.start, + end: changes.end, + }) + + default: + a.set = append( + a.set[:*changes.newIndex+1], + a.set[*changes.newIndex:]..., + ) + a.set[*changes.newIndex] = rangeItem{ + start: changes.start, + end: changes.end, + } + } + + return + } + + if changes.indexToEdit != nil { + a.set[*changes.indexToEdit] = rangeItem{ + start: changes.start, + end: changes.end, + } + } +} + +// arrayChanges encompasses the diff to apply to the sorted rangeItem array +// representation of a range index. Such a diff will either include adding a +// new range or editing an existing range. If an existing range is edited, then +// the diff might also include deleting an index (this will be the case if the +// editing of the one range results in the merge of another range). +type arrayChanges struct { + start uint64 + end uint64 + + // newIndex, if set, is the index of the in-memory range array where a + // new range, [start:end], should be added. newIndex should never be + // set at the same time as indexToEdit or indexToDelete. + newIndex *int + + // indexToDelete, if set, is the index of the sorted rangeItem array + // that should be deleted. This should be applied before reading the + // index value of indexToEdit. This should not be set at the same time + // as newIndex. + indexToDelete *int + + // indexToEdit is the index of the in-memory range array that should be + // edited. The range at this index will be changed to [start:end]. This + // should only be read after indexToDelete index has been deleted. + indexToEdit *int +} + +// kvChanges encompasses the diff to apply to a KV-store representation of a +// range index. A kv-store diff for the addition of a single number to the range +// index will include either a brand new key-value pair or the altering of the +// value of an existing key. Optionally, the diff may also include the deletion +// of an existing key. A deletion will be required if the addition of the new +// number results in the merge of two ranges. +type kvChanges struct { + key uint64 + value uint64 + + // deleteKVKey, if set, is the key of the kv store representation that + // should be deleted. + deleteKVKey *uint64 +} + +// getChanges will calculate and return the changes that need to be applied to +// both the sorted-rangeItem-array representation and the key-value store +// representation of the range index. +func (a *RangeIndex) getChanges(n uint64) (*arrayChanges, *kvChanges) { + // If the set is empty then a new range item is added. + if len(a.set) == 0 { + // For the array representation, a new range [n:n] is added to + // the first index of the array. + firstIndex := 0 + ac := &arrayChanges{ + newIndex: &firstIndex, + start: n, + end: n, + } + + // For the KV representation, a new [n:n] pair is added. + kvc := &kvChanges{ + key: n, + value: n, + } + + return ac, kvc + } + + // Find the index of the lower bound range to the new number. + indexOfRangeBelow, alreadyCovered := a.lowerBoundIndex(n) + + switch { + // The new number is already covered by the range index. No changes are + // required. + case alreadyCovered: + return nil, nil + + // No lower bound index exists. + case indexOfRangeBelow < 0: + // Check if the very first range can be merged into this new + // one. + if n+1 == a.set[0].start { + // If so, the two ranges can be merged and so the start + // value of the range is n and the end value is the end + // of the existing first range. + start := n + end := a.set[0].end + + // For the array representation, we can just edit the + // first entry of the array + editIndex := 0 + ac := &arrayChanges{ + indexToEdit: &editIndex, + start: start, + end: end, + } + + // For the KV store representation, we add a new kv pair + // and delete the range with the key equal to the start + // value of the range we are merging. + kvKeyToDelete := a.set[0].start + kvc := &kvChanges{ + key: start, + value: end, + deleteKVKey: &kvKeyToDelete, + } + + return ac, kvc + } + + // Otherwise, we add a new index. + + // For the array representation, a new range [n:n] is added to + // the first index of the array. + newIndex := 0 + ac := &arrayChanges{ + newIndex: &newIndex, + start: n, + end: n, + } + + // For the KV representation, a new [n:n] pair is added. + kvc := &kvChanges{ + key: n, + value: n, + } + + return ac, kvc + + // A lower range does exist, and it can be extended to include this new + // number. + case a.set[indexOfRangeBelow].end+1 == n: + start := a.set[indexOfRangeBelow].start + end := n + indexToChange := indexOfRangeBelow + + // If there are no intervals above this one or if there are, but + // they can't be merged into this one then we just need to edit + // this interval. + if indexOfRangeBelow == len(a.set)-1 || + a.set[indexOfRangeBelow+1].start != n+1 { + + // For the array representation, we just edit the index. + ac := &arrayChanges{ + indexToEdit: &indexToChange, + start: start, + end: end, + } + + // For the key-value representation, we just overwrite + // the end value at the existing start key. + kvc := &kvChanges{ + key: start, + value: end, + } + + return ac, kvc + } + + // There is a range above this one that we need to merge into + // this one. + delIndex := indexOfRangeBelow + 1 + end = a.set[delIndex].end + + // For the array representation, we delete the range above this + // one and edit this range to include the end value of the range + // above. + ac := &arrayChanges{ + indexToDelete: &delIndex, + indexToEdit: &indexToChange, + start: start, + end: end, + } + + // For the kv representation, we tweak the end value of an + // existing key and delete the key of the range we are deleting. + deleteKey := a.set[delIndex].start + kvc := &kvChanges{ + key: start, + value: end, + deleteKVKey: &deleteKey, + } + + return ac, kvc + + // A lower range does exist, but it can't be extended to include this + // new number, and so we need to add a new range after the lower bound + // range. + default: + newIndex := indexOfRangeBelow + 1 + + // If there are no ranges above this new one or if there are, + // but they can't be merged into this new one, then we can just + // add the new one as is. + if newIndex == len(a.set) || a.set[newIndex].start != n+1 { + ac := &arrayChanges{ + newIndex: &newIndex, + start: n, + end: n, + } + + kvc := &kvChanges{ + key: n, + value: n, + } + + return ac, kvc + } + + // Else, we merge the above index. + start := n + end := a.set[newIndex].end + toEdit := newIndex + + // For the array representation, we edit the range above to + // include the new start value. + ac := &arrayChanges{ + indexToEdit: &toEdit, + start: start, + end: end, + } + + // For the kv representation, we insert the new start-end key + // value pair and delete the key using the old start value. + delKey := a.set[newIndex].start + kvc := &kvChanges{ + key: start, + value: end, + deleteKVKey: &delKey, + } + + return ac, kvc + } +} + +func defaultSerializeUint64(i uint64) ([]byte, error) { + var b [8]byte + byteOrder.PutUint64(b[:], i) + return b[:], nil +} diff --git a/watchtower/wtdb/queue.go b/watchtower/wtdb/queue.go index 372765e7cd..674743a1e7 100644 --- a/watchtower/wtdb/queue.go +++ b/watchtower/wtdb/queue.go @@ -80,6 +80,7 @@ type DiskQueueDB[T Serializable] struct { db kvdb.Backend topLevelBkt []byte constructor func() T + onItemWrite func(tx kvdb.RwTx, item T) error } // A compile-time check to ensure that DiskQueueDB implements the Queue @@ -89,12 +90,14 @@ var _ Queue[Serializable] = (*DiskQueueDB[Serializable])(nil) // NewQueueDB constructs a new DiskQueueDB. A queueBktName must be provided so // that the DiskQueueDB can create its own namespace in the bolt db. func NewQueueDB[T Serializable](db kvdb.Backend, queueBktName []byte, - constructor func() T) Queue[T] { + constructor func() T, + onItemWrite func(tx kvdb.RwTx, item T) error) Queue[T] { return &DiskQueueDB[T]{ db: db, topLevelBkt: queueBktName, constructor: constructor, + onItemWrite: onItemWrite, } } @@ -279,6 +282,13 @@ func (d *DiskQueueDB[T]) addItem(tx kvdb.RwTx, queueName []byte, item T) error { return err } + if d.onItemWrite != nil { + err = d.onItemWrite(tx, item) + if err != nil { + return err + } + } + // Find the index to use for placing this new item at the back of the // queue. var nextIndex uint64 diff --git a/watchtower/wtdb/queue_test.go b/watchtower/wtdb/queue_test.go index 02c7b272cb..ff2c5a0da8 100644 --- a/watchtower/wtdb/queue_test.go +++ b/watchtower/wtdb/queue_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/stretchr/testify/require" ) @@ -31,6 +32,14 @@ func TestDiskQueue(t *testing.T) { require.NoError(t, err) }) + // In order to test that the queue's `onItemWrite` call back (which in + // this case will be set to maybeUpdateMaxCommitHeight) is executed as + // expected, we need to register a channel so that we can later assert + // that it's max height field was updated properly. + var chanID lnwire.ChannelID + err = db.RegisterChannel(chanID, []byte{}) + require.NoError(t, err) + namespace := []byte("test-namespace") queue := db.GetDBQueue(namespace) @@ -110,4 +119,19 @@ func TestDiskQueue(t *testing.T) { // This should not have changed the order of the tasks, they should // still appear in the correct order. popAndAssert(task1, task2, task3, task4, task5, task6) + + // Finally, we check that the `onItemWrite` call back was executed by + // the queue. We do this by checking that the channel's recorded max + // commitment height was set correctly. It should be equal to the height + // recorded in task6. + infos, err := db.FetchChanInfos() + require.NoError(t, err) + require.Len(t, infos, 1) + + info, ok := infos[chanID] + require.True(t, ok) + require.True(t, info.MaxHeight.IsSome()) + info.MaxHeight.WhenSome(func(height uint64) { + require.EqualValues(t, task6.CommitHeight, height) + }) } diff --git a/watchtower/wtdb/version.go b/watchtower/wtdb/version.go index b44ed80eb3..dd9c554723 100644 --- a/watchtower/wtdb/version.go +++ b/watchtower/wtdb/version.go @@ -12,6 +12,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb/migration5" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration6" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration7" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration8" ) // txMigration is a function which takes a prior outdated version of the @@ -67,6 +68,9 @@ var clientDBVersions = []version{ { txMigration: migration7.MigrateChannelToSessionIndex, }, + { + txMigration: migration8.MigrateChannelMaxHeights, + }, } // getLatestDBVersion returns the last known database version.