diff --git a/pubsub.go b/pubsub.go index 8dc9219f..25bfa5e5 100644 --- a/pubsub.go +++ b/pubsub.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "fmt" "math/rand" - "runtime" "sync" "sync/atomic" "time" @@ -21,11 +20,6 @@ import ( timecache "github.com/whyrusleeping/timecache" ) -const ( - defaultValidateConcurrency = 1024 - defaultValidateThrottle = 8192 -) - var ( TimeCacheDuration = 120 * time.Second ) @@ -45,6 +39,8 @@ type PubSub struct { rt PubSubRouter + val *validation + // incoming messages from other peers incoming chan *RPC @@ -90,18 +86,6 @@ type PubSub struct { // rmVal handles validator unregistration requests rmVal chan *rmValReq - // topicVals tracks per topic validators - topicVals map[string]*topicVal - - // validateQ is the front-end to the validation pipeline - validateQ chan *validateReq - - // validateThrottle limits the number of active validation goroutines - validateThrottle chan struct{} - - // this is the number of synchronous validation workers - validateWorkers int - // eval thunk in event loop eval chan func() @@ -168,36 +152,33 @@ type Option func(*PubSub) error // NewPubSub returns a new PubSub management object. func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option) (*PubSub, error) { ps := &PubSub{ - host: h, - ctx: ctx, - rt: rt, - signID: h.ID(), - signKey: h.Peerstore().PrivKey(h.ID()), - incoming: make(chan *RPC, 32), - publish: make(chan *Message), - newPeers: make(chan peer.ID), - newPeerStream: make(chan inet.Stream), - newPeerError: make(chan peer.ID), - peerDead: make(chan peer.ID), - cancelCh: make(chan *Subscription), - getPeers: make(chan *listPeerReq), - addSub: make(chan *addSubReq), - getTopics: make(chan *topicReq), - sendMsg: make(chan *sendReq, 32), - addVal: make(chan *addValReq), - rmVal: make(chan *rmValReq), - validateThrottle: make(chan struct{}, defaultValidateThrottle), - eval: make(chan func()), - myTopics: make(map[string]map[*Subscription]struct{}), - topics: make(map[string]map[peer.ID]struct{}), - peers: make(map[peer.ID]chan *RPC), - topicVals: make(map[string]*topicVal), - validateQ: make(chan *validateReq, 32), - blacklist: NewMapBlacklist(), - blacklistPeer: make(chan peer.ID), - seenMessages: timecache.NewTimeCache(TimeCacheDuration), - counter: uint64(time.Now().UnixNano()), - validateWorkers: runtime.NumCPU(), + host: h, + ctx: ctx, + rt: rt, + val: newValidation(), + signID: h.ID(), + signKey: h.Peerstore().PrivKey(h.ID()), + incoming: make(chan *RPC, 32), + publish: make(chan *Message), + newPeers: make(chan peer.ID), + newPeerStream: make(chan inet.Stream), + newPeerError: make(chan peer.ID), + peerDead: make(chan peer.ID), + cancelCh: make(chan *Subscription), + getPeers: make(chan *listPeerReq), + addSub: make(chan *addSubReq), + getTopics: make(chan *topicReq), + sendMsg: make(chan *sendReq, 32), + addVal: make(chan *addValReq), + rmVal: make(chan *rmValReq), + eval: make(chan func()), + myTopics: make(map[string]map[*Subscription]struct{}), + topics: make(map[string]map[peer.ID]struct{}), + peers: make(map[peer.ID]chan *RPC), + blacklist: NewMapBlacklist(), + blacklistPeer: make(chan peer.ID), + seenMessages: timecache.NewTimeCache(TimeCacheDuration), + counter: uint64(time.Now().UnixNano()), } for _, opt := range opts { @@ -218,36 +199,13 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option } h.Network().Notify((*PubSubNotif)(ps)) - go ps.processLoop(ctx) + ps.val.Start(ps) - for i := 0; i < ps.validateWorkers; i++ { - go ps.validateWorker() - } + go ps.processLoop(ctx) return ps, nil } -// WithValidateThrottle sets the upper bound on the number of active validation -// goroutines across all topics. The default is 8192. -func WithValidateThrottle(n int) Option { - return func(ps *PubSub) error { - ps.validateThrottle = make(chan struct{}, n) - return nil - } -} - -// WithValidateWorkers sets the number of synchronous validation worker goroutines. -// Defaults to NumCPU. -func WithValidateWorkers(n int) Option { - return func(ps *PubSub) error { - if n > 0 { - ps.validateWorkers = n - return nil - } - return fmt.Errorf("number of validation workers must be > 0") - } -} - // WithMessageSigning enables or disables message signing (enabled by default). func WithMessageSigning(enabled bool) Option { return func(p *PubSub) error { @@ -412,17 +370,16 @@ func (p *PubSub) processLoop(ctx context.Context) { p.handleIncomingRPC(rpc) case msg := <-p.publish: - vals := p.getValidators(msg) - p.pushMsg(vals, p.host.ID(), msg) + p.pushMsg(p.host.ID(), msg) case req := <-p.sendMsg: p.publishMessage(req.from, req.msg.Message) case req := <-p.addVal: - p.addValidator(req) + p.val.AddValidator(req) case req := <-p.rmVal: - p.rmValidator(req) + p.val.RemoveValidator(req) case thunk := <-p.eval: thunk() @@ -630,8 +587,7 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) { } msg := &Message{pmsg} - vals := p.getValidators(msg) - p.pushMsg(vals, rpc.from, msg) + p.pushMsg(rpc.from, msg) } p.rt.HandleRPC(rpc) @@ -643,7 +599,7 @@ func msgID(pmsg *pb.Message) string { } // pushMsg pushes a message performing validation as necessary -func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) { +func (p *PubSub) pushMsg(src peer.ID, msg *Message) { // reject messages from blacklisted peers if p.blacklist.Contains(src) { log.Warningf("dropping message from blacklisted peer %s", src) @@ -668,12 +624,7 @@ func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) { return } - if len(vals) > 0 || msg.Signature != nil { - select { - case p.validateQ <- &validateReq{vals, src, msg}: - default: - log.Warningf("message validation throttled; dropping message from %s", src) - } + if !p.val.Push(src, msg) { return } @@ -682,178 +633,11 @@ func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) { } } -func (p *PubSub) validateWorker() { - for { - select { - case req := <-p.validateQ: - p.validate(req.vals, req.src, req.msg) - case <-p.ctx.Done(): - return - } - } -} - -// validate performs validation and only sends the message if all validators succeed -// signature validation is performed synchronously, while user validators are invoked -// asynchronously, throttled by the global validation throttle. -func (p *PubSub) validate(vals []*topicVal, src peer.ID, msg *Message) { - if msg.Signature != nil { - if !p.validateSignature(msg) { - log.Warningf("message signature validation failed; dropping message from %s", src) - return - } - } - - // we can mark the message as seen now that we have verified the signature - // and avoid invoking user validators more than once - id := msgID(msg.Message) - if !p.markSeen(id) { - return - } - - var inline, async []*topicVal - for _, val := range vals { - if val.validateInline { - inline = append(inline, val) - } else { - async = append(async, val) - } - } - - // apply inline (synchronous) validators - for _, val := range inline { - if !val.validateMsg(p.ctx, src, msg) { - log.Debugf("message validation failed; dropping message from %s", src) - return - } - } - - // apply async validators - if len(async) > 0 { - select { - case p.validateThrottle <- struct{}{}: - go func() { - p.doValidateTopic(async, src, msg) - <-p.validateThrottle - }() - default: - log.Warningf("message validation throttled; dropping message from %s", src) - } - return - } - - // no async validators, send the message - p.sendMsg <- &sendReq{ - from: src, - msg: msg, - } -} - -func (p *PubSub) validateSignature(msg *Message) bool { - err := verifyMessageSignature(msg.Message) - if err != nil { - log.Debugf("signature verification error: %s", err.Error()) - return false - } - - return true -} - -func (p *PubSub) doValidateTopic(vals []*topicVal, src peer.ID, msg *Message) { - if !p.validateTopic(vals, src, msg) { - log.Warningf("message validation failed; dropping message from %s", src) - return - } - - p.sendMsg <- &sendReq{ - from: src, - msg: msg, - } -} - -func (p *PubSub) validateTopic(vals []*topicVal, src peer.ID, msg *Message) bool { - if len(vals) == 1 { - return p.validateSingleTopic(vals[0], src, msg) - } - - ctx, cancel := context.WithCancel(p.ctx) - defer cancel() - - rch := make(chan bool, len(vals)) - rcount := 0 - throttle := false - -loop: - for _, val := range vals { - rcount++ - - select { - case val.validateThrottle <- struct{}{}: - go func(val *topicVal) { - rch <- val.validateMsg(ctx, src, msg) - <-val.validateThrottle - }(val) - - default: - log.Debugf("validation throttled for topic %s", val.topic) - throttle = true - break loop - } - } - - if throttle { - return false - } - - for i := 0; i < rcount; i++ { - valid := <-rch - if !valid { - return false - } - } - - return true -} - -// fast path for single topic validation that avoids the extra goroutine -func (p *PubSub) validateSingleTopic(val *topicVal, src peer.ID, msg *Message) bool { - select { - case val.validateThrottle <- struct{}{}: - ctx, cancel := context.WithCancel(p.ctx) - defer cancel() - - res := val.validateMsg(ctx, src, msg) - <-val.validateThrottle - - return res - - default: - log.Debugf("validation throttled for topic %s", val.topic) - return false - } -} - func (p *PubSub) publishMessage(from peer.ID, pmsg *pb.Message) { p.notifySubs(pmsg) p.rt.Publish(from, pmsg) } -// getValidators returns all validators that apply to a given message -func (p *PubSub) getValidators(msg *Message) []*topicVal { - var vals []*topicVal - - for _, topic := range msg.GetTopicIDs() { - val, ok := p.topicVals[topic] - if !ok { - continue - } - - vals = append(vals, val) - } - - return vals -} - type addSubReq struct { sub *Subscription resp chan *Subscription @@ -965,70 +749,6 @@ func (p *PubSub) BlacklistPeer(pid peer.ID) { p.blacklistPeer <- pid } -// validation requests -type validateReq struct { - vals []*topicVal - src peer.ID - msg *Message -} - -// per topic validators -type addValReq struct { - topic string - validate Validator - timeout time.Duration - throttle int - inline bool - resp chan error -} - -type rmValReq struct { - topic string - resp chan error -} - -type topicVal struct { - topic string - validate Validator - validateTimeout time.Duration - validateThrottle chan struct{} - validateInline bool -} - -// Validator is a function that validates a message. -type Validator func(context.Context, peer.ID, *Message) bool - -// ValidatorOpt is an option for RegisterTopicValidator. -type ValidatorOpt func(addVal *addValReq) error - -// WithValidatorTimeout is an option that sets a timeout for an (asynchronous) topic validator. -// By default there is no timeout in asynchronous validators. -func WithValidatorTimeout(timeout time.Duration) ValidatorOpt { - return func(addVal *addValReq) error { - addVal.timeout = timeout - return nil - } -} - -// WithValidatorConcurrency is an option that sets the topic validator throttle. -// This controls the number of active validation goroutines for the topic; the default is 1024. -func WithValidatorConcurrency(n int) ValidatorOpt { - return func(addVal *addValReq) error { - addVal.throttle = n - return nil - } -} - -// WithValidatorInline is an option that sets the validation disposition to synchronous: -// it will be executed inline in validation front-end, without spawning a new goroutine. -// This is suitable for simple or cpu-bound validators that do not block. -func WithValidatorInline(inline bool) ValidatorOpt { - return func(addVal *addValReq) error { - addVal.inline = inline - return nil - } -} - // RegisterTopicValidator registers a validator for topic. // By default validators are asynchronous, which means they will run in a separate goroutine. // The number of active goroutines is controlled by global and per topic validator @@ -1051,35 +771,6 @@ func (p *PubSub) RegisterTopicValidator(topic string, val Validator, opts ...Val return <-addVal.resp } -func (ps *PubSub) addValidator(req *addValReq) { - topic := req.topic - - _, ok := ps.topicVals[topic] - if ok { - req.resp <- fmt.Errorf("Duplicate validator for topic %s", topic) - return - } - - val := &topicVal{ - topic: topic, - validate: req.validate, - validateTimeout: 0, - validateThrottle: make(chan struct{}, defaultValidateConcurrency), - validateInline: req.inline, - } - - if req.timeout > 0 { - val.validateTimeout = req.timeout - } - - if req.throttle > 0 { - val.validateThrottle = make(chan struct{}, req.throttle) - } - - ps.topicVals[topic] = val - req.resp <- nil -} - // UnregisterTopicValidator removes a validator from a topic. // Returns an error if there was no validator registered with the topic. func (p *PubSub) UnregisterTopicValidator(topic string) error { @@ -1091,30 +782,3 @@ func (p *PubSub) UnregisterTopicValidator(topic string) error { p.rmVal <- rmVal return <-rmVal.resp } - -func (ps *PubSub) rmValidator(req *rmValReq) { - topic := req.topic - - _, ok := ps.topicVals[topic] - if ok { - delete(ps.topicVals, topic) - req.resp <- nil - } else { - req.resp <- fmt.Errorf("No validator for topic %s", topic) - } -} - -func (val *topicVal) validateMsg(ctx context.Context, src peer.ID, msg *Message) bool { - if val.validateTimeout > 0 { - var cancel func() - ctx, cancel = context.WithTimeout(ctx, val.validateTimeout) - defer cancel() - } - - valid := val.validate(ctx, src, msg) - if !valid { - log.Debugf("validation failed for topic %s", val.topic) - } - - return valid -} diff --git a/validation.go b/validation.go new file mode 100644 index 00000000..3c2ed57b --- /dev/null +++ b/validation.go @@ -0,0 +1,383 @@ +package pubsub + +import ( + "context" + "fmt" + "runtime" + "time" + + peer "github.com/libp2p/go-libp2p-peer" +) + +const ( + defaultValidateConcurrency = 1024 + defaultValidateThrottle = 8192 +) + +// Validator is a function that validates a message. +type Validator func(context.Context, peer.ID, *Message) bool + +// ValidatorOpt is an option for RegisterTopicValidator. +type ValidatorOpt func(addVal *addValReq) error + +// validation represents the validator pipeline +type validation struct { + p *PubSub + + // topicVals tracks per topic validators + topicVals map[string]*topicVal + + // validateQ is the front-end to the validation pipeline + validateQ chan *validateReq + + // validateThrottle limits the number of active validation goroutines + validateThrottle chan struct{} + + // this is the number of synchronous validation workers + validateWorkers int +} + +// validation requests +type validateReq struct { + vals []*topicVal + src peer.ID + msg *Message +} + +// representation of topic validators +type topicVal struct { + topic string + validate Validator + validateTimeout time.Duration + validateThrottle chan struct{} + validateInline bool +} + +// async request to add a topic validators +type addValReq struct { + topic string + validate Validator + timeout time.Duration + throttle int + inline bool + resp chan error +} + +// async request to remove a topic validator +type rmValReq struct { + topic string + resp chan error +} + +// newValidation creates a new validation pipeline +func newValidation() *validation { + return &validation{ + topicVals: make(map[string]*topicVal), + validateQ: make(chan *validateReq, 32), + validateThrottle: make(chan struct{}, defaultValidateThrottle), + validateWorkers: runtime.NumCPU(), + } +} + +// Start attaches the validation pipeline to a pubsub instance and starts background +// workers +func (v *validation) Start(p *PubSub) { + v.p = p + for i := 0; i < v.validateWorkers; i++ { + go v.validateWorker() + } +} + +// AddValidator adds a new validator +func (v *validation) AddValidator(req *addValReq) { + topic := req.topic + + _, ok := v.topicVals[topic] + if ok { + req.resp <- fmt.Errorf("Duplicate validator for topic %s", topic) + return + } + + val := &topicVal{ + topic: topic, + validate: req.validate, + validateTimeout: 0, + validateThrottle: make(chan struct{}, defaultValidateConcurrency), + validateInline: req.inline, + } + + if req.timeout > 0 { + val.validateTimeout = req.timeout + } + + if req.throttle > 0 { + val.validateThrottle = make(chan struct{}, req.throttle) + } + + v.topicVals[topic] = val + req.resp <- nil +} + +// RemoveValidator removes an existing validator +func (v *validation) RemoveValidator(req *rmValReq) { + topic := req.topic + + _, ok := v.topicVals[topic] + if ok { + delete(v.topicVals, topic) + req.resp <- nil + } else { + req.resp <- fmt.Errorf("No validator for topic %s", topic) + } +} + +// Push pushes a message into the validation pipeline. +// It returns true if the message can be forwarded immediately without validation. +func (v *validation) Push(src peer.ID, msg *Message) bool { + vals := v.getValidators(msg) + + if len(vals) > 0 || msg.Signature != nil { + select { + case v.validateQ <- &validateReq{vals, src, msg}: + default: + log.Warningf("message validation throttled; dropping message from %s", src) + } + return false + } + + return true +} + +// getValidators returns all validators that apply to a given message +func (v *validation) getValidators(msg *Message) []*topicVal { + var vals []*topicVal + + for _, topic := range msg.GetTopicIDs() { + val, ok := v.topicVals[topic] + if !ok { + continue + } + + vals = append(vals, val) + } + + return vals +} + +// validateWorker is an active goroutine performing inline validation +func (v *validation) validateWorker() { + for { + select { + case req := <-v.validateQ: + v.validate(req.vals, req.src, req.msg) + case <-v.p.ctx.Done(): + return + } + } +} + +// validate performs validation and only sends the message if all validators succeed +// signature validation is performed synchronously, while user validators are invoked +// asynchronously, throttled by the global validation throttle. +func (v *validation) validate(vals []*topicVal, src peer.ID, msg *Message) { + if msg.Signature != nil { + if !v.validateSignature(msg) { + log.Warningf("message signature validation failed; dropping message from %s", src) + return + } + } + + // we can mark the message as seen now that we have verified the signature + // and avoid invoking user validators more than once + id := msgID(msg.Message) + if !v.p.markSeen(id) { + return + } + + var inline, async []*topicVal + for _, val := range vals { + if val.validateInline { + inline = append(inline, val) + } else { + async = append(async, val) + } + } + + // apply inline (synchronous) validators + for _, val := range inline { + if !val.validateMsg(v.p.ctx, src, msg) { + log.Debugf("message validation failed; dropping message from %s", src) + return + } + } + + // apply async validators + if len(async) > 0 { + select { + case v.validateThrottle <- struct{}{}: + go func() { + v.doValidateTopic(async, src, msg) + <-v.validateThrottle + }() + default: + log.Warningf("message validation throttled; dropping message from %s", src) + } + return + } + + // no async validators, send the message + v.p.sendMsg <- &sendReq{ + from: src, + msg: msg, + } +} + +func (v *validation) validateSignature(msg *Message) bool { + err := verifyMessageSignature(msg.Message) + if err != nil { + log.Debugf("signature verification error: %s", err.Error()) + return false + } + + return true +} + +func (v *validation) doValidateTopic(vals []*topicVal, src peer.ID, msg *Message) { + if !v.validateTopic(vals, src, msg) { + log.Warningf("message validation failed; dropping message from %s", src) + return + } + + v.p.sendMsg <- &sendReq{ + from: src, + msg: msg, + } +} + +func (v *validation) validateTopic(vals []*topicVal, src peer.ID, msg *Message) bool { + if len(vals) == 1 { + return v.validateSingleTopic(vals[0], src, msg) + } + + ctx, cancel := context.WithCancel(v.p.ctx) + defer cancel() + + rch := make(chan bool, len(vals)) + rcount := 0 + throttle := false + +loop: + for _, val := range vals { + rcount++ + + select { + case val.validateThrottle <- struct{}{}: + go func(val *topicVal) { + rch <- val.validateMsg(ctx, src, msg) + <-val.validateThrottle + }(val) + + default: + log.Debugf("validation throttled for topic %s", val.topic) + throttle = true + break loop + } + } + + if throttle { + return false + } + + for i := 0; i < rcount; i++ { + valid := <-rch + if !valid { + return false + } + } + + return true +} + +// fast path for single topic validation that avoids the extra goroutine +func (v *validation) validateSingleTopic(val *topicVal, src peer.ID, msg *Message) bool { + select { + case val.validateThrottle <- struct{}{}: + ctx, cancel := context.WithCancel(v.p.ctx) + defer cancel() + + res := val.validateMsg(ctx, src, msg) + <-val.validateThrottle + + return res + + default: + log.Debugf("validation throttled for topic %s", val.topic) + return false + } +} + +func (val *topicVal) validateMsg(ctx context.Context, src peer.ID, msg *Message) bool { + if val.validateTimeout > 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, val.validateTimeout) + defer cancel() + } + + valid := val.validate(ctx, src, msg) + if !valid { + log.Debugf("validation failed for topic %s", val.topic) + } + + return valid +} + +/// Options + +// WithValidateThrottle sets the upper bound on the number of active validation +// goroutines across all topics. The default is 8192. +func WithValidateThrottle(n int) Option { + return func(ps *PubSub) error { + ps.val.validateThrottle = make(chan struct{}, n) + return nil + } +} + +// WithValidateWorkers sets the number of synchronous validation worker goroutines. +// Defaults to NumCPU. +func WithValidateWorkers(n int) Option { + return func(ps *PubSub) error { + if n > 0 { + ps.val.validateWorkers = n + return nil + } + return fmt.Errorf("number of validation workers must be > 0") + } +} + +// WithValidatorTimeout is an option that sets a timeout for an (asynchronous) topic validator. +// By default there is no timeout in asynchronous validators. +func WithValidatorTimeout(timeout time.Duration) ValidatorOpt { + return func(addVal *addValReq) error { + addVal.timeout = timeout + return nil + } +} + +// WithValidatorConcurrency is an option that sets the topic validator throttle. +// This controls the number of active validation goroutines for the topic; the default is 1024. +func WithValidatorConcurrency(n int) ValidatorOpt { + return func(addVal *addValReq) error { + addVal.throttle = n + return nil + } +} + +// WithValidatorInline is an option that sets the validation disposition to synchronous: +// it will be executed inline in validation front-end, without spawning a new goroutine. +// This is suitable for simple or cpu-bound validators that do not block. +func WithValidatorInline(inline bool) ValidatorOpt { + return func(addVal *addValReq) error { + addVal.inline = inline + return nil + } +}