diff --git a/blacklist.go b/blacklist.go new file mode 100644 index 00000000..84f2efb5 --- /dev/null +++ b/blacklist.go @@ -0,0 +1,53 @@ +package pubsub + +import ( + lru "github.com/hashicorp/golang-lru" + peer "github.com/libp2p/go-libp2p-peer" +) + +// Blacklist is an interface for peer blacklisting. +type Blacklist interface { + Add(peer.ID) + Contains(peer.ID) bool +} + +// MapBlacklist is a blacklist implementation using a perfect map +type MapBlacklist map[peer.ID]struct{} + +// NewMapBlacklist creates a new MapBlacklist +func NewMapBlacklist() Blacklist { + return MapBlacklist(make(map[peer.ID]struct{})) +} + +func (b MapBlacklist) Add(p peer.ID) { + b[p] = struct{}{} +} + +func (b MapBlacklist) Contains(p peer.ID) bool { + _, ok := b[p] + return ok +} + +// LRUBlacklist is a blacklist implementation using an LRU cache +type LRUBlacklist struct { + lru *lru.Cache +} + +// NewLRUBlacklist creates a new LRUBlacklist with capacity cap +func NewLRUBlacklist(cap int) (Blacklist, error) { + c, err := lru.New(cap) + if err != nil { + return nil, err + } + + b := &LRUBlacklist{lru: c} + return b, nil +} + +func (b LRUBlacklist) Add(p peer.ID) { + b.lru.Add(p, nil) +} + +func (b LRUBlacklist) Contains(p peer.ID) bool { + return b.lru.Contains(p) +} diff --git a/blacklist_test.go b/blacklist_test.go new file mode 100644 index 00000000..514d9fa0 --- /dev/null +++ b/blacklist_test.go @@ -0,0 +1,126 @@ +package pubsub + +import ( + "context" + "testing" + "time" + + peer "github.com/libp2p/go-libp2p-peer" +) + +func TestMapBlacklist(t *testing.T) { + b := NewMapBlacklist() + + p := peer.ID("test") + + b.Add(p) + if !b.Contains(p) { + t.Fatal("peer not in the blacklist") + } + +} + +func TestLRUBlacklist(t *testing.T) { + b, err := NewLRUBlacklist(10) + if err != nil { + t.Fatal(err) + } + + p := peer.ID("test") + + b.Add(p) + if !b.Contains(p) { + t.Fatal("peer not in the blacklist") + } + +} + +func TestBlacklist(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 2) + psubs := getPubsubs(ctx, hosts) + connect(t, hosts[0], hosts[1]) + + sub, err := psubs[1].Subscribe("test") + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 100) + psubs[1].BlacklistPeer(hosts[0].ID()) + time.Sleep(time.Millisecond * 100) + + psubs[0].Publish("test", []byte("message")) + + wctx, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + _, err = sub.Next(wctx) + + if err == nil { + t.Fatal("got message from blacklisted peer") + } +} + +func TestBlacklist2(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 2) + psubs := getPubsubs(ctx, hosts) + connect(t, hosts[0], hosts[1]) + + _, err := psubs[0].Subscribe("test") + if err != nil { + t.Fatal(err) + } + + sub1, err := psubs[1].Subscribe("test") + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 100) + psubs[1].BlacklistPeer(hosts[0].ID()) + time.Sleep(time.Millisecond * 100) + + psubs[0].Publish("test", []byte("message")) + + wctx, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + _, err = sub1.Next(wctx) + + if err == nil { + t.Fatal("got message from blacklisted peer") + } +} + +func TestBlacklist3(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 2) + psubs := getPubsubs(ctx, hosts) + + psubs[1].BlacklistPeer(hosts[0].ID()) + time.Sleep(time.Millisecond * 100) + connect(t, hosts[0], hosts[1]) + + sub, err := psubs[1].Subscribe("test") + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 100) + + psubs[0].Publish("test", []byte("message")) + + wctx, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + _, err = sub.Next(wctx) + + if err == nil { + t.Fatal("got message from blacklisted peer") + } +} diff --git a/package.json b/package.json index dcd396c1..92678ea4 100644 --- a/package.json +++ b/package.json @@ -77,6 +77,12 @@ "hash": "QmabLh8TrJ3emfAoQk5AbqbLTbMyj7XqumMFmAFxa9epo8", "name": "go-multistream", "version": "0.3.9" + }, + { + "author": "hashicorp", + "hash": "QmQjMHF8ptRgx4E57UFMiT4YM6kqaJeYxZ1MCDX23aw4rK", + "name": "golang-lru", + "version": "2017.10.18" } ], "gxVersion": "0.9.0", @@ -86,3 +92,4 @@ "releaseCmd": "git commit -a -m \"gx publish $VERSION\"", "version": "0.11.10" } + diff --git a/pubsub.go b/pubsub.go index 11824e8b..3c5cebe2 100644 --- a/pubsub.go +++ b/pubsub.go @@ -98,6 +98,10 @@ type PubSub struct { // eval thunk in event loop eval chan func() + // peer blacklist + blacklist Blacklist + blacklistPeer chan peer.ID + peers map[peer.ID]chan *RPC seenMessages *timecache.TimeCache @@ -179,6 +183,8 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option topics: make(map[string]map[peer.ID]struct{}), peers: make(map[peer.ID]chan *RPC), topicVals: make(map[string]*topicVal), + blacklist: NewMapBlacklist(), + blacklistPeer: make(chan peer.ID), seenMessages: timecache.NewTimeCache(TimeCacheDuration), counter: uint64(time.Now().UnixNano()), } @@ -262,6 +268,15 @@ func WithStrictSignatureVerification(required bool) Option { } } +// WithBlacklist provides an implementation of the blacklist; the default is a +// MapBlacklist +func WithBlacklist(b Blacklist) Option { + return func(p *PubSub) error { + p.blacklist = b + return nil + } +} + // processLoop handles all inputs arriving on the channels func (p *PubSub) processLoop(ctx context.Context) { defer func() { @@ -276,12 +291,16 @@ func (p *PubSub) processLoop(ctx context.Context) { for { select { case pid := <-p.newPeers: - _, ok := p.peers[pid] - if ok { + if p.blacklist.Contains(pid) { log.Warning("already have connection to peer: ", pid) continue } + if p.blacklist.Contains(pid) { + log.Warning("ignoring connection from blacklisted peer: ", pid) + continue + } + messages := make(chan *RPC, 32) messages <- p.getHelloPacket() go p.handleNewPeer(ctx, pid, messages) @@ -290,13 +309,20 @@ func (p *PubSub) processLoop(ctx context.Context) { case s := <-p.newPeerStream: pid := s.Conn().RemotePeer() - _, ok := p.peers[pid] + ch, ok := p.peers[pid] if !ok { log.Warning("new stream for unknown peer: ", pid) s.Reset() continue } + if p.blacklist.Contains(pid) { + log.Warning("closing stream for blacklisted peer: ", pid) + close(ch) + s.Reset() + continue + } + p.rt.AddPeer(pid, s.Protocol()) case pid := <-p.newPeerError: @@ -374,6 +400,20 @@ func (p *PubSub) processLoop(ctx context.Context) { case thunk := <-p.eval: thunk() + case pid := <-p.blacklistPeer: + log.Infof("Blacklisting peer %s", pid) + p.blacklist.Add(pid) + + ch, ok := p.peers[pid] + if ok { + close(ch) + delete(p.peers, pid) + for _, t := range p.topics { + delete(t, pid) + } + p.rt.RemovePeer(pid) + } + case <-ctx.Done(): log.Info("pubsub processloop shutting down") return @@ -567,6 +607,18 @@ func msgID(pmsg *pb.Message) string { // pushMsg pushes a message performing validation as necessary func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) { + // reject messages from blacklisted peers + if p.blacklist.Contains(src) { + log.Warningf("dropping message from blacklisted peer %s", src) + return + } + + // even if they are forwarded by good peers + if p.blacklist.Contains(msg.GetFrom()) { + log.Warningf("dropping message from blacklisted source %s", src) + return + } + // reject unsigned messages when strict before we even process the id if p.signStrict && msg.Signature == nil { log.Debugf("dropping unsigned message from %s", src) @@ -821,6 +873,11 @@ func (p *PubSub) ListPeers(topic string) []peer.ID { return <-out } +// BlacklistPeer blacklists a peer; all messages from this peer will be unconditionally dropped. +func (p *PubSub) BlacklistPeer(pid peer.ID) { + p.blacklistPeer <- pid +} + // per topic validators type addValReq struct { topic string