Skip to content

Commit

Permalink
Merge pull request #14560 from openmina/fixes/libp2p_helper_races
Browse files Browse the repository at this point in the history
Protect shared data accesses in `libp2p_helper`
  • Loading branch information
georgeee authored Nov 16, 2023
2 parents 7b91714 + 586f43e commit e64d3d5
Show file tree
Hide file tree
Showing 12 changed files with 239 additions and 135 deletions.
12 changes: 10 additions & 2 deletions src/app/libp2p_helper/src/codanet.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ type Helper struct {
ConnectionManager *CodaConnectionManager
BandwidthCounter *metrics.BandwidthCounter
MsgStats *MessageStats
Seeds []peer.AddrInfo
_seeds []peer.AddrInfo
seedsMutex sync.RWMutex
NodeStatus []byte
HeartbeatPeer func(peer.ID)
}
Expand Down Expand Up @@ -368,6 +369,13 @@ func (h *Helper) SetGatingState(gs *CodaGatingConfig) {
}
}

func (h *Helper) AddSeeds(infos ...peer.AddrInfo) {
// TODO: this "_seeds" field is never read anywhere, is it needed?
h.seedsMutex.Lock()
h._seeds = append(h._seeds, infos...)
h.seedsMutex.Unlock()
}

func (gs *CodaGatingState) TrustPeer(p peer.ID) {
gs.trustedPeersMutex.Lock()
gs.trustedPeers[p] = struct{}{}
Expand Down Expand Up @@ -751,7 +759,7 @@ func MakeHelper(ctx context.Context, listenOn []ma.Multiaddr, externalAddr ma.Mu
ConnectionManager: connManager,
BandwidthCounter: bandwidthCounter,
MsgStats: &MessageStats{min: math.MaxUint64},
Seeds: seeds,
_seeds: seeds,
HeartbeatPeer: func(p peer.ID) {
lanPatcher.Heartbeat(p)
wanPatcher.Heartbeat(p)
Expand Down
158 changes: 151 additions & 7 deletions src/app/libp2p_helper/src/libp2p_helper/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"math"
"os"
"strconv"
"sync"
"time"

ipc "libp2p_ipc"
Expand All @@ -29,14 +28,13 @@ func newApp() *app {
return &app{
P2p: nil,
Ctx: ctx,
Subs: make(map[uint64]subscription),
Topics: make(map[string]*pubsub.Topic),
ValidatorMutex: &sync.Mutex{},
Validators: make(map[uint64]*validationStatus),
Streams: make(map[uint64]net.Stream),
_subs: make(map[uint64]subscription),
_topics: make(map[string]*pubsub.Topic),
_validators: make(map[uint64]*validationStatus),
_streams: make(map[uint64]net.Stream),
OutChan: outChan,
Out: bufio.NewWriter(os.Stdout),
AddedPeers: []peer.AddrInfo{},
_addedPeers: []peer.AddrInfo{},
MetricsRefreshTime: time.Minute,
metricsCollectionStarted: false,
metricsServer: nil,
Expand Down Expand Up @@ -64,6 +62,151 @@ func (app *app) NextId() uint64 {
return app.counter
}

func (app *app) AddPeers(infos ...peer.AddrInfo) {
app.addedPeersMutex.Lock()
defer app.addedPeersMutex.Unlock()
app._addedPeers = append(app._addedPeers, infos...)
}

func (app *app) GetAddedPeers() []peer.AddrInfo {
app.addedPeersMutex.RLock()
defer app.addedPeersMutex.RUnlock()
copyOfAddedPeers := make([]peer.AddrInfo, len(app._addedPeers))
copy(copyOfAddedPeers, app._addedPeers)
return copyOfAddedPeers
}

func (app *app) ResetAddedPeers() {
app.addedPeersMutex.Lock()
defer app.addedPeersMutex.Unlock()
app._addedPeers = nil
}

func (app *app) AddStream(stream net.Stream) uint64 {
streamIdx := app.NextId()
app.streamsMutex.Lock()
defer app.streamsMutex.Unlock()
app._streams[streamIdx] = stream
return streamIdx
}

func (app *app) CloseStream(streamId uint64) error {
app.streamsMutex.Lock()
defer app.streamsMutex.Unlock()
if stream, ok := app._streams[streamId]; ok {
delete(app._streams, streamId)
err := stream.Close()
if err != nil {
return badp2p(err)
}
return nil
}
return badRPC(errors.New("unknown stream_idx"))
}

func (app *app) ResetStream(streamId uint64) error {
app.streamsMutex.Lock()
defer app.streamsMutex.Unlock()
if stream, ok := app._streams[streamId]; ok {
delete(app._streams, streamId)
err := stream.Reset()
if err != nil {
return badp2p(err)
}
return nil
}
return badRPC(errors.New("unknown stream_idx"))
}

func (app *app) StreamWrite(streamId uint64, data []byte) error {
// TODO Consider using a more fine-grained locking strategy,
// not using a global mutex to lock on a message sending
app.streamsMutex.Lock()
defer app.streamsMutex.Unlock()
if stream, ok := app._streams[streamId]; ok {
n, err := stream.Write(data)
if err != nil {
// TODO check that it's correct to error out, not repeat writing
delete(app._streams, streamId)
close_err := stream.Close()
if close_err != nil {
app.P2p.Logger.Errorf("failed to close stream %d after encountering write failure (%s): %s", streamId, err.Error(), close_err.Error())
}
return wrapError(badp2p(err), fmt.Sprintf("only wrote %d out of %d bytes", n, len(data)))
}
return nil
}
return badRPC(errors.New("unknown stream_idx"))
}

func (app *app) AddValidator() (uint64, chan pubsub.ValidationResult) {
seqno := app.NextId()
ch := make(chan pubsub.ValidationResult)
app.validatorMutex.Lock()
defer app.validatorMutex.Unlock()
app._validators[seqno] = new(validationStatus)
app._validators[seqno].Completion = ch
return seqno, ch
}

func (app *app) RemoveValidator(seqno uint64) {
app.validatorMutex.Lock()
defer app.validatorMutex.Unlock()
delete(app._validators, seqno)
}

func (app *app) TimeoutValidator(seqno uint64) {
now := time.Now()
app.validatorMutex.Lock()
defer app.validatorMutex.Unlock()
app._validators[seqno].TimedOutAt = &now
}

func (app *app) FinishValidator(seqno uint64, finish func(st *validationStatus)) bool {
app.validatorMutex.Lock()
defer app.validatorMutex.Unlock()
if st, ok := app._validators[seqno]; ok {
finish(st)
delete(app._validators, seqno)
return true
} else {
return false
}
}

func (app *app) AddTopic(topicName string, topic *pubsub.Topic) {
app.topicsMutex.Lock()
defer app.topicsMutex.Unlock()
app._topics[topicName] = topic
}

func (app *app) GetTopic(topicName string) (*pubsub.Topic, bool) {
app.topicsMutex.RLock()
defer app.topicsMutex.RUnlock()
topic, has := app._topics[topicName]
return topic, has
}

func (app *app) AddSubscription(subId uint64, sub subscription) {
app.subsMutex.Lock()
defer app.subsMutex.Unlock()
app._subs[subId] = sub
}

func (app *app) CancelSubscription(subId uint64) bool {
app.subsMutex.Lock()
defer app.subsMutex.Unlock()

if sub, ok := app._subs[subId]; ok {
sub.Sub.Cancel()
sub.Cancel()
delete(app._subs, subId)
return true
}

return false
}

func parseMultiaddrWithID(ma multiaddr.Multiaddr, id peer.ID) (*codaPeerInfo, error) {
ipComponent, tcpMaddr := multiaddr.SplitFirst(ma)
if !(ipComponent.Protocol().Code == multiaddr.P_IP4 || ipComponent.Protocol().Code == multiaddr.P_IP6) {
Expand Down Expand Up @@ -96,6 +239,7 @@ func addrInfoOfString(maddr string) (*peer.AddrInfo, error) {
return info, nil
}

// Writes a message back to the OCaml node
func (app *app) writeMsg(msg *capnp.Message) {
if app.NoUpcalls {
return
Expand Down
12 changes: 6 additions & 6 deletions src/app/libp2p_helper/src/libp2p_helper/config_msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (msg BeginAdvertisingReq) handle(app *app, seqno uint64) *capnp.Message {
return mkRpcRespError(seqno, needsConfigure())
}
app.SetConnectionHandlers()
for _, info := range app.AddedPeers {
for _, info := range app.GetAddedPeers() {
app.P2p.Logger.Debug("Trying to connect to: ", info)
err := app.P2p.Host.Connect(app.Ctx, info)
if err != nil {
Expand Down Expand Up @@ -334,7 +334,7 @@ func (msg ConfigureReq) handle(app *app, seqno uint64) *capnp.Message {
return mkRpcRespError(seqno, badRPC(err))
}

app.AddedPeers = append(app.AddedPeers, seeds...)
app.AddPeers(seeds...)

directPeersMaList, err := m.DirectPeers()
if err != nil {
Expand Down Expand Up @@ -372,12 +372,12 @@ func (msg ConfigureReq) handle(app *app, seqno uint64) *capnp.Message {
if err != nil {
return mkRpcRespError(seqno, badRPC(err))
}
gatingConfig, err := readGatingConfig(gc, app.AddedPeers)
gatingConfig, err := readGatingConfig(gc, app.GetAddedPeers())
if err != nil {
return mkRpcRespError(seqno, badRPC(err))
}
if gc.CleanAddedPeers() {
app.AddedPeers = nil
app.ResetAddedPeers()
}

stateDir, err := m.Statedir()
Expand Down Expand Up @@ -593,13 +593,13 @@ func (m SetGatingConfigReq) handle(app *app, seqno uint64) *capnp.Message {
var gatingConfig *codanet.CodaGatingConfig
gc, err := SetGatingConfigReqT(m).GatingConfig()
if err == nil {
gatingConfig, err = readGatingConfig(gc, app.AddedPeers)
gatingConfig, err = readGatingConfig(gc, app.GetAddedPeers())
}
if err != nil {
return mkRpcRespError(seqno, badRPC(err))
}
if gc.CleanAddedPeers() {
app.AddedPeers = nil
app.ResetAddedPeers()
}
app.P2p.SetGatingState(gatingConfig)

Expand Down
2 changes: 1 addition & 1 deletion src/app/libp2p_helper/src/libp2p_helper/config_msg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestDHTDiscovery_TwoNodes(t *testing.T) {
require.NoError(t, err)

appB, _ := newTestApp(t, appAInfos, true)
appB.AddedPeers = appAInfos
appB.AddPeers(appAInfos...)
appB.NoMDNS = true

// begin appB and appA's DHT advertising
Expand Down
19 changes: 10 additions & 9 deletions src/app/libp2p_helper/src/libp2p_helper/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,19 @@ import (
type app struct {
P2p *codanet.Helper
Ctx context.Context
Subs map[uint64]subscription
Topics map[string]*pubsub.Topic
Validators map[uint64]*validationStatus
ValidatorMutex *sync.Mutex
Streams map[uint64]net.Stream
StreamsMutex sync.Mutex
_subs map[uint64]subscription
subsMutex sync.Mutex
_topics map[string]*pubsub.Topic
topicsMutex sync.RWMutex
_validators map[uint64]*validationStatus
validatorMutex sync.Mutex
_streams map[uint64]net.Stream
streamsMutex sync.Mutex
Out *bufio.Writer
OutChan chan *capnp.Message
Bootstrapper io.Closer
AddedPeers []peer.AddrInfo
addedPeersMutex sync.RWMutex
_addedPeers []peer.AddrInfo
UnsafeNoTrustIP bool
MetricsRefreshTime time.Duration
metricsCollectionStarted bool
Expand All @@ -54,8 +57,6 @@ type app struct {

type subscription struct {
Sub *pubsub.Subscription
Idx uint64
Ctx context.Context
Cancel context.CancelFunc
}

Expand Down
1 change: 1 addition & 0 deletions src/app/libp2p_helper/src/libp2p_helper/incoming_msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ var pushMesssageExtractors = map[ipc.Libp2pHelperInterface_PushMessage_Which]ext
ipc.Libp2pHelperInterface_PushMessage_Which_heartbeatPeer: fromHeartbeatPeerPush,
}

// Handles messages coming from the OCaml process
func (app *app) handleIncomingMsg(msg *ipc.Libp2pHelperInterface_Message) {
if msg.HasRpcRequest() {
resp, err := func() (*capnp.Message, error) {
Expand Down
4 changes: 2 additions & 2 deletions src/app/libp2p_helper/src/libp2p_helper/peer_msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (m AddPeerReq) handle(app *app, seqno uint64) *capnp.Message {
return mkRpcRespError(seqno, badRPC(err))
}

app.AddedPeers = append(app.AddedPeers, *info)
app.AddPeers(*info)
app.P2p.GatingState().TrustPeer(info.ID)

if app.Bootstrapper != nil {
Expand All @@ -50,7 +50,7 @@ func (m AddPeerReq) handle(app *app, seqno uint64) *capnp.Message {
app.P2p.Logger.Info("addPeer Trying to connect to: ", info)

if AddPeerReqT(m).IsSeed() {
app.P2p.Seeds = append(app.P2p.Seeds, *info)
app.P2p.AddSeeds(*info)
}

err = app.P2p.Host.Connect(app.Ctx, *info)
Expand Down
Loading

0 comments on commit e64d3d5

Please sign in to comment.