From 41d1fe8fa0fd380b284130e628c821bb45f55b26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tomasz=20Zdyba=C5=82?= Date: Tue, 31 Jan 2023 11:04:17 +0100 Subject: [PATCH] fix(libs/header/sync): Make ranges.Add thread-safe (#1649) --- libs/header/sync/ranges.go | 12 ++++++++---- libs/header/sync/ranges_test.go | 34 +++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 4 deletions(-) create mode 100644 libs/header/sync/ranges_test.go diff --git a/libs/header/sync/ranges.go b/libs/header/sync/ranges.go index 99047c4093..8a05b728c3 100644 --- a/libs/header/sync/ranges.go +++ b/libs/header/sync/ranges.go @@ -19,6 +19,10 @@ func (rs *ranges[H]) Head() H { rs.lk.RLock() defer rs.lk.RUnlock() + return rs.head() +} + +func (rs *ranges[H]) head() H { ln := len(rs.ranges) if ln == 0 { var zero H @@ -32,7 +36,10 @@ func (rs *ranges[H]) Head() H { // Add appends the new Header to existing range or starts a new one. // It starts a new one if the new Header is not adjacent to any of existing ranges. func (rs *ranges[H]) Add(h H) { - head := rs.Head() + rs.lk.Lock() + defer rs.lk.Unlock() + + head := rs.head() // short-circuit if header is from the past if !head.IsZero() && head.Height() >= h.Height() { @@ -47,9 +54,6 @@ func (rs *ranges[H]) Add(h H) { return } - rs.lk.Lock() - defer rs.lk.Unlock() - // if the new header is adjacent to head if !head.IsZero() && h.Height() == head.Height()+1 { // append it to the last known range diff --git a/libs/header/sync/ranges_test.go b/libs/header/sync/ranges_test.go new file mode 100644 index 0000000000..fea81378ac --- /dev/null +++ b/libs/header/sync/ranges_test.go @@ -0,0 +1,34 @@ +package sync + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/celestiaorg/celestia-node/libs/header/test" +) + +func TestAddParallel(t *testing.T) { + var pending ranges[*test.DummyHeader] + + n := 500 + suite := test.NewTestSuite(t) + headers := suite.GenDummyHeaders(n) + + wg := &sync.WaitGroup{} + wg.Add(n) + for i := 0; i < n; i++ { + go func(i int) { + pending.Add(headers[i]) + wg.Done() + }(i) + } + wg.Wait() + + last := uint64(0) + for _, r := range pending.ranges { + assert.Greater(t, r.start, last) + last = r.start + } +}