From 79206f9d8bc28e7a964e7b64424e65c5f1122cba Mon Sep 17 00:00:00 2001 From: Aravind Srinivasan Date: Tue, 27 Dec 2016 08:51:54 -0800 Subject: [PATCH] Initial commit --- .gitignore | 14 + LICENSE | 20 + Makefile | 36 ++ README.md | 35 ++ client/cherami/basePublisher.go | 161 +++++ client/cherami/client.go | 382 ++++++++++++ client/cherami/client_test.go | 140 +++++ client/cherami/connection.go | 349 +++++++++++ client/cherami/connection_test.go | 349 +++++++++++ client/cherami/consumer.go | 413 +++++++++++++ client/cherami/delivery.go | 91 +++ client/cherami/delivery_test.go | 96 +++ client/cherami/interfaces.go | 249 ++++++++ client/cherami/mockClient_test.go | 370 +++++++++++ client/cherami/outputhostconnection.go | 360 +++++++++++ client/cherami/outputhostconnection_test.go | 243 ++++++++ client/cherami/publisher.go | 347 +++++++++++ client/cherami/reconfigurable.go | 103 ++++ client/cherami/task.go | 52 ++ client/cherami/taskexecutor.go | 182 ++++++ client/cherami/taskscheduler.go | 108 ++++ client/cherami/tchanPublisher.go | 427 +++++++++++++ client/cherami/tchanPublisher_test.go | 233 +++++++ client/cherami/wsconnector.go | 156 +++++ common/backoff/retry.go | 69 +++ common/backoff/retry_test.go | 141 +++++ common/backoff/retrypolicy.go | 197 ++++++ common/backoff/retrypolicy_test.go | 231 +++++++ common/constants.go | 77 +++ common/convert.go | 98 +++ common/log.go | 235 +++++++ common/metrics/interfaces.go | 80 +++ common/metrics/names.go | 99 +++ common/metrics/nullreporter.go | 94 +++ common/thrift_util.go | 117 ++++ common/util.go | 294 +++++++++ common/websocket/base_test.go | 260 ++++++++ common/websocket/client.go | 105 ++++ common/websocket/conn.go | 51 ++ common/websocket/hub.go | 73 +++ common/websocket/server.go | 103 ++++ common/websocket/serverclient_test.go | 180 ++++++ common/websocket/stream.go | 401 ++++++++++++ common/websocket/stream_test.go | 576 ++++++++++++++++++ example.go | 197 ++++++ glide.lock | 68 +++ glide.yaml | 28 + mocks/README.md | 9 + .../MockBInOpenPublisherStreamOutCall.go | 73 +++ .../MockBOutOpenConsumerStreamOutCall.go | 73 +++ mocks/clients/cherami/MockTChanBInClient.go | 56 ++ mocks/clients/cherami/MockTChanBOutClient.go | 79 +++ mocks/clients/cherami/MockWSConnector.go | 70 +++ mocks/common/websocket/MockWebsocketConn.go | 114 ++++ stream/stream.go | 91 +++ 55 files changed, 9255 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 Makefile create mode 100644 README.md create mode 100644 client/cherami/basePublisher.go create mode 100644 client/cherami/client.go create mode 100644 client/cherami/client_test.go create mode 100644 client/cherami/connection.go create mode 100644 client/cherami/connection_test.go create mode 100644 client/cherami/consumer.go create mode 100644 client/cherami/delivery.go create mode 100644 client/cherami/delivery_test.go create mode 100644 client/cherami/interfaces.go create mode 100644 client/cherami/mockClient_test.go create mode 100644 client/cherami/outputhostconnection.go create mode 100644 client/cherami/outputhostconnection_test.go create mode 100644 client/cherami/publisher.go create mode 100644 client/cherami/reconfigurable.go create mode 100644 client/cherami/task.go create mode 100644 client/cherami/taskexecutor.go create mode 100644 client/cherami/taskscheduler.go create mode 100644 client/cherami/tchanPublisher.go create mode 100644 client/cherami/tchanPublisher_test.go create mode 100644 client/cherami/wsconnector.go create mode 100644 common/backoff/retry.go create mode 100644 common/backoff/retry_test.go create mode 100644 common/backoff/retrypolicy.go create mode 100644 common/backoff/retrypolicy_test.go create mode 100644 common/constants.go create mode 100644 common/convert.go create mode 100644 common/log.go create mode 100644 common/metrics/interfaces.go create mode 100644 common/metrics/names.go create mode 100644 common/metrics/nullreporter.go create mode 100644 common/thrift_util.go create mode 100644 common/util.go create mode 100644 common/websocket/base_test.go create mode 100644 common/websocket/client.go create mode 100644 common/websocket/conn.go create mode 100644 common/websocket/hub.go create mode 100644 common/websocket/server.go create mode 100644 common/websocket/serverclient_test.go create mode 100644 common/websocket/stream.go create mode 100644 common/websocket/stream_test.go create mode 100644 example.go create mode 100644 glide.lock create mode 100644 glide.yaml create mode 100644 mocks/README.md create mode 100644 mocks/clients/cherami/MockBInOpenPublisherStreamOutCall.go create mode 100644 mocks/clients/cherami/MockBOutOpenConsumerStreamOutCall.go create mode 100644 mocks/clients/cherami/MockTChanBInClient.go create mode 100644 mocks/clients/cherami/MockTChanBOutClient.go create mode 100644 mocks/clients/cherami/MockWSConnector.go create mode 100644 mocks/common/websocket/MockWebsocketConn.go create mode 100644 stream/stream.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8fd2e70 --- /dev/null +++ b/.gitignore @@ -0,0 +1,14 @@ +*.out +*.test +*.xml +*.swp +*.cov +*.html +*.tmp +test +test.log +.rewrite +vendor/ + +# produced executable(s) +example diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..a4c8d11 --- /dev/null +++ b/LICENSE @@ -0,0 +1,20 @@ +Copyright (c) 2016 Uber Technologies, Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..f5fdce3 --- /dev/null +++ b/Makefile @@ -0,0 +1,36 @@ +.PHONY: bins test clean +PROJECT_ROOT=github.com/uber/cherami-client-go + +export PATH := $(GOPATH)/bin:$(PATH) + +PROGS = example +export GO15VENDOREXPERIMENT=1 +NOVENDOR = $(shell GO15VENDOREXPERIMENT=1 glide novendor) + +export PATH := $(GOPATH)/bin:$(PATH) + +# Automatically gather all srcs +ALL_SRC := $(shell find . -name "*.go" | grep -v -e Godeps -e vendor \ + -e ".*/\..*" \ + -e ".*/_.*" \ + -e ".*/mocks.*") + +# all directories with *_test.go files in them +TEST_DIRS := $(sort $(dir $(filter %_test.go,$(ALL_SRC)))) + +bins: + glide install + go build -i -o example example.go +clean: + rm -rf example +test: + @rm -f test + @rm -f test.log + @for dir in $(TEST_DIRS); do \ + go test -coverprofile=$@ "$$dir" | tee -a test.log; \ + done; + +test-race: + @for dir in $(TEST_DIRS); do \ + go test -race "$$dir" | tee -a "$$dir"_test.log; \ + done; diff --git a/README.md b/README.md new file mode 100644 index 0000000..7d1a48c --- /dev/null +++ b/README.md @@ -0,0 +1,35 @@ +Go client library for Cherami +============================= + +Cherami is a distributed, scalable, durable, and highly available message queue system we developed at Uber Engineering to transport asynchronous tasks. + +cherami-client-go is the go client library for Cherami. + +Clone this repo +--------------- +Make sure you clone this repo into the correct location. + +`git clone git@github.com:uber/cherami-client-go.git $GOPATH/src/github.com/uber/cherami-client-go` +`pushd $GOPATH/src/github.com/uber/cherami-client-go` + + +Development +----------- +The cherami-client-go repo specifically holds the client library for Cherami, whose thrift APIs are defined in cherami-thrift repo. This repo can be used to talk to Cherami server once the cherami server is up and running. + +The repo also holds an `example` which can be executed against the cherami server running locally. + +In order to use the example in this repo, the following dependencies needs to be addressed: +1. Make certain that `thrift` (OSX: `brew install thrift`) and `glide` are in your path (above). +2. Make sure that cherami server is up and running by cloning the `cherami-server` repo and following the instructions on that repo. + +Once we have the aforementioned steps, one can build the `example` by running: +`make bins` + +In order to use `cherami-client-go` as a library in an application which wants to talk to Cherami, in the consuming repo, take in the updated go-client (`github.com/uber/cherami-client-go`) as a package in `glide.yaml`. + +Documentation +-------------- + +Interested in learning more about Cherami? Read the blog post: +[eng.uber.com.cherami](https://eng.uber.com/cherami/) diff --git a/client/cherami/basePublisher.go b/client/cherami/basePublisher.go new file mode 100644 index 0000000..32eb938 --- /dev/null +++ b/client/cherami/basePublisher.go @@ -0,0 +1,161 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "crypto/md5" + "errors" + "hash/crc32" + "strconv" + "sync/atomic" + "time" + + "github.com/uber-common/bark" + + "github.com/uber/cherami-thrift/.generated/go/cherami" + "github.com/uber/cherami-client-go/common" + "github.com/uber/cherami-client-go/common/backoff" + "github.com/uber/cherami-client-go/common/metrics" +) + +type ( + // basePublisher contains the data/code + // common to all types of Publisher + // implementations + basePublisher struct { + idCounter int64 + path string + client Client + logger bark.Logger + reporter metrics.Reporter + retryPolicy backoff.RetryPolicy + checksumOption cherami.ChecksumOption + } + + // publishError represents a message publishing error + publishError struct { + msg string + } +) + +// ErrMessageTimedout is returned by Publish when no ack is received within timeout interval +var ErrMessageTimedout = errors.New("Timed out.") + +// Publisher default retry policy +const ( + publisherRetryInterval = 50 * time.Millisecond + publisherRetryMaxInterval = 10 * time.Second + publisherRetryExpirationInterval = 2 * time.Minute +) + +// choosePublishEndpoints selects the list of publish endpoints +// from the set of all possible (protocol -> endpoints) returned +// from the server +func (bp *basePublisher) choosePublishEndpoints(publisherOptions *cherami.ReadPublisherOptionsResult_) (cherami.Protocol, []*cherami.HostAddress) { + // pick best protocol from what server suggested + hostProtocols := publisherOptions.GetHostProtocols() + chosenIdx, err := bp.chooseProcotol(hostProtocols) + chosenProtocol := cherami.Protocol_WS + chosenHostAddresses := publisherOptions.GetHostAddresses() + if err == nil { + chosenProtocol = hostProtocols[chosenIdx].GetProtocol() + chosenHostAddresses = hostProtocols[chosenIdx].GetHostAddresses() + } + return chosenProtocol, chosenHostAddresses +} + +// chooseProtocol selects a preferred protocol from the list of +// available protocols returned from the server +func (bp *basePublisher) chooseProcotol(hostProtocols []*cherami.HostProtocol) (int, error) { + clientSupportedProtocol := map[cherami.Protocol]bool{cherami.Protocol_WS: true} + clientSupportButDeprecated := -1 + serverSupportedProtocol := make([]cherami.Protocol, 0, len(hostProtocols)) + + for idx, hostProtocol := range hostProtocols { + serverSupportedProtocol = append(serverSupportedProtocol, hostProtocol.GetProtocol()) + if _, found := clientSupportedProtocol[hostProtocol.GetProtocol()]; found { + if !hostProtocol.GetDeprecated() { + // found first supported and non-deprecated one, done + return idx, nil + } else if clientSupportButDeprecated == -1 { + // found first supported but deprecated one, keep looking + clientSupportButDeprecated = idx + } + } + } + + if clientSupportButDeprecated == -1 { + bp.logger.WithField(`protocols`, serverSupportedProtocol).Error("No protocol is supported by client") + return clientSupportButDeprecated, &cherami.BadRequestError{Message: `No protocol is supported by client`} + } + + bp.logger.WithField(`protocol`, hostProtocols[clientSupportButDeprecated].GetProtocol()).Warn("Client using deprecated protocol") + return clientSupportButDeprecated, nil +} + +func (bp *basePublisher) readPublisherOptions() (*cherami.ReadPublisherOptionsResult_, error) { + return bp.client.ReadPublisherOptions(bp.path) +} + +func (bp *basePublisher) addChecksum(msg *cherami.PutMessage) { + switch bp.checksumOption { + case cherami.ChecksumOption_CRC32IEEE: + msg.Crc32IEEEDataChecksum = common.Int64Ptr(int64(crc32.ChecksumIEEE(msg.GetData()))) + case cherami.ChecksumOption_MD5: + md5Checksum := md5.Sum(msg.GetData()) + msg.Md5DataChecksum = md5Checksum[:] + } +} + +// toPutMessage converts a PublisherMessage to cherami.PutMessage +func (bp *basePublisher) toPutMessage(pubMessage *PublisherMessage) *cherami.PutMessage { + + msgID := atomic.AddInt64(&bp.idCounter, 1) + idStr := strconv.FormatInt(msgID, 10) + delay := int32(pubMessage.Delay.Seconds()) + + msg := &cherami.PutMessage{ + ID: common.StringPtr(idStr), + Data: pubMessage.Data, + DelayMessageInSeconds: &delay, + UserContext: pubMessage.UserContext, + } + + bp.addChecksum(msg) + return msg +} + +func createDefaultPublisherRetryPolicy() backoff.RetryPolicy { + policy := backoff.NewExponentialRetryPolicy(publisherRetryInterval) + policy.SetMaximumInterval(publisherRetryMaxInterval) + policy.SetExpirationInterval(publisherRetryExpirationInterval) + return policy +} + +func newPublishError(status cherami.Status) error { + return &publishError{ + msg: "Publish failed with error:" + status.String(), + } +} + +func (e *publishError) Error() string { + return e.msg +} diff --git a/client/cherami/client.go b/client/cherami/client.go new file mode 100644 index 0000000..532312d --- /dev/null +++ b/client/cherami/client.go @@ -0,0 +1,382 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "fmt" + "os" + "strings" + "sync" + "time" + + "github.com/uber/cherami-thrift/.generated/go/cherami" + "github.com/uber/cherami-client-go/common" + "github.com/uber/cherami-client-go/common/backoff" + "github.com/uber/cherami-client-go/common/metrics" + + log "github.com/Sirupsen/logrus" + "github.com/pborman/uuid" + "github.com/uber-common/bark" + "github.com/uber/tchannel-go" + "github.com/uber/tchannel-go/thrift" + + "golang.org/x/net/context" +) + +type ( + clientImpl struct { + connection *tchannel.Channel + client cherami.TChanBFrontend + options *ClientOptions + retryPolicy backoff.RetryPolicy + + sync.Mutex + hostPort string + } +) + +const ( + clientRetryInterval = 50 * time.Millisecond + clientRetryMaxInterval = 10 * time.Second + clientRetryExpirationInterval = 1 * time.Minute +) + +var envUserName = os.Getenv("USER") +var envHostName, _ = os.Hostname() + +// NewClient returns the singleton Cherami client used for communicating with the service at given port +func NewClient(serviceName string, host string, port int, options *ClientOptions) (Client, error) { + ch, err := tchannel.NewChannel(serviceName, nil) + if err != nil { + return nil, err + } + + ch.Peers().Add(fmt.Sprintf("%s:%d", host, port)) + return newClientWithTChannel(ch, options) +} + +// NewHyperbahnClient returns the singleton Cherami client used for communicating with the service via hyperbahn. Streaming methods will probably not work. +func NewHyperbahnClient(serviceName string, bootstrapFile string, options *ClientOptions) (Client, error) { + ch, err := tchannel.NewChannel(serviceName, nil) + if err != nil { + return nil, err + } + + common.CreateHyperbahnClient(ch, bootstrapFile) + return newClientWithTChannel(ch, options) +} + +// NewClientWithFE is used by Frontend to create a Cherami client for itself. +// It is used by non-streaming publish/consume APIs. +func NewClientWithFE(feClient cherami.TChanBFrontend, options *ClientOptions) Client { + if options == nil { + options = getDefaultOptions() + } + common.ValidateTimeout(options.Timeout) + + verifyOptions(options) + + return &clientImpl{ + client: feClient, + options: options, + retryPolicy: createDefaultRetryPolicy(), + } +} + +func newClientWithTChannel(ch *tchannel.Channel, options *ClientOptions) (Client, error) { + if options == nil { + options = getDefaultOptions() + } + common.ValidateTimeout(options.Timeout) + + verifyOptions(options) + + tClient := thrift.NewClient(ch, getFrontEndServiceName(options.DeploymentStr), nil) + + client := &clientImpl{ + connection: ch, + client: cherami.NewTChanBFrontendClient(tClient), + options: options, + retryPolicy: createDefaultRetryPolicy(), + } + return client, nil +} + +func (c *clientImpl) createStreamingClient() (cherami.TChanBFrontend, error) { + // create a streaming client directly connecting to the frontend + c.Lock() + defer c.Unlock() + // if hostPort is not known (e.g. hyperbahn client), need to query the + // frontened to find its IP + if c.hostPort == "" { + ctx, cancel := c.createContext() + defer cancel() + hostport, err := c.client.HostPort(ctx) + if err != nil { + return nil, err + } + + c.hostPort = hostport + } + + ch, err := tchannel.NewChannel(uuid.New(), nil) + if err != nil { + return nil, err + } + + tClient := thrift.NewClient(ch, getFrontEndServiceName(c.options.DeploymentStr), &thrift.ClientOptions{ + HostPort: c.hostPort, + }) + + streamingClient := cherami.NewTChanBFrontendClient(tClient) + + return streamingClient, nil +} + +// Close shuts down the connection to Cherami frontend +func (c *clientImpl) Close() { + c.Lock() + defer c.Unlock() + + if c.connection != nil { + c.connection.Close() + } +} + +func (c *clientImpl) CreateDestination(request *cherami.CreateDestinationRequest) (*cherami.DestinationDescription, error) { + ctx, cancel := c.createContext() + defer cancel() + + return c.client.CreateDestination(ctx, request) +} + +func (c *clientImpl) ReadDestination(request *cherami.ReadDestinationRequest) (*cherami.DestinationDescription, error) { + ctx, cancel := c.createContext() + defer cancel() + + return c.client.ReadDestination(ctx, request) +} + +func (c *clientImpl) UpdateDestination(request *cherami.UpdateDestinationRequest) (*cherami.DestinationDescription, error) { + ctx, cancel := c.createContext() + defer cancel() + + return c.client.UpdateDestination(ctx, request) +} + +func (c *clientImpl) DeleteDestination(request *cherami.DeleteDestinationRequest) error { + ctx, cancel := c.createContext() + defer cancel() + + return c.client.DeleteDestination(ctx, request) +} + +func (c *clientImpl) ListDestinations(request *cherami.ListDestinationsRequest) (*cherami.ListDestinationsResult_, error) { + ctx, cancel := c.createContext() + defer cancel() + + return c.client.ListDestinations(ctx, request) +} + +func (c *clientImpl) CreateConsumerGroup(request *cherami.CreateConsumerGroupRequest) (*cherami.ConsumerGroupDescription, error) { + ctx, cancel := c.createContext() + defer cancel() + + return c.client.CreateConsumerGroup(ctx, request) +} + +func (c *clientImpl) ReadConsumerGroup(request *cherami.ReadConsumerGroupRequest) (*cherami.ConsumerGroupDescription, error) { + ctx, cancel := c.createContext() + defer cancel() + + return c.client.ReadConsumerGroup(ctx, request) +} + +func (c *clientImpl) UpdateConsumerGroup(request *cherami.UpdateConsumerGroupRequest) (*cherami.ConsumerGroupDescription, error) { + ctx, cancel := c.createContext() + defer cancel() + + return c.client.UpdateConsumerGroup(ctx, request) +} + +func (c *clientImpl) MergeDLQForConsumerGroup(request *cherami.MergeDLQForConsumerGroupRequest) error { + ctx, cancel := c.createContext() + defer cancel() + + return c.client.MergeDLQForConsumerGroup(ctx, request) +} + +func (c *clientImpl) PurgeDLQForConsumerGroup(request *cherami.PurgeDLQForConsumerGroupRequest) error { + ctx, cancel := c.createContext() + defer cancel() + + return c.client.PurgeDLQForConsumerGroup(ctx, request) +} + +func (c *clientImpl) DeleteConsumerGroup(request *cherami.DeleteConsumerGroupRequest) error { + ctx, cancel := c.createContext() + defer cancel() + + return c.client.DeleteConsumerGroup(ctx, request) +} + +func (c *clientImpl) ListConsumerGroups(request *cherami.ListConsumerGroupRequest) (*cherami.ListConsumerGroupResult_, error) { + ctx, cancel := c.createContext() + defer cancel() + + return c.client.ListConsumerGroups(ctx, request) +} + +func (c *clientImpl) GetQueueDepthInfo(request *cherami.GetQueueDepthInfoRequest) (*cherami.GetQueueDepthInfoResult_, error) { + ctx, cancel := c.createContext() + defer cancel() + + return c.client.GetQueueDepthInfo(ctx, request) +} + +func (c *clientImpl) CreatePublisher(request *CreatePublisherRequest) Publisher { + switch request.PublisherType { + case PublisherTypeStreaming: + return NewPublisher(c, request.Path, request.MaxInflightMessagesPerConnection) + case PublisherTypeNonStreaming: + return newTChannelBatchPublisher(c, request.Path, c.options.Logger, c.options.MetricsReporter) + } + return nil +} + +func (c *clientImpl) CreateConsumer(request *CreateConsumerRequest) Consumer { + if request.Options == nil { + return nil + } + return newConsumer(c, request.Path, request.ConsumerGroupName, request.ConsumerName, request.PrefetchCount, request.Options) +} + +func (c *clientImpl) createContext() (thrift.Context, context.CancelFunc) { + ctx, cancel := thrift.NewContext(c.options.Timeout) + return thrift.WithHeaders(ctx, map[string]string{ + common.HeaderClientVersion: common.ClientVersion, + common.HeaderUserName: envUserName, + common.HeaderHostName: envHostName, + }), cancel +} + +func (c *clientImpl) ReadPublisherOptions(path string) (*cherami.ReadPublisherOptionsResult_, error) { + ctx, cancel := c.createContext() + defer cancel() + + var result *cherami.ReadPublisherOptionsResult_ + readOp := func() error { + var e error + request := &cherami.ReadPublisherOptionsRequest{ + Path: common.StringPtr(path), + } + + result, e = c.client.ReadPublisherOptions(ctx, request) + + return e + } + + err := backoff.Retry(readOp, c.retryPolicy, isTransientError) + if err != nil { + return nil, err + } + + return result, nil +} + +func (c *clientImpl) ReadConsumerGroupHosts(path string, consumerGroupName string) (*cherami.ReadConsumerGroupHostsResult_, error) { + ctx, cancel := c.createContext() + defer cancel() + + var result *cherami.ReadConsumerGroupHostsResult_ + readOp := func() error { + var e error + request := &cherami.ReadConsumerGroupHostsRequest{ + DestinationPath: common.StringPtr(path), + ConsumerGroupName: common.StringPtr(consumerGroupName), + } + + result, e = c.client.ReadConsumerGroupHosts(ctx, request) + + return e + } + + err := backoff.Retry(readOp, c.retryPolicy, isTransientError) + if err != nil { + return nil, err + } + + return result, nil +} + +func getFrontEndServiceName(deploymentStr string) string { + if len(deploymentStr) == 0 || strings.HasPrefix(strings.ToLower(deploymentStr), `prod`) || strings.HasPrefix(strings.ToLower(deploymentStr), `dev`) { + return common.FrontendServiceName + } + return fmt.Sprintf("%v_%v", common.FrontendServiceName, deploymentStr) +} + +func getDefaultLogger() bark.Logger { + return bark.NewLoggerFromLogrus(log.StandardLogger()) +} + +func getDefaultOptions() *ClientOptions { + return &ClientOptions{ + Timeout: time.Minute, + Logger: getDefaultLogger(), + MetricsReporter: metrics.NewNullReporter(), + } +} + +// verifyOptions is used to verify if we have a metrics reporter and +// a logger. If not, just setup a default logger and a null reporter +func verifyOptions(opts *ClientOptions) { + if opts.Logger == nil { + opts.Logger = getDefaultLogger() + } + + if opts.MetricsReporter == nil { + opts.MetricsReporter = metrics.NewNullReporter() + } + + // Now make sure we init the default metrics as well + opts.MetricsReporter.InitMetrics(metrics.MetricDefs) +} + +func createDefaultRetryPolicy() backoff.RetryPolicy { + policy := backoff.NewExponentialRetryPolicy(clientRetryInterval) + policy.SetMaximumInterval(clientRetryMaxInterval) + policy.SetExpirationInterval(clientRetryExpirationInterval) + + return policy +} + +func isTransientError(err error) bool { + // Only EntityNotExistsError/EntityDisabledError from Cherami is treated as non-transient error today + switch err.(type) { + case *cherami.EntityNotExistsError: + return false + case *cherami.EntityDisabledError: + return false + default: + return true + } +} diff --git a/client/cherami/client_test.go b/client/cherami/client_test.go new file mode 100644 index 0000000..ee418f1 --- /dev/null +++ b/client/cherami/client_test.go @@ -0,0 +1,140 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "testing" + "time" + + "github.com/uber/cherami-client-go/common/metrics" + + log "github.com/Sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/uber-common/bark" + "github.com/uber/tchannel-go" +) + +type ClientSuite struct { + *require.Assertions // override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test, not merely log an error + suite.Suite +} + +func TestClientSuite(t *testing.T) { + suite.Run(t, new(ClientSuite)) +} + +func (s *ClientSuite) SetupTest() { + s.Assertions = require.New(s.T()) // Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil +} + +// TestClientNoOptions tests to make sure we setup a default logger +// and a null reporter even if the options are nil +func (s *ClientSuite) TestClientNoOptions() { + ip, err := tchannel.ListenIP() + s.Nil(err) + listenIP := ip.String() + + // pass nil options + cheramiClient, err := NewClient("cherami-client-test-nil", listenIP, 0, nil) + s.Nil(err) + + cheramiImpl := cheramiClient.(*clientImpl) + + // Make sure the reporter is null reporter + assert.IsType(s.T(), &metrics.NullReporter{}, cheramiImpl.options.MetricsReporter) + + // Make sure logger is not nil as well + s.NotNil(cheramiImpl.options.Logger) + cheramiClient.Close() +} + +// TestClientOptionsOnlyTimeout tests to make sure we setup the logger +// and reporter even if the options has only the timeout set. +func (s *ClientSuite) TestClientOptionsOnlyTimeout() { + ip, err := tchannel.ListenIP() + s.Nil(err) + listenIP := ip.String() + + // setup options with just the timeout set + options := &ClientOptions{ + Timeout: 1 * time.Minute, + } + + cheramiClient, err := NewClient("cherami-client-test-both", listenIP, 0, options) + s.Nil(err) + + cheramiImpl := cheramiClient.(*clientImpl) + + // Make sure the reporter is null reporter + assert.IsType(s.T(), &metrics.NullReporter{}, cheramiImpl.options.MetricsReporter) + + // Make sure logger is not nil as well + s.NotNil(cheramiImpl.options.Logger) + cheramiClient.Close() +} + +// TestClientOptionsNoReporter tests to make sure we setup the null reporter +// even if the options has a logger set +func (s *ClientSuite) TestClientOptionsNoReporter() { + ip, err := tchannel.ListenIP() + s.Nil(err) + listenIP := ip.String() + + // setup options with just the logger + options := &ClientOptions{ + Timeout: 1 * time.Minute, + Logger: bark.NewLoggerFromLogrus(log.StandardLogger()), + } + + cheramiClient, err := NewClient("cherami-client-test-reporter", listenIP, 0, options) + s.Nil(err) + + cheramiImpl := cheramiClient.(*clientImpl) + + // Make sure the reporter is null reporter + assert.IsType(s.T(), &metrics.NullReporter{}, cheramiImpl.options.MetricsReporter) + cheramiClient.Close() +} + +// TestClientOptionsNoLogger tests to make sure we setup the default logger even +// if the options is valid and has just the reporter set +func (s *ClientSuite) TestClientOptionsNoLogger() { + ip, err := tchannel.ListenIP() + s.Nil(err) + listenIP := ip.String() + + // setup options with just the reporter + options := &ClientOptions{ + Timeout: 1 * time.Minute, + MetricsReporter: metrics.NewNullReporter(), + } + + cheramiClient, err := NewClient("cherami-client-test-logger", listenIP, 0, options) + s.Nil(err) + + cheramiImpl := cheramiClient.(*clientImpl) + + // Make sure logger is not nil as well + s.NotNil(cheramiImpl.options.Logger) + cheramiClient.Close() +} diff --git a/client/cherami/connection.go b/client/cherami/connection.go new file mode 100644 index 0000000..df0c295 --- /dev/null +++ b/client/cherami/connection.go @@ -0,0 +1,349 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "errors" + "net/http" + "sync" + "sync/atomic" + + "github.com/uber-common/bark" + "golang.org/x/net/context" + + "time" + + "github.com/uber/cherami-thrift/.generated/go/cherami" + "github.com/uber/cherami-client-go/common" + "github.com/uber/cherami-client-go/common/metrics" + "github.com/uber/cherami-client-go/stream" +) + +type ( + connection struct { + inputHostStream stream.BInOpenPublisherStreamOutCall + path string + inputHostClient cherami.TChanBIn + wsConnector WSConnector + connKey string + protocol cherami.Protocol + messagesCh <-chan putMessageRequest + reconfigureCh chan<- reconfigureInfo + replyCh chan *messageResponse + shuttingDownCh chan struct{} + closeCh chan struct{} + writeMsgPumpWG sync.WaitGroup + readAckPumpWG sync.WaitGroup + cancel context.CancelFunc + logger bark.Logger + reporter metrics.Reporter + + lk sync.Mutex + opened int32 + closed int32 + } + + // This struct is created by writePump after writing message to stream. + // readAcksPump reads it from replyCh to populate the inflightMessages + messageResponse struct { + // This is unique identifier of message + ackID string + // This is the channel created by Publisher for each Publish request + // It is written to by readAcksPump when the ack is received from InputHost on the AckStream + completion chan<- *PublisherReceipt + // This is user specified context to pass through + userContext map[string]string + } + + ackChannelClosedError struct{} +) + +const ( + defaultMaxInflightMessages = 1000 + defaultWGTimeout = time.Minute +) + +func newConnection(client cherami.TChanBIn, wsConnector WSConnector, path string, messages <-chan putMessageRequest, + reconfigureCh chan<- reconfigureInfo, connKey string, protocol cherami.Protocol, + maxInflightMessages int, logger bark.Logger, reporter metrics.Reporter) *connection { + if maxInflightMessages <= 0 { + maxInflightMessages = defaultMaxInflightMessages + } + + conn := &connection{ + inputHostClient: client, + wsConnector: wsConnector, + path: path, + connKey: connKey, + protocol: protocol, + messagesCh: messages, + reconfigureCh: reconfigureCh, + replyCh: make(chan *messageResponse, maxInflightMessages), + shuttingDownCh: make(chan struct{}), + closeCh: make(chan struct{}), + logger: logger, + reporter: reporter, + } + + return conn +} + +func (conn *connection) open() error { + conn.lk.Lock() + defer conn.lk.Unlock() + + if atomic.LoadInt32(&conn.opened) == 0 { + switch conn.protocol { + case cherami.Protocol_WS: + conn.logger.Infof("Using websocket to connect to input host %s", conn.connKey) + wsStream, err := conn.wsConnector.OpenPublisherStream(conn.connKey, http.Header{ + "path": {conn.path}, + }) + if err != nil { + conn.logger.Infof("Error opening websocket connection to input host %s: %v", conn.connKey, err) + return err + } + + conn.inputHostStream = wsStream + conn.cancel = nil + default: + return &cherami.BadRequestError{Message: `Protocol not supported`} + } + + conn.readAckPumpWG.Add(1) + go conn.readAcksPump() + conn.writeMsgPumpWG.Add(1) + go conn.writeMessagesPump() + + atomic.StoreInt32(&conn.opened, 1) + conn.logger.Info("Input host connection opened.") + } + + return nil +} + +func (conn *connection) close() { + conn.lk.Lock() + defer conn.lk.Unlock() + + if atomic.LoadInt32(&conn.closed) == 0 { + // First shutdown the write pump to make sure we don't leave any message without ack + close(conn.shuttingDownCh) + if ok := common.AwaitWaitGroup(&conn.writeMsgPumpWG, defaultWGTimeout); !ok { + conn.logger.Warn("writeMsgPumpWG timed out") + } + + // Now shutdown the read pump and drain all inflight messages + close(conn.closeCh) + if ok := common.AwaitWaitGroup(&conn.readAckPumpWG, defaultWGTimeout); !ok { + conn.logger.Warn("readAckPumpWG timed out") + } + + // Both pumps are shutdown. Close the underlying stream. + if conn.cancel != nil { + conn.cancel() + } + + if conn.inputHostStream != nil { + conn.inputHostStream.Done() + } + + atomic.StoreInt32(&conn.closed, 1) + conn.logger.Info("Input host connection closed.") + + // trigger a reconfiguration due to connection closed + select { + case conn.reconfigureCh <- reconfigureInfo{eventType: connClosedReconfigureType, reconfigureID: conn.connKey}: + default: + conn.logger.Info("Reconfigure channel is full. Drop reconfigure command due to connection close.") + } + } +} + +func (conn *connection) writeMessagesPump() { + defer conn.writeMsgPumpWG.Done() + + for { + select { + case pr := <-conn.messagesCh: + sw := conn.reporter.StartTimer(metrics.PublishMessageLatency, nil) + conn.reporter.IncCounter(metrics.PublishMessageRate, nil, 1) + + err := conn.inputHostStream.Write(pr.message) + if err == nil { + // TODO: We don't need to flush on every message. Rewrite the method to flush once messagesCh is empty + err = conn.inputHostStream.Flush() + } + sw.Stop() + + if err == nil { + conn.replyCh <- &messageResponse{pr.message.GetID(), pr.messageAck, pr.message.GetUserContext()} + } else { + conn.reporter.IncCounter(metrics.PublishMessageFailedRate, nil, 1) + conn.logger.WithField(common.TagMsgID, common.FmtMsgID(pr.message.GetID())).Infof("Error writing message to stream: %v", err) + + pr.messageAck <- &PublisherReceipt{ + ID: pr.message.GetID(), + Error: err, + UserContext: pr.message.GetUserContext(), + } + + // Write failed, rebuild connection + go conn.close() + } + case <-conn.shuttingDownCh: + // Connection is closed. Bail out of the pump and close stream + return + } + } +} + +func (conn *connection) readAcksPump() { + defer conn.readAckPumpWG.Done() + + inflightMessages := make(map[string]*messageResponse) + // This map is needed when we receive a reply out of order before the inflightMessages is populated + earlyReplyAcks := make(map[string]*PublisherReceipt) + defer failInflightMessages(inflightMessages) + + // Flag which is set when AckStream is closed by InputHost + isEOF := false + for { + conn.reporter.UpdateGauge(metrics.PublishNumInflightMessagess, nil, int64(len(inflightMessages))) + + if isEOF || len(inflightMessages) == 0 { + select { + // We want to make sure that ackId is in the inflightMessages before we read a response for it + case resCh := <-conn.replyCh: + populateInflightMapUtil(inflightMessages, earlyReplyAcks, resCh) + // Connection is closed just fail all inflightMessages + case <-conn.closeCh: + // First drain the replyCh to get all inflight messages which are in the + // buffer. This is essential because we could have some messages in the channel buffer + // which we need to fail as well so that the clients can retry. If not, + // we will get unnecessary timeouts. + DrainLoop: + for { + select { + case resCh, ok := <-conn.replyCh: + if !ok { + break DrainLoop + } + populateInflightMapUtil(inflightMessages, earlyReplyAcks, resCh) + default: + break DrainLoop + } + } + + return + } + } else { + cmd, err := conn.inputHostStream.Read() + if err != nil { + // Error reading from stream. Time to close and bail out. + conn.logger.Infof("Error reading Ack Stream: %v", err) + // Ack stream is closed. Also close the Connection incase this is triggered by InputHost + // Any inflight messages at this point needs to be failed + go conn.close() + // AckStream is closed by InputHost. There is no point in calling Read again. + // This flag is used to prevent readAckPump from calling read when we know ack stream is closed. + // We still need the pump going to drain the conn.replyCh to populate any inflight messages + isEOF = true + } else { + if cmd.GetType() == cherami.InputHostCommandType_ACK { + conn.reporter.IncCounter(metrics.PublishAckRate, nil, 1) + ack := cmd.Ack + messageResp, exists := inflightMessages[ack.GetID()] + if !exists { + // Probably we received the ack even before the inflightMessages map is populated. + // Let's put it in the earlyReplyAcks map so we can immediately complete the response when seen by this pump. + //conn.logger.WithField(common.TagAckID, common.FmtAckID(ack.GetID())).Debug("Received Ack before populating inflight messages.") + earlyReplyAcks[ack.GetID()] = processMessageAck(ack) + } else { + delete(inflightMessages, ack.GetID()) + messageResp.completion <- processMessageAck(ack) + } + } else if cmd.GetType() == cherami.InputHostCommandType_RECONFIGURE { + conn.reporter.IncCounter(metrics.PublishReconfigureRate, nil, 1) + reconfigInfo := cmd.Reconfigure + conn.logger.WithField(common.TagReconfigureID, common.FmtReconfigureID(reconfigInfo.GetUpdateUUID())).Info("Reconfigure command received from InputHost.") + select { + case conn.reconfigureCh <- reconfigureInfo{eventType: reconfigureCmdReconfigureType, reconfigureID: reconfigInfo.GetUpdateUUID()}: + default: + conn.logger.WithField(common.TagReconfigureID, common.FmtReconfigureID(reconfigInfo.GetUpdateUUID())).Warn("Reconfigure channel is full. Drop reconfigure command.") + } + } + } + } + } +} + +func (conn *connection) isOpened() bool { + return atomic.LoadInt32(&conn.opened) != 0 +} + +func (conn *connection) isClosed() bool { + return atomic.LoadInt32(&conn.closed) != 0 +} + +func (e *ackChannelClosedError) Error() string { + return "Ack channel closed." +} + +func processMessageAck(messageAck *cherami.PutMessageAck) *PublisherReceipt { + ret := &PublisherReceipt{ + ID: messageAck.GetID(), + UserContext: messageAck.GetUserContext(), + } + + if messageAck.GetStatus() != cherami.Status_OK { + ret.Error = errors.New(messageAck.GetMessage()) + } else { + ret.Receipt = messageAck.GetReceipt() + } + + return ret +} + +// populateInflightMapUtil is used to populate the inflightMessages Map, +// based on the acks we received as well. +func populateInflightMapUtil(inflightMessages map[string]*messageResponse, earlyReplyAcks map[string]*PublisherReceipt, resCh *messageResponse) { + // First check if we have already seen the ack for this ID + if ack, ok := earlyReplyAcks[resCh.ackID]; ok { + // We already received the ack for this msgID. Complete the request immediately. + //conn.logger.WithField(common.TagAckID, common.FmtAckID(resCh.ackID)).Debug("Found ack for this response in earlyReplyAcks map. Completing immediately.") + delete(earlyReplyAcks, resCh.ackID) + resCh.completion <- ack + } else { + // Populate the inflightMessages map so we can complete the response after reading the ack from Cherami + inflightMessages[resCh.ackID] = resCh + } +} + +func failInflightMessages(inflightMessages map[string]*messageResponse) { + for id, messageResp := range inflightMessages { + messageResp.completion <- &PublisherReceipt{ + ID: id, + Error: &ackChannelClosedError{}, + UserContext: messageResp.userContext, + } + } +} diff --git a/client/cherami/connection_test.go b/client/cherami/connection_test.go new file mode 100644 index 0000000..4200a86 --- /dev/null +++ b/client/cherami/connection_test.go @@ -0,0 +1,349 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "errors" + "io" + "testing" + "time" + + _ "fmt" + _ "strconv" + + "github.com/uber/cherami-thrift/.generated/go/cherami" + "github.com/uber/cherami-client-go/common" + "github.com/uber/cherami-client-go/common/metrics" + mc "github.com/uber/cherami-client-go/mocks/clients/cherami" + + log "github.com/Sirupsen/logrus" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/uber-common/bark" +) + +type ConnectionSuite struct { + *require.Assertions // override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test, not merely log an error + suite.Suite +} + +func TestConnectionSuite(t *testing.T) { + suite.Run(t, new(ConnectionSuite)) +} + +func (s *ConnectionSuite) SetupTest() { + s.Assertions = require.New(s.T()) // Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil +} + +func (s *ConnectionSuite) TestSuccess() { + conn, inputHostClient, messagesCh := createConnection() + + // Setup inputHostClient mock for successful message + inputHostClient.On("Write", mock.Anything).Return(nil) + inputHostClient.On("Flush").Return(nil) + inputHostClient.On("Read").Return(wrapAckInCommand(&cherami.PutMessageAck{ + ID: common.StringPtr("1"), + Status: common.CheramiStatusPtr(cherami.Status_OK), + }), nil) + + conn.open() + s.True(conn.isOpened(), "Connection not opened.") + + message := &cherami.PutMessage{ + ID: common.StringPtr("1"), + Data: []byte("test"), + } + requestDone := make(chan *PublisherReceipt, 1) + + messagesCh <- putMessageRequest{message, requestDone} + receipt := <-requestDone + + s.Equal(*message.ID, receipt.ID) + s.Nil(receipt.Error, "Expected the message to be written successfully.") + // TODO: mock assert expectations is not working for some reason. + //inputHostClient.AssertExpectations(s.T()) +} + +func (s *ConnectionSuite) TestFailedResponse() { + conn, inputHostClient, messagesCh := createConnection() + + // Setup inputHostClient mock for successful message + inputHostClient.On("Write", mock.Anything).Return(nil) + inputHostClient.On("Flush").Return(nil) + inputHostClient.On("Read").Return(wrapAckInCommand(&cherami.PutMessageAck{ + ID: common.StringPtr("1"), + Status: common.CheramiStatusPtr(cherami.Status_FAILED), + Message: common.StringPtr("Failed"), + }), nil) + + conn.open() + s.True(conn.isOpened(), "Connection not opened.") + + message := &cherami.PutMessage{ + ID: common.StringPtr("1"), + Data: []byte("test"), + } + requestDone := make(chan *PublisherReceipt, 1) + + messagesCh <- putMessageRequest{message, requestDone} + receipt := <-requestDone + + s.Equal(*message.ID, receipt.ID) + s.NotNil(receipt.Error) + s.Equal("Failed", receipt.Error.Error()) + // TODO: mock assert expectations is not working for some reason. + //inputHostClient.AssertExpectations(s.T()) +} + +func (s *ConnectionSuite) TestWriteFailed() { + conn, inputHostClient, messagesCh := createConnection() + + // Setup inputHostClient mock for successful message + inputHostClient.On("Write", mock.Anything).Return(errors.New("Failed")) + inputHostClient.On("Done").Return(nil) + + conn.open() + s.True(conn.isOpened(), "Connection not opened.") + + message := &cherami.PutMessage{ + ID: common.StringPtr("1"), + Data: []byte("test"), + } + requestDone := make(chan *PublisherReceipt, 1) + + messagesCh <- putMessageRequest{message, requestDone} + receipt := <-requestDone + + s.Equal(*message.ID, receipt.ID) + s.NotNil(receipt.Error) + s.Equal("Failed", receipt.Error.Error()) + // TODO: mock assert expectations is not working for some reason. + //inputHostClient.AssertExpectations(s.T()) +} + +func (s *ConnectionSuite) TestFlushFailed() { + conn, inputHostClient, messagesCh := createConnection() + + // Setup inputHostClient mock for successful message + inputHostClient.On("Write", mock.Anything).Return(nil) + inputHostClient.On("Flush").Return(errors.New("Failed")) + inputHostClient.On("Done", mock.Anything).Return(nil) + + conn.open() + s.True(conn.isOpened(), "Connection not opened.") + + message := &cherami.PutMessage{ + ID: common.StringPtr("1"), + Data: []byte("test"), + } + requestDone := make(chan *PublisherReceipt, 1) + + messagesCh <- putMessageRequest{message, requestDone} + receipt := <-requestDone + + s.Equal(*message.ID, receipt.ID) + s.NotNil(receipt.Error) + s.Equal("Failed", receipt.Error.Error()) + // TODO: mock assert expectations is not working for some reason. + //inputHostClient.AssertExpectations(s.T()) +} + +func (s *ConnectionSuite) TestAckClosedByInputHost() { + conn, inputHostClient, messagesCh := createConnection() + + // Setup inputHostClient mock for successful message + inputHostClient.On("Write", mock.Anything).Return(nil) + inputHostClient.On("Flush").Return(nil) + inputHostClient.On("Read").Return(nil, io.EOF) + inputHostClient.On("Done").Return(nil) + + conn.open() + s.True(conn.isOpened(), "Connection not opened.") + + message := &cherami.PutMessage{ + ID: common.StringPtr("1"), + Data: []byte("test"), + } + requestDone := make(chan *PublisherReceipt, 1) + + messagesCh <- putMessageRequest{message, requestDone} + receipt := <-requestDone + + s.Equal(*message.ID, receipt.ID) + s.NotNil(receipt.Error) + s.IsType(&ackChannelClosedError{}, receipt.Error) + // TODO: mock assert expectations is not working for some reason. + //inputHostClient.AssertExpectations(s.T()) +} + +func (s *ConnectionSuite) TestClientClosed() { + conn, inputHostClient, messagesCh := createConnection() + + // Setup inputHostClient mock for successful message + inputHostClient.On("Write", mock.Anything).Return(nil) + inputHostClient.On("Flush").Return(nil) + inputHostClient.On("Read").Return(wrapAckInCommand(&cherami.PutMessageAck{ + ID: common.StringPtr("1"), + Status: common.CheramiStatusPtr(cherami.Status_OK), + }), nil).After(10 * time.Millisecond).WaitUntil(time.After(100 * time.Millisecond)) + inputHostClient.On("Done").Return(nil) + + conn.open() + s.True(conn.isOpened(), "Connection not opened.") + + message := &cherami.PutMessage{ + ID: common.StringPtr("1"), + Data: []byte("test"), + } + requestDone := make(chan *PublisherReceipt, 1) + + messagesCh <- putMessageRequest{message, requestDone} + <-time.After(10 * time.Millisecond) + conn.close() + receipt := <-requestDone + + s.Equal(*message.ID, receipt.ID) + s.Nil(receipt.Error, "Expected the message to be written successfully.") + s.True(conn.isClosed()) + // TODO: mock assert expectations is not working for some reason. + //inputHostClient.AssertExpectations(s.T()) +} + +func (s *ConnectionSuite) TestOutOfOrderAcks() { + conn, inputHostClient, messagesCh := createConnection() + + // Setup inputHostClient mock for successful message + inputHostClient.On("Write", mock.Anything).Return(nil) + inputHostClient.On("Flush").Return(nil) + inputHostClient.On("Read").Return(wrapAckInCommand(&cherami.PutMessageAck{ + ID: common.StringPtr("2"), + Status: common.CheramiStatusPtr(cherami.Status_OK), + }), nil).Once() + inputHostClient.On("Read").Return(wrapAckInCommand(&cherami.PutMessageAck{ + ID: common.StringPtr("1"), + Status: common.CheramiStatusPtr(cherami.Status_OK), + }), nil).Once() + inputHostClient.On("Done").Return(nil) + + conn.open() + s.True(conn.isOpened(), "Connection not opened.") + + message1 := &cherami.PutMessage{ + ID: common.StringPtr("1"), + Data: []byte("test"), + } + requestDone1 := make(chan *PublisherReceipt, 1) + messagesCh <- putMessageRequest{message1, requestDone1} + + message2 := &cherami.PutMessage{ + ID: common.StringPtr("2"), + Data: []byte("test"), + } + requestDone2 := make(chan *PublisherReceipt, 1) + messagesCh <- putMessageRequest{message2, requestDone2} + + receipt2 := <-requestDone2 + s.Equal(*message2.ID, receipt2.ID) + + receipt1 := <-requestDone1 + s.Equal(*message1.ID, receipt1.ID) + + conn.close() + s.True(conn.isClosed()) + // TODO: mock assert expectations is not working for some reason. + //inputHostClient.AssertExpectations(s.T()) +} + +// TODO: Figure out a way to test multiple messages with mocks +/*func (s *ConnectionSuite) TestManySuccess() { + numberOfMessages := 10 + conn, inputHostClient, messagesCh := createConnection() + + w := make(chan time.Time) + // Setup inputHostClient mock for successful message + inputHostClient.On("Write", mock.Anything).Return(nil).Times(numberOfMessages) + inputHostClient.On("Flush").Return(nil).Times(numberOfMessages) + for i := 0; i < numberOfMessages; i++ { + inputHostClient.On("Read").Return(&cherami.PutMessageAck{ + ID: common.StringPtr(strconv.Itoa(i)), + Status: common.CheramiStatusPtr(cherami.Status_OK), + }, nil).WaitUntil(w).Once() +//.After(3 * time.Second).Once() + } + + conn.Open() + s.True(conn.opened, "Connection not opened.") + + done := make(chan bool, numberOfMessages) + for i := 0; i < numberOfMessages; i++ { + go func(id int) { + fmt.Printf("Writing message with id: %d\n", id) + message := &cherami.PutMessage{ + ID: strconv.Itoa(id), + Data: []byte("test"), + } + + requestDone := make(chan error, 1) + + messagesCh <- PutMessageRequest{message, requestDone} + w <- time.Now() + + err := <-requestDone + fmt.Printf("Got response for message id: %d\n", id) + s.Nil(err, "Expected the message to be written successfully.") + + fmt.Printf("Go Routine for id '%d' setting done.\n", id) + done<- true + }(i) + } + + for i := 0; i < 10; i++ { + fmt.Printf("Waiting for message: %d\n", i) + <-done + fmt.Printf("Received message: %d\n", i) + } + + // TODO: mock assert expectations is not working for some reason. + //inputHostClient.AssertExpectations(s.T()) +}*/ + +func createConnection() (*connection, *mc.MockBInOpenPublisherStreamOutCall, chan putMessageRequest) { + host := "testHost" + messagesCh := make(chan putMessageRequest) + reconfigureCh := make(chan reconfigureInfo, 10) + inputHostClient := new(mc.MockTChanBInClient) + wsConnector := new(mc.MockWSConnector) + stream := new(mc.MockBInOpenPublisherStreamOutCall) + + inputHostClient.On("OpenPublisherStream", mock.Anything).Return(stream, nil) + wsConnector.On("OpenPublisherStream", mock.Anything, mock.Anything).Return(stream, nil) + + return newConnection(inputHostClient, wsConnector, "/test/inputhostconnection", messagesCh, reconfigureCh, host, cherami.Protocol_WS, 0, bark.NewLoggerFromLogrus(log.StandardLogger()), metrics.NewNullReporter()), stream, messagesCh +} + +func wrapAckInCommand(ack *cherami.PutMessageAck) *cherami.InputHostCommand { + cmd := cherami.NewInputHostCommand() + cmd.Type = common.CheramiInputHostCommandTypePtr(cherami.InputHostCommandType_ACK) + cmd.Ack = ack + + return cmd +} diff --git a/client/cherami/consumer.go b/client/cherami/consumer.go new file mode 100644 index 0000000..0dedaf6 --- /dev/null +++ b/client/cherami/consumer.go @@ -0,0 +1,413 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "fmt" + "strings" + "sync" + "sync/atomic" + + "github.com/uber/cherami-thrift/.generated/go/cherami" + "github.com/uber/cherami-client-go/common" + "github.com/uber/cherami-client-go/common/metrics" + "github.com/uber/tchannel-go" + "github.com/uber/tchannel-go/thrift" + + "github.com/pborman/uuid" + "github.com/uber-common/bark" +) + +type ( + consumerImpl struct { + path string + consumerGroupName string + consumerName string + streamConnection *tchannel.Channel + ackConnection *tchannel.Channel + prefetchSize int + options *ClientOptions + client *clientImpl + deliveryCh chan Delivery + reconfigureCh chan reconfigureInfo + closingCh chan struct{} + isClosing int32 + logger bark.Logger + reporter metrics.Reporter + + lk sync.Mutex + opened bool + connections map[string]*outputHostConnection + wsConnector WSConnector + reconfigurable *reconfigurable + } + + deliveryID struct { + // AckowledgerID is the identifier for underlying acknowledger which can be used to Ack/Nack this delivery + AcknowledgerID string + // MessageAckID is the Ack identifier for this delivery + MessageAckID string + } +) + +func newConsumer(client *clientImpl, path, consumerGroupName, consumerName string, prefetchSize int, options *ClientOptions) Consumer { + consumer := &consumerImpl{ + client: client, + options: options, + path: path, + consumerGroupName: consumerGroupName, + consumerName: consumerName, + prefetchSize: prefetchSize, + reconfigureCh: make(chan reconfigureInfo, reconfigureChBufferSize), + closingCh: make(chan struct{}), + isClosing: 0, + logger: client.options.Logger.WithFields(bark.Fields{common.TagDstPth: common.FmtDstPth(path), common.TagCnsPth: common.FmtCnsPth(consumerGroupName)}), + reporter: client.options.MetricsReporter, + wsConnector: NewWSConnector(), + } + + return consumer +} + +// This token is generated in delivery.GetDeliveryToken. Make sure to keep both implementations in sync to +// serialize/deserialize these tokens +func newDeliveryID(token string) (deliveryID, error) { + parts := strings.Split(token, deliveryTokenSplitter) + if len(parts) != 2 { + return deliveryID{}, fmt.Errorf("Invalid delivery token: %v", token) + } + + return deliveryID{ + AcknowledgerID: parts[0], + MessageAckID: parts[1], + }, nil +} + +func (c *consumerImpl) Open(deliveryCh chan Delivery) (chan Delivery, error) { + c.lk.Lock() + defer c.lk.Unlock() + + if !c.opened { + c.deliveryCh = deliveryCh + consumerOptions, err := c.client.ReadConsumerGroupHosts(c.path, c.consumerGroupName) + if err != nil { + c.logger.Errorf("Error resolving output hosts: %v", err) + c.deliveryCh = nil + return nil, err + } + + streamCh, err := tchannel.NewChannel(uuid.New(), nil) + if err != nil { + return nil, err + } + c.streamConnection = streamCh + + // Create a separate connection for ack messages call + ackCh, err := tchannel.NewChannel(uuid.New(), nil) + if err != nil { + return nil, err + } + c.ackConnection = ackCh + + // pick best protocol from what server suggested + hostProtocols := consumerOptions.GetHostProtocols() + chosenIdx, tchanProtocolIdx, err := c.chooseProcotol(hostProtocols) + chosenProtocol := cherami.Protocol_WS + chosenHostAddresses := consumerOptions.GetHostAddresses() + tchanHostAddresses := consumerOptions.GetHostAddresses() + if err == nil { + chosenProtocol = hostProtocols[chosenIdx].GetProtocol() + chosenHostAddresses = hostProtocols[chosenIdx].GetHostAddresses() + tchanHostAddresses = hostProtocols[tchanProtocolIdx].GetHostAddresses() + } + + c.connections = make(map[string]*outputHostConnection) + for idx, host := range chosenHostAddresses { + connKey := common.GetConnectionKey(host) + // ReadConsumerGroupHosts can return duplicates, so we need to dedupe to make sure we create a single connection for each host + if _, ok := c.connections[connKey]; !ok { + conn, err := c.createOutputHostConnection(common.GetConnectionKey(tchanHostAddresses[idx]), connKey, chosenProtocol) + if err != nil { + if conn != nil { + closeConnection(conn) + } + + c.logger.Errorf("Error opening outputhost connection on %v:%v: %v", host.GetHost(), host.GetPort(), err) + // TODO: We should be returning failure only when no connections could be opened to OutputHost + c.deliveryCh = nil + return nil, err + } + + // Add the connection to map used by Ack/Nack API to find the connection using DeliveryId + c.connections[conn.GetAcknowledgerID()] = conn + } + } + c.reporter.UpdateGauge(metrics.ConsumeNumConnections, nil, int64(len(c.connections))) + + c.reconfigurable = newReconfigurable(c.reconfigureCh, c.closingCh, c.reconfigureConsumer, c.logger) + go c.reconfigurable.reconfigurePump() + + c.opened = true + } + + return c.deliveryCh, nil +} + +func (c *consumerImpl) Close() { + // TODO: ideally this should be synchronized, i.e. wait until all + // connections are properly shutdown, so that we make sure that + // nothing gets written to c.deliveyCh afterwards, because owner of that + // channel would likely close it after Close() returns. + if atomic.CompareAndSwapInt32(&c.isClosing, 0, 1) { + close(c.closingCh) + } else { + return + } + + c.lk.Lock() + defer c.lk.Unlock() + if c.connections != nil { + for _, outputHostConn := range c.connections { + closeConnection(outputHostConn) + } + c.reporter.UpdateGauge(metrics.ConsumeNumConnections, nil, 0) + } + + if c.streamConnection != nil { + c.streamConnection.Close() + } + if c.ackConnection != nil { + c.ackConnection.Close() + } + + c.opened = false +} + +func (c *consumerImpl) AckDelivery(token string) error { + acknowledger, id, err := c.getAcknowledger(token) + if err != nil { + return err + } + + return acknowledger.Ack([]string{id.MessageAckID}) +} + +func (c *consumerImpl) NackDelivery(token string) error { + acknowledger, id, err := c.getAcknowledger(token) + if err != nil { + return err + } + + return acknowledger.Nack([]string{id.MessageAckID}) +} + +func (c *consumerImpl) reconfigureConsumer() { + c.lk.Lock() + defer c.lk.Unlock() + + select { + case <-c.closingCh: + c.logger.Info("Consumer is closing. Ignore reconfiguration.") + default: + var conn *outputHostConnection + + consumerOptions, err := c.client.ReadConsumerGroupHosts(c.path, c.consumerGroupName) + if err != nil { + c.logger.Warnf("Error resolving output hosts: %v", err) + if _, ok := err.(*cherami.EntityNotExistsError); ok { + // ConsumerGroup is deleted. Continue with reconfigure and close all connections + consumerOptions = &cherami.ReadConsumerGroupHostsResult_{} + } else { + // This is a potentially a transient error. + // Retry on next reconfigure + return + } + } + + // pick best protocol from what server suggested + // note: tchannel is must have here since AckMessage is still using tchannel (non-streaming) + // tchanProtocol stores tchannel hosts, same order as other protocols + hostProtocols := consumerOptions.GetHostProtocols() + chosenIdx, tchanProtocolIdx, err := c.chooseProcotol(hostProtocols) + chosenProtocol := cherami.Protocol_WS + chosenHostAddresses := consumerOptions.GetHostAddresses() + tchanHostAddresses := consumerOptions.GetHostAddresses() + if err == nil { + chosenProtocol = hostProtocols[chosenIdx].GetProtocol() + chosenHostAddresses = hostProtocols[chosenIdx].GetHostAddresses() + tchanHostAddresses = hostProtocols[tchanProtocolIdx].GetHostAddresses() + } + + // First remove any closed connections from the connections map + for existingConnKey, existingConn := range c.connections { + if existingConn.isClosed() { + c.logger.WithField(common.TagHostIP, common.FmtHostIP(existingConnKey)).Info("Removing closed connection from cache.") + closeConnection(existingConn) + delete(c.connections, existingConnKey) + c.logger.WithField(common.TagHostIP, common.FmtHostIP(existingConn.connKey)).Info("Removed connection from cache.") + } + } + + currentHosts := make(map[string]*outputHostConnection) + for idx, host := range chosenHostAddresses { + connKey := common.GetConnectionKey(host) + conn = c.connections[connKey] + if conn == nil || conn.isClosed() { + // Newly assigned host, create a connection + connLogger := c.logger.WithField(common.TagHostIP, common.FmtHostIP(connKey)) + connLogger.Info("Discovered new OutputHost during reconfiguration.") + conn, err = c.createOutputHostConnection(common.GetConnectionKey(tchanHostAddresses[idx]), connKey, chosenProtocol) + if err != nil { + connLogger.Info("Error creating connection to OutputHost after reconfiguration.") + if conn != nil { + closeConnection(conn) + } + } else { + connLogger.Info("Successfully created connection to OutputHost after reconfiguration.") + // Successfully created a connection to new host. Add it to current list of output hosts + currentHosts[conn.GetAcknowledgerID()] = conn + } + } else { + // Existing output host connection, copy it over to current collection of output hosts + currentHosts[conn.GetAcknowledgerID()] = conn + } + } + + // Now close all remaining list of output host connections + for host, outputHostConn := range c.connections { + if _, ok := currentHosts[host]; !ok { + connLogger := c.logger.WithField(common.TagHostIP, common.FmtHostIP(outputHostConn.connKey)) + connLogger.Info("Closing connection to OutputHost after reconfiguration.") + closeConnection(outputHostConn) + } + } + + connectedHosts := make([]string, len(currentHosts)) + for k := range currentHosts { + connectedHosts = append(connectedHosts, k) + } + c.logger.WithField(common.TagHosts, connectedHosts).Debug("List of connected output hosts.") + // Now assign the list of current output hosts to list of connections + c.connections = currentHosts + c.reporter.UpdateGauge(metrics.ConsumeNumConnections, nil, int64(len(c.connections))) + } +} + +func (c *consumerImpl) createOutputHostConnection(tchanHostPort string, connKey string, protocol cherami.Protocol) (*outputHostConnection, error) { + connLogger := c.logger.WithField(common.TagHostIP, common.FmtHostIP(connKey)) + + // TODO [ljj] to be removed once moved to websocket + client, err := createOutputHostClient(c.streamConnection, connKey) + if err != nil { + connLogger.Infof("Error creating OutputHost client: %v", err) + return nil, err + } + + // We use a separate connection for acks to make sure response for acks won't get blocked behind streaming messages + ackClient, err := createOutputHostClient(c.ackConnection, tchanHostPort) + if err != nil { + connLogger.Infof("Error creating AckClient: %v", err) + return nil, err + } + + conn := newOutputHostConnection(client, ackClient, c.wsConnector, c.path, c.consumerGroupName, c.options, c.deliveryCh, + c.reconfigureCh, connKey, protocol, int32(c.prefetchSize), connLogger, c.reporter) + + // Now open the connection + err = conn.open() + if err != nil { + connLogger.Infof("Error opening OutputHost connection: %v", err) + return nil, err + } + + return conn, nil +} + +func (c *consumerImpl) getAcknowledger(token string) (*outputHostConnection, deliveryID, error) { + id, err := newDeliveryID(token) + if err != nil { + return nil, deliveryID{}, err + } + + c.lk.Lock() + acknowledger, ok := c.connections[id.AcknowledgerID] + c.lk.Unlock() + if !ok { + return nil, id, fmt.Errorf("Cannot Ack/Nack message '%s'. Acknowledger Id '%s' not found.", + id.MessageAckID, id.AcknowledgerID) + } + + return acknowledger, id, nil +} + +func (c *consumerImpl) chooseProcotol(hostProtocols []*cherami.HostProtocol) (int, int, error) { + clientSupportedProtocol := map[cherami.Protocol]bool{cherami.Protocol_WS: true} + clientSupportButDeprecated := -1 + serverSupportedProtocol := make([]cherami.Protocol, 0, len(hostProtocols)) + + // tchannel is needed for ack message + tchanProtocolIdx := -1 + for idx, hostProtocol := range hostProtocols { + if hostProtocol.GetProtocol() == cherami.Protocol_TCHANNEL { + tchanProtocolIdx = idx + break + } + } + if tchanProtocolIdx == -1 { + return -1, -1, &cherami.BadRequestError{Message: `TChannel is needed for client to ack message`} + } + + for idx, hostProtocol := range hostProtocols { + serverSupportedProtocol = append(serverSupportedProtocol, hostProtocol.GetProtocol()) + if _, found := clientSupportedProtocol[hostProtocol.GetProtocol()]; found { + if !hostProtocol.GetDeprecated() { + // found first supported and non-deprecated one, done + return idx, tchanProtocolIdx, nil + } else if clientSupportButDeprecated == -1 { + // found first supported but deprecated one, keep looking + clientSupportButDeprecated = idx + } + } + } + + if clientSupportButDeprecated == -1 { + c.logger.WithField(`protocols`, serverSupportedProtocol).Error("No protocol is supported by client") + return -1, -1, &cherami.BadRequestError{Message: `No protocol is supported by client`} + } + + c.logger.WithField(`protocol`, hostProtocols[clientSupportButDeprecated].GetProtocol()).Warn("Client using deprecated protocol") + return clientSupportButDeprecated, tchanProtocolIdx, nil +} + +func createOutputHostClient(ch *tchannel.Channel, hostPort string) (cherami.TChanBOut, error) { + tClient := thrift.NewClient(ch, common.OutputServiceName, &thrift.ClientOptions{ + HostPort: hostPort, + }) + client := cherami.NewTChanBOutClient(tClient) + + return client, nil +} + +func closeConnection(conn *outputHostConnection) { + conn.close() + // This is necessary to shutdown writeAcksPump within the connection + conn.closeAcksBatchCh() +} diff --git a/client/cherami/delivery.go b/client/cherami/delivery.go new file mode 100644 index 0000000..3c22113 --- /dev/null +++ b/client/cherami/delivery.go @@ -0,0 +1,91 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "bytes" + "crypto/md5" + "hash/crc32" + "strings" + + "github.com/uber/cherami-thrift/.generated/go/cherami" +) + +type ( + deliveryImpl struct { + message *cherami.ConsumerMessage + acknowledger acknowledger + } + + // acknowledger can be used to Ack/Nack messages from a specific OutputHost + acknowledger interface { + GetAcknowledgerID() string + Ack(ids []string) error + Nack(ids []string) error + } +) + +const ( + deliveryTokenSplitter = "|" +) + +func newDelivery(msg *cherami.ConsumerMessage, acknowledger acknowledger) Delivery { + return &deliveryImpl{ + message: msg, + acknowledger: acknowledger, + } +} + +func (d *deliveryImpl) GetMessage() *cherami.ConsumerMessage { + return d.message +} + +// This token is parsed in consumer.newDeliveryID +// Make sure to keep both implementations in sync to serialize/deserialize these tokens +func (d *deliveryImpl) GetDeliveryToken() string { + return strings.Join([]string{d.acknowledger.GetAcknowledgerID(), d.message.GetAckId()}, deliveryTokenSplitter) +} + +func (d *deliveryImpl) Ack() error { + return d.acknowledger.Ack([]string{d.message.GetAckId()}) +} + +func (d *deliveryImpl) Nack() error { + return d.acknowledger.Nack([]string{d.message.GetAckId()}) +} + +func (d *deliveryImpl) VerifyChecksum() bool { + if d.message.IsSetPayload() == false { + return false + } + payload := d.message.GetPayload() + + if payload.IsSetCrc32IEEEDataChecksum() { + gotChecksum := int64(crc32.ChecksumIEEE(payload.GetData())) + return gotChecksum == payload.GetCrc32IEEEDataChecksum() + } else if payload.IsSetMd5DataChecksum() { + gotChecksum := md5.Sum(payload.GetData()) + return bytes.Equal(gotChecksum[:], payload.GetMd5DataChecksum()) + } + + // no known checksum provided, just pass + return true +} diff --git a/client/cherami/delivery_test.go b/client/cherami/delivery_test.go new file mode 100644 index 0000000..bdcae9f --- /dev/null +++ b/client/cherami/delivery_test.go @@ -0,0 +1,96 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "crypto/md5" + "hash/crc32" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/uber/cherami-thrift/.generated/go/cherami" + "github.com/uber/cherami-client-go/common" + + "github.com/stretchr/testify/require" +) + +type DeliverySuite struct { + *require.Assertions // override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test, not merely log an error + suite.Suite +} + +func TestDeliverySuite(t *testing.T) { + suite.Run(t, new(DeliverySuite)) +} + +func (s *DeliverySuite) SetupTest() { + s.Assertions = require.New(s.T()) // Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil +} + +func (s *DeliverySuite) TestVerifyChecksumCrc32IEEE() { + delivery := newTestDelivery() + // 3957769958 + delivery.GetMessage().GetPayload().Crc32IEEEDataChecksum = common.Int64Ptr(int64(crc32.ChecksumIEEE(delivery.GetMessage().GetPayload().GetData()))) + s.True(delivery.VerifyChecksum(), "Crc32IEEE checksum verification failed") +} + +func (s *DeliverySuite) TestVerifyChecksumCrc32IEEEFail() { + delivery := newTestDelivery() + delivery.GetMessage().GetPayload().Crc32IEEEDataChecksum = common.Int64Ptr(int64(123)) + s.False(delivery.VerifyChecksum(), "Crc32IEEE checksum verification failed") +} + +func (s *DeliverySuite) TestVerifyChecksumMd5() { + delivery := newTestDelivery() + // 6CD3556DEB0DA54BCA060B4C39479839 + md5Checksum := md5.Sum(delivery.GetMessage().GetPayload().GetData()) + delivery.GetMessage().GetPayload().Md5DataChecksum = md5Checksum[:] + s.True(delivery.VerifyChecksum(), "Md5 checksum verification failed") +} + +func (s *DeliverySuite) TestVerifyChecksumMd5Fail() { + delivery := newTestDelivery() + delivery.GetMessage().GetPayload().Md5DataChecksum = []byte("0123456789ABCDEF") + s.False(delivery.VerifyChecksum(), "Md5 checksum verification failed") +} + +func (s *DeliverySuite) TestVerifyChecksumNone() { + delivery := newTestDelivery() + s.True(delivery.VerifyChecksum(), "None checksum verification should just pass") +} + +func (s *DeliverySuite) TestVerifyChecksumMultiple() { + delivery := newTestDelivery() + // checksum verification is done with ordering + // since crc32 IEEE checksum verification is done firstly, other invalid checksum wont matter + delivery.GetMessage().GetPayload().Crc32IEEEDataChecksum = common.Int64Ptr(int64(crc32.ChecksumIEEE(delivery.GetMessage().GetPayload().GetData()))) + delivery.GetMessage().GetPayload().Md5DataChecksum = []byte("123") + s.True(delivery.VerifyChecksum(), "Crc32IEEE checksum verification shoud result in pass") +} + +func newTestDelivery() Delivery { + putMessage := cherami.NewPutMessage() + putMessage.Data = []byte("Hello, world!") + consumerMessage := cherami.NewConsumerMessage() + consumerMessage.Payload = putMessage + return newDelivery(consumerMessage, nil) +} diff --git a/client/cherami/interfaces.go b/client/cherami/interfaces.go new file mode 100644 index 0000000..f71c464 --- /dev/null +++ b/client/cherami/interfaces.go @@ -0,0 +1,249 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "time" + + "github.com/uber/cherami-thrift/.generated/go/cherami" + "github.com/uber/cherami-client-go/common/metrics" + + "github.com/uber-common/bark" +) + +type ( + // Client exposes API for destination and consumer group CRUD and capability to publish and consume messages + Client interface { + Close() + CreateConsumerGroup(request *cherami.CreateConsumerGroupRequest) (*cherami.ConsumerGroupDescription, error) + CreateDestination(request *cherami.CreateDestinationRequest) (*cherami.DestinationDescription, error) + CreateConsumer(request *CreateConsumerRequest) Consumer + CreatePublisher(request *CreatePublisherRequest) Publisher + DeleteConsumerGroup(request *cherami.DeleteConsumerGroupRequest) error + DeleteDestination(request *cherami.DeleteDestinationRequest) error + ListConsumerGroups(request *cherami.ListConsumerGroupRequest) (*cherami.ListConsumerGroupResult_, error) + ListDestinations(request *cherami.ListDestinationsRequest) (*cherami.ListDestinationsResult_, error) + ReadConsumerGroup(request *cherami.ReadConsumerGroupRequest) (*cherami.ConsumerGroupDescription, error) + ReadDestination(request *cherami.ReadDestinationRequest) (*cherami.DestinationDescription, error) + UpdateConsumerGroup(request *cherami.UpdateConsumerGroupRequest) (*cherami.ConsumerGroupDescription, error) + UpdateDestination(request *cherami.UpdateDestinationRequest) (*cherami.DestinationDescription, error) + GetQueueDepthInfo(request *cherami.GetQueueDepthInfoRequest) (*cherami.GetQueueDepthInfoResult_, error) + MergeDLQForConsumerGroup(request *cherami.MergeDLQForConsumerGroupRequest) error + PurgeDLQForConsumerGroup(request *cherami.PurgeDLQForConsumerGroupRequest) error + ReadPublisherOptions(path string) (*cherami.ReadPublisherOptionsResult_, error) + ReadConsumerGroupHosts(path string, consumerGroupName string) (*cherami.ReadConsumerGroupHostsResult_, error) + } + + // Publisher is used by an application to publish messages to Cherami service + Publisher interface { + Open() error + Close() + Publish(message *PublisherMessage) *PublisherReceipt + PublishAsync(message *PublisherMessage, done chan<- *PublisherReceipt) (string, error) + } + + // PublisherMessage is a struct that wraps the message payload and a delay time duration. + PublisherMessage struct { + Data []byte + Delay time.Duration + // UserContext is user specified context to pass through + UserContext map[string]string + } + + // PublisherReceipt is an token for publisher as the prove of message being + // durably stored. + PublisherReceipt struct { + // ID is the message id passed with message when published + ID string + // Receipt is a token that contains info where the message is stored + Receipt string + // Error is the error if any that associates with the publishing of this message + Error error + // UserContext is user specified context to pass through + UserContext map[string]string + } + + // Consumer is used by an application to receive messages from Cherami service + Consumer interface { + // Open will connect to Cherami nodes and start delivering messages to + // a provided Delivery channel for registered consumer group. + // + // It is ADVISED that deliveryCh's buffer size should be bigger than the + // total PrefetchCount in CreateConsumerRequest of the consumers writing + // to this channel. + Open(deliveryCh chan Delivery) (chan Delivery, error) + // Closed all the connections to Cherami nodes for this consumer + Close() + // AckDelivery can be used by application to Ack a message so it is not delivered to any other consumer + AckDelivery(deliveryToken string) error + // NackDelivery can be used by application to Nack a message so it can be delivered to another consumer immediately + // without waiting for the timeout to expire + NackDelivery(deliveryToken string) error + } + + // Delivery is the container which has the actual message returned by Cherami + Delivery interface { + // Returns the message returned by Cherami + GetMessage() *cherami.ConsumerMessage + // Returns a delivery token which can be used to Ack/Nack delivery using the Consumer API + // Consumer has 2 options to Ack/Nack a delivery: + // 1) Simply call the Ack/Nack API on the delivery after processing the message + // 2) If the consumer wants to forward the message to downstream component for processing then they can get the + // DeliveryToken by calling this function and pass it along. Later the downstream component can call the + // API on the Consumer with this token to Ack/Nack the message. + GetDeliveryToken() string + // Acks this delivery + Ack() error + // Nacks this delivery + Nack() error + // VerifyChecksum verifies checksum of the message if exist + // Consumer needs to perform this verification and decide what to do based on returned result + VerifyChecksum() bool + } + + // CreatePublisherRequest struct used to call Client.CreatePublisher to create an object used by application to publish messages + CreatePublisherRequest struct { + Path string + MaxInflightMessagesPerConnection int + // PublisherType represents the mode in which + // publishing should be done i.e. either through + // websocket streaming or through tchannel batch API + // Defaults to websocket streaming. Choose non-streaming + // batch API for low throughput publishing. + PublisherType PublisherType + } + + // CreateConsumerRequest struct is used to call Client.CreateConsumer to create an object used by application to + // consume messages + CreateConsumerRequest struct { + // Path to destination consumer wants to consume messages from + Path string + // ConsumerGroupName registered with Cherami for a particular destination + ConsumerGroupName string + // Name of consumer (worker) connecting to Cherami + ConsumerName string + // Number of messages to buffer locally. Clients which process messages very fast may want to specify larger value + // for PrefetchCount for faster throughput. On the flip side larger values for PrefetchCount will result in + // more messages being buffered locally causing high memory foot print + PrefetchCount int + // Options used for making API calls to Cherami services + Options *ClientOptions + } + + // PublisherType represents the type of publisher viz. streaming/non-streaming + PublisherType int + + // ClientOptions used by Cherami client + ClientOptions struct { + Timeout time.Duration + // DeploymentStr specifies which deployment(staging,prod,dev,etc) the client should connect to + // If the string is empty, client will connect to prod + // If the string is 'prod', client will connect to prod + // If the string is 'staging' or 'staging2', client will connect to staging or staging2 + // If the string is 'dev', client will connect to dev server + DeploymentStr string + // MetricsReporter is the reporter object + MetricsReporter metrics.Reporter + // Logger is the logger object + Logger bark.Logger + } + + // Task represents the task queued in Cherami + Task interface { + // GetType returns the unique type name that can be used to identify cooresponding task handler + GetType() string + // GetID returns the unique identifier of this specific task + GetID() string + // GetValue deserializes task value into given struct that matches the type used to publish the task + GetValue(instance interface{}) error + // GetContext returns key value pairs context accosicated with the task when published + GetContext() map[string]string + } + + // TaskFunc is function signature of task handler + TaskFunc func(task Task) error + + // TaskScheduler is used to put tasks into Cherami + TaskScheduler interface { + // Open gets TaskScheduler for scheduling tasks + Open() error + // Close make sure resources are released + Close() + // ScheduleTask enqueues a task + ScheduleTask(request *ScheduleTaskRequest) error + } + + // TaskExecutor is used to pull tasks from Cherami and execute their task handlers accordingly + TaskExecutor interface { + // Register registers task handler with its *unique* task type + Register(taskType string, taskFunc TaskFunc) + // Start starts dequeuing tasks and execute them + Start() error + // Stop stops dequeuing/exeuction of tasks + // There's no guarantee to drain scheduled tasks when Stop is invoked + Stop() + } + + // CreateTaskSchedulerRequest is used to call Client.CreateTaskScheduler to create a task scheduler + CreateTaskSchedulerRequest struct { + // Path to destination which tasks enqueue into + Path string + // MaxInflightMessagesPerConnection is number of messages pending confirmation per connection + MaxInflightMessagesPerConnection int + } + + // CreateTaskExecutorRequest is used to call Client.CreateTaskExecutor to create a task executor + CreateTaskExecutorRequest struct { + // Concurrency is the number of concurrent workers to execute tasks + Concurrency int + // Path to destination which tasks dequeued from + Path string + // ConsumerGroupName registered with Cherami for a particular destination + ConsumerGroupName string + // ConsumerName is name of consumer (worker) connecting to Cherami + ConsumerName string + // PrefetchCount is number of messages to buffer locally + PrefetchCount int + // Timeout is timeout setting used when ack/nack back to Cherami + Timeout time.Duration + } + + // ScheduleTaskRequest is used to call TaskScheduler.ScheduleTask to schedule a new task + ScheduleTaskRequest struct { + // TaskType is the unique type name which is used to register task handler with task executor + TaskType string + // TaskID is the unique identifier of this specific task + TaskID string + // TaskValue can be anything represent the task + TaskValue interface{} + // Context is key value pairs context accosicated with the task + Context map[string]string + // Delay is the time duration before task can be executed + Delay time.Duration + } +) + +const ( + // PublisherTypeStreaming indicates a publisher that uses websocket streaming + PublisherTypeStreaming PublisherType = iota + // PublisherTypeNonStreaming indicates a publisher that uses tchannel batch api + PublisherTypeNonStreaming +) diff --git a/client/cherami/mockClient_test.go b/client/cherami/mockClient_test.go new file mode 100644 index 0000000..e3b1fd4 --- /dev/null +++ b/client/cherami/mockClient_test.go @@ -0,0 +1,370 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import "github.com/stretchr/testify/mock" +import "github.com/uber/cherami-thrift/.generated/go/cherami" + +type mockClient struct { + mock.Mock +} + +// Close provides a mock function with given fields: +func (_m *mockClient) Close() { + _m.Called() +} + +// CreateConsumerGroup provides a mock function with given fields: request +func (_m *mockClient) CreateConsumerGroup(request *cherami.CreateConsumerGroupRequest) (*cherami.ConsumerGroupDescription, error) { + ret := _m.Called(request) + + var r0 *cherami.ConsumerGroupDescription + if rf, ok := ret.Get(0).(func(*cherami.CreateConsumerGroupRequest) *cherami.ConsumerGroupDescription); ok { + r0 = rf(request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*cherami.ConsumerGroupDescription) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*cherami.CreateConsumerGroupRequest) error); ok { + r1 = rf(request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CreateDestination provides a mock function with given fields: request +func (_m *mockClient) CreateDestination(request *cherami.CreateDestinationRequest) (*cherami.DestinationDescription, error) { + ret := _m.Called(request) + + var r0 *cherami.DestinationDescription + if rf, ok := ret.Get(0).(func(*cherami.CreateDestinationRequest) *cherami.DestinationDescription); ok { + r0 = rf(request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*cherami.DestinationDescription) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*cherami.CreateDestinationRequest) error); ok { + r1 = rf(request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CreateConsumer provides a mock function with given fields: request +func (_m *mockClient) CreateConsumer(request *CreateConsumerRequest) Consumer { + ret := _m.Called(request) + + var r0 Consumer + if rf, ok := ret.Get(0).(func(*CreateConsumerRequest) Consumer); ok { + r0 = rf(request) + } else { + r0 = ret.Get(0).(Consumer) + } + + return r0 +} + +// CreatePublisher provides a mock function with given fields: request +func (_m *mockClient) CreatePublisher(request *CreatePublisherRequest) Publisher { + ret := _m.Called(request) + + var r0 Publisher + if rf, ok := ret.Get(0).(func(*CreatePublisherRequest) Publisher); ok { + r0 = rf(request) + } else { + r0 = ret.Get(0).(Publisher) + } + + return r0 +} + +// DeleteConsumerGroup provides a mock function with given fields: request +func (_m *mockClient) DeleteConsumerGroup(request *cherami.DeleteConsumerGroupRequest) error { + ret := _m.Called(request) + + var r0 error + if rf, ok := ret.Get(0).(func(*cherami.DeleteConsumerGroupRequest) error); ok { + r0 = rf(request) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DeleteDestination provides a mock function with given fields: request +func (_m *mockClient) DeleteDestination(request *cherami.DeleteDestinationRequest) error { + ret := _m.Called(request) + + var r0 error + if rf, ok := ret.Get(0).(func(*cherami.DeleteDestinationRequest) error); ok { + r0 = rf(request) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ListConsumerGroups provides a mock function with given fields: request +func (_m *mockClient) ListConsumerGroups(request *cherami.ListConsumerGroupRequest) (*cherami.ListConsumerGroupResult_, error) { + ret := _m.Called(request) + + var r0 *cherami.ListConsumerGroupResult_ + if rf, ok := ret.Get(0).(func(*cherami.ListConsumerGroupRequest) *cherami.ListConsumerGroupResult_); ok { + r0 = rf(request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*cherami.ListConsumerGroupResult_) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*cherami.ListConsumerGroupRequest) error); ok { + r1 = rf(request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ListDestinations provides a mock function with given fields: request +func (_m *mockClient) ListDestinations(request *cherami.ListDestinationsRequest) (*cherami.ListDestinationsResult_, error) { + ret := _m.Called(request) + + var r0 *cherami.ListDestinationsResult_ + if rf, ok := ret.Get(0).(func(*cherami.ListDestinationsRequest) *cherami.ListDestinationsResult_); ok { + r0 = rf(request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*cherami.ListDestinationsResult_) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*cherami.ListDestinationsRequest) error); ok { + r1 = rf(request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ReadConsumerGroup provides a mock function with given fields: request +func (_m *mockClient) ReadConsumerGroup(request *cherami.ReadConsumerGroupRequest) (*cherami.ConsumerGroupDescription, error) { + ret := _m.Called(request) + + var r0 *cherami.ConsumerGroupDescription + if rf, ok := ret.Get(0).(func(*cherami.ReadConsumerGroupRequest) *cherami.ConsumerGroupDescription); ok { + r0 = rf(request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*cherami.ConsumerGroupDescription) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*cherami.ReadConsumerGroupRequest) error); ok { + r1 = rf(request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ReadDestination provides a mock function with given fields: request +func (_m *mockClient) ReadDestination(request *cherami.ReadDestinationRequest) (*cherami.DestinationDescription, error) { + ret := _m.Called(request) + + var r0 *cherami.DestinationDescription + if rf, ok := ret.Get(0).(func(*cherami.ReadDestinationRequest) *cherami.DestinationDescription); ok { + r0 = rf(request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*cherami.DestinationDescription) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*cherami.ReadDestinationRequest) error); ok { + r1 = rf(request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdateConsumerGroup provides a mock function with given fields: request +func (_m *mockClient) UpdateConsumerGroup(request *cherami.UpdateConsumerGroupRequest) (*cherami.ConsumerGroupDescription, error) { + ret := _m.Called(request) + + var r0 *cherami.ConsumerGroupDescription + if rf, ok := ret.Get(0).(func(*cherami.UpdateConsumerGroupRequest) *cherami.ConsumerGroupDescription); ok { + r0 = rf(request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*cherami.ConsumerGroupDescription) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*cherami.UpdateConsumerGroupRequest) error); ok { + r1 = rf(request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdateDestination provides a mock function with given fields: request +func (_m *mockClient) UpdateDestination(request *cherami.UpdateDestinationRequest) (*cherami.DestinationDescription, error) { + ret := _m.Called(request) + + var r0 *cherami.DestinationDescription + if rf, ok := ret.Get(0).(func(*cherami.UpdateDestinationRequest) *cherami.DestinationDescription); ok { + r0 = rf(request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*cherami.DestinationDescription) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*cherami.UpdateDestinationRequest) error); ok { + r1 = rf(request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetQueueDepthInfo provides a mock function with given fields: request +func (_m *mockClient) GetQueueDepthInfo(request *cherami.GetQueueDepthInfoRequest) (*cherami.GetQueueDepthInfoResult_, error) { + ret := _m.Called(request) + + var r0 *cherami.GetQueueDepthInfoResult_ + if rf, ok := ret.Get(0).(func(*cherami.GetQueueDepthInfoRequest) *cherami.GetQueueDepthInfoResult_); ok { + r0 = rf(request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*cherami.GetQueueDepthInfoResult_) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*cherami.GetQueueDepthInfoRequest) error); ok { + r1 = rf(request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MergeDLQForConsumerGroup provides a mock function with given fields: request +func (_m *mockClient) MergeDLQForConsumerGroup(request *cherami.MergeDLQForConsumerGroupRequest) error { + ret := _m.Called(request) + + var r0 error + if rf, ok := ret.Get(0).(func(*cherami.MergeDLQForConsumerGroupRequest) error); ok { + r0 = rf(request) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// PurgeDLQForConsumerGroup provides a mock function with given fields: request +func (_m *mockClient) PurgeDLQForConsumerGroup(request *cherami.PurgeDLQForConsumerGroupRequest) error { + ret := _m.Called(request) + + var r0 error + if rf, ok := ret.Get(0).(func(*cherami.PurgeDLQForConsumerGroupRequest) error); ok { + r0 = rf(request) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ReadPublisherOptions provides a mock function with given fields: path +func (_m *mockClient) ReadPublisherOptions(path string) (*cherami.ReadPublisherOptionsResult_, error) { + ret := _m.Called(path) + + var r0 *cherami.ReadPublisherOptionsResult_ + if rf, ok := ret.Get(0).(func(string) *cherami.ReadPublisherOptionsResult_); ok { + r0 = rf(path) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*cherami.ReadPublisherOptionsResult_) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(path) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ReadConsumerGroupHosts provides a mock function with given fields: path, consumerGroupName +func (_m *mockClient) ReadConsumerGroupHosts(path string, consumerGroupName string) (*cherami.ReadConsumerGroupHostsResult_, error) { + ret := _m.Called(path, consumerGroupName) + + var r0 *cherami.ReadConsumerGroupHostsResult_ + if rf, ok := ret.Get(0).(func(string, string) *cherami.ReadConsumerGroupHostsResult_); ok { + r0 = rf(path, consumerGroupName) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*cherami.ReadConsumerGroupHostsResult_) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(path, consumerGroupName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/client/cherami/outputhostconnection.go b/client/cherami/outputhostconnection.go new file mode 100644 index 0000000..aadca4b --- /dev/null +++ b/client/cherami/outputhostconnection.go @@ -0,0 +1,360 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "net/http" + "sync" + "sync/atomic" + "time" + + "github.com/uber/cherami-thrift/.generated/go/cherami" + "github.com/uber/cherami-client-go/common" + "github.com/uber/cherami-client-go/common/metrics" + "github.com/uber/cherami-client-go/stream" + + "github.com/uber-common/bark" + "github.com/uber/tchannel-go/thrift" + "golang.org/x/net/context" +) + +type ( + outputHostConnection struct { + outputHostClient cherami.TChanBOut + ackClient cherami.TChanBOut + wsConnector WSConnector + path string + consumerGroupName string + options *ClientOptions + deliveryCh chan<- Delivery + prefetchSize int32 + creditBatchSize int32 + outputHostStream stream.BOutOpenConsumerStreamOutCall + cancel context.CancelFunc + closeChannel chan struct{} + connKey string + protocol cherami.Protocol + reconfigureCh chan<- reconfigureInfo + creditsCh chan int32 + logger bark.Logger + reporter metrics.Reporter + + lk sync.Mutex + opened int32 + closed int32 + + // We need this lock to protect writes to acksBatchCh after it getting closed + acksBatchLk sync.RWMutex + acksBatchCh chan []string + acksBatchClosed chan struct{} + } +) + +const ( + creditsChBuffer = 10 + ackBatchSize = 9 // ackIds are about 133 bytes each and MTU is about 1500, so try to keep it to one packet full + ackBatchDelay = time.Second / 10 +) + +func newOutputHostConnection(client cherami.TChanBOut, ackClient cherami.TChanBOut, wsConnector WSConnector, + path, consumerGroupName string, options *ClientOptions, deliveryCh chan<- Delivery, + reconfigureCh chan<- reconfigureInfo, connKey string, protocol cherami.Protocol, + prefetchSize int32, logger bark.Logger, reporter metrics.Reporter) *outputHostConnection { + + creditBatchSize := prefetchSize / 10 + if creditBatchSize < 1 { + creditBatchSize = 1 + } + + // Don't check for nil options; better to panic here than panic later + common.ValidateTimeout(options.Timeout) + + return &outputHostConnection{ + connKey: connKey, + protocol: protocol, + outputHostClient: client, + ackClient: ackClient, + wsConnector: wsConnector, + path: path, + consumerGroupName: consumerGroupName, + options: options, + deliveryCh: deliveryCh, + prefetchSize: prefetchSize, + creditBatchSize: creditBatchSize, + closeChannel: make(chan struct{}), + reconfigureCh: reconfigureCh, + creditsCh: make(chan int32, creditsChBuffer), + acksBatchCh: make(chan []string, ackBatchSize*2), + acksBatchClosed: make(chan struct{}), + logger: logger, + reporter: reporter, + } +} + +func (conn *outputHostConnection) open() error { + conn.lk.Lock() + defer conn.lk.Unlock() + + if atomic.LoadInt32(&conn.opened) == 0 { + switch conn.protocol { + case cherami.Protocol_WS: + conn.logger.Infof("Using websocket to connect to output host %s", conn.connKey) + wsStream, err := conn.wsConnector.OpenConsumerStream(conn.connKey, http.Header{ + "path": {conn.path}, + "consumerGroupName": {conn.consumerGroupName}, + }) + if err != nil { + conn.logger.Infof("Error opening websocket connection to output host %s: %v", conn.connKey, err) + return err + } + + conn.outputHostStream = wsStream + conn.cancel = nil + + default: + return &cherami.BadRequestError{Message: `Protocol not supported`} + } + + // Now start the message pump + go conn.readMessagesPump() + go conn.writeCreditsPump() + // We only bail out of this pump when acksBatchCh is closed by consumerImpl after this connections is removed. + // This is the only guarantee we will not receive more acks on the channel and it is safe to shutdown the pump. + // Closing the pump earlier has the potential to cause deadlock between consumer writing acks and connection writing + // messages to deliveryCh. + go conn.writeAcksPump() + + atomic.StoreInt32(&conn.opened, 1) + conn.logger.Info("Output host connection opened.") + } + + return nil +} + +func (conn *outputHostConnection) close() { + conn.lk.Lock() + defer conn.lk.Unlock() + + if atomic.LoadInt32(&conn.closed) == 0 { + select { + case conn.reconfigureCh <- reconfigureInfo{eventType: connClosedReconfigureType, reconfigureID: conn.connKey}: + default: + conn.logger.Info("Reconfigure channel is full. Drop reconfigure command due to connection close.") + } + + close(conn.closeChannel) + + atomic.StoreInt32(&conn.closed, 1) + conn.logger.Info("Output host connection closed.") + } +} + +func (conn *outputHostConnection) isOpened() bool { + return atomic.LoadInt32(&conn.opened) != 0 +} + +func (conn *outputHostConnection) isClosed() bool { + return atomic.LoadInt32(&conn.closed) != 0 +} + +func (conn *outputHostConnection) readMessagesPump() { + var localCredits int32 + for { + conn.reporter.UpdateGauge(metrics.ConsumeLocalCredits, nil, int64(localCredits)) + if localCredits >= conn.creditBatchSize { + // Issue more credits + select { + case conn.creditsCh <- localCredits: + localCredits = 0 + default: + conn.logger.Debugf("Credits channel is full. Unable to write to creditsCh.") + } + } + + cmd, err := conn.outputHostStream.Read() + if err != nil { + // Error reading from stream. Time to close and bail out. + conn.logger.Infof("Error reading OutputHost Message Stream: %v", err) + + // Stream is closed. Close the connection and bail out + conn.close() + return + } + + if cmd.GetType() == cherami.OutputHostCommandType_MESSAGE { + conn.reporter.IncCounter(metrics.ConsumeMessageRate, nil, 1) + msg := cmd.Message + delivery := newDelivery(msg, conn) + conn.deliveryCh <- delivery + localCredits++ + } else if cmd.GetType() == cherami.OutputHostCommandType_RECONFIGURE { + conn.reporter.IncCounter(metrics.ConsumeReconfigureRate, nil, 1) + reconfigInfo := cmd.Reconfigure + conn.logger.WithField(common.TagReconfigureID, common.FmtReconfigureID(reconfigInfo.GetUpdateUUID())).Info("Reconfigure command received from OutputHost.") + select { + case conn.reconfigureCh <- reconfigureInfo{eventType: reconfigureCmdReconfigureType, reconfigureID: reconfigInfo.GetUpdateUUID()}: + default: + conn.logger.WithField(common.TagReconfigureID, common.FmtReconfigureID(reconfigInfo.GetUpdateUUID())).Info("Reconfigure channel is full. Drop reconfigure command.") + } + } + } +} + +func (conn *outputHostConnection) writeCreditsPump() { + // This will unblock any pending read operations on the stream. + defer func() { + if conn.cancel != nil { + conn.cancel() + } + conn.outputHostStream.Done() + }() + + // Send initial credits to OutputHost + if err := conn.sendCredits(int32(conn.prefetchSize)); err != nil { + conn.logger.Infof("Error sending initialCredits to OutputHost: %v", err) + + conn.close() + return + } + + // Start the write pump + for { + select { + case credits := <-conn.creditsCh: + // TODO: this needs to be converted into a metric + //conn.logger.Infof("Sending credits to output host: %v", credits) + if err := conn.sendCredits(credits); err != nil { + conn.logger.Infof("Error sending creditBatchSize to OutputHost: %v", err) + + conn.close() + } + case <-conn.closeChannel: + conn.logger.Info("WriteCreditsPump closing due to connection closed.") + return + } + } +} + +func (conn *outputHostConnection) writeAcksPump() { + var buffer []string + var bufferTicker <-chan time.Time + var err error + + bufferTicker = time.Tick(time.Second / 10) + + push := func(buf *[]string) { + ackRequest := cherami.NewAckMessagesRequest() + ackRequest.AckIds = *buf + *buf = nil + + ctx, cancel := thrift.NewContext(conn.options.Timeout) + conn.reporter.IncCounter(metrics.ConsumeAckRate, nil, int64(len(ackRequest.AckIds))) + err = conn.ackClient.AckMessages(ctx, ackRequest) + if err != nil { + conn.logger.Infof("error in ack batch: %v", err) + conn.reporter.IncCounter(metrics.ConsumeAckFailedRate, nil, int64(len(ackRequest.AckIds))) + } + cancel() + } + +ackPump: + for { + select { + + case ackIds, ok := <-conn.acksBatchCh: + if !ok { + conn.logger.Info("writeAcksPump closing.") + break ackPump + } + buffer = append(buffer, ackIds...) + + if len(buffer) >= ackBatchSize { + push(&buffer) + } + + case <-bufferTicker: + // Check for excessive idleness and disable the ticker if this happens + if len(buffer) > 0 { + push(&buffer) + } + } + } + + // Make sure we push all acks before bailing out of the pump + if len(buffer) > 0 { + push(&buffer) + } +} + +func (conn *outputHostConnection) sendCredits(credits int32) error { + flows := cherami.NewControlFlow() + flows.Credits = common.Int32Ptr(credits) + + conn.reporter.IncCounter(metrics.ConsumeCreditRate, nil, 1) + sw := conn.reporter.StartTimer(metrics.ConsumeCreditLatency, nil) + defer sw.Stop() + + err := conn.outputHostStream.Write(flows) + if err == nil { + err = conn.outputHostStream.Flush() + } else { + conn.reporter.IncCounter(metrics.ConsumeCreditRate, nil, 1) + } + + return err +} + +func (conn *outputHostConnection) GetAcknowledgerID() string { + return conn.connKey +} + +func (conn *outputHostConnection) Ack(ids []string) error { + conn.acksBatchLk.RLock() + defer conn.acksBatchLk.RUnlock() + + select { + case <-conn.acksBatchClosed: + return nil + default: + conn.acksBatchCh <- ids + } + + return nil +} + +func (conn *outputHostConnection) Nack(ids []string) error { + ackRequest := cherami.NewAckMessagesRequest() + ackRequest.NackIds = ids + + ctx, cancel := thrift.NewContext(conn.options.Timeout) + defer cancel() + + conn.reporter.IncCounter(metrics.ConsumeNackRate, nil, int64(len(ids))) + return conn.ackClient.AckMessages(ctx, ackRequest) +} + +func (conn *outputHostConnection) closeAcksBatchCh() { + conn.acksBatchLk.Lock() + defer conn.acksBatchLk.Unlock() + + close(conn.acksBatchCh) + close(conn.acksBatchClosed) +} diff --git a/client/cherami/outputhostconnection_test.go b/client/cherami/outputhostconnection_test.go new file mode 100644 index 0000000..96a8e2b --- /dev/null +++ b/client/cherami/outputhostconnection_test.go @@ -0,0 +1,243 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "testing" + "time" + + "github.com/stretchr/testify/suite" + + "errors" + "io" + + "github.com/uber/cherami-thrift/.generated/go/cherami" + "github.com/uber/cherami-client-go/common" + "github.com/uber/cherami-client-go/common/metrics" + mc "github.com/uber/cherami-client-go/mocks/clients/cherami" + + log "github.com/Sirupsen/logrus" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/uber-common/bark" +) + +type OutputHostConnectionSuite struct { + *require.Assertions // override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test, not merely log an error + suite.Suite +} + +func TestOutputHostConnectionSuite(t *testing.T) { + suite.Run(t, new(OutputHostConnectionSuite)) +} + +func (s *OutputHostConnectionSuite) SetupTest() { + s.Assertions = require.New(s.T()) // Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil +} + +func (s *OutputHostConnectionSuite) TestOutputHostBasic() { + conn, _, _, stream, messagesCh := createOutputHostConnection() + + stream.On("Write", mock.Anything).Return(nil) + stream.On("Flush").Return(nil) + stream.On("Read").Return(wrapMessageInCommand(&cherami.ConsumerMessage{ + AckId: common.StringPtr("test"), + }), nil) + + conn.open() + s.True(conn.isOpened(), "Connection not opened.") + + delivery := <-messagesCh + s.NotNil(delivery, "Delivery cannot be nil.") + + msg := delivery.GetMessage() + s.NotNil(msg, "Message cannot be nil.") + s.Equal("test", msg.GetAckId()) +} + +func (s *OutputHostConnectionSuite) TestReadFailed() { + conn, _, _, stream, _ := createOutputHostConnection() + + stream.On("Write", mock.Anything).Return(nil) + stream.On("Flush").Return(nil) + stream.On("Read").Return(nil, errors.New("some error")) + stream.On("Done").Return(nil) + + conn.open() + s.True(conn.isOpened(), "Connection not opened.") + + time.Sleep(10 * time.Millisecond) + + s.True(conn.isClosed(), "Connection not opened.") +} + +func (s *OutputHostConnectionSuite) TestReadEOF() { + conn, _, _, stream, _ := createOutputHostConnection() + + stream.On("Write", mock.Anything).Return(nil) + stream.On("Flush").Return(nil) + stream.On("Read").Return(nil, io.EOF) + stream.On("Done").Return(nil) + + conn.open() + s.True(conn.isOpened(), "Connection not opened.") + + time.Sleep(10 * time.Millisecond) + + s.True(conn.isClosed(), "Connection not opened.") +} + +func (s *OutputHostConnectionSuite) TestCreditsRenewSuccess() { + conn, _, _, stream, messagesCh := createOutputHostConnection() + + initialFlows := cherami.NewControlFlow() + initialFlows.Credits = common.Int32Ptr(conn.prefetchSize) + + renewFlows := cherami.NewControlFlow() + renewFlows.Credits = common.Int32Ptr(conn.creditBatchSize) + + stream.On("Write", initialFlows).Return(nil).Once() + stream.On("Write", renewFlows).Return(nil).Once() + stream.On("Flush").Return(nil) + stream.On("Read").Return(wrapMessageInCommand(&cherami.ConsumerMessage{ + AckId: common.StringPtr("test"), + }), nil) + + conn.open() + s.True(conn.isOpened(), "Connection not opened.") + + for i := 0; i < int(conn.creditBatchSize); i++ { + delivery := <-messagesCh + s.NotNil(delivery, "Delivery cannot be nil.") + + msg := delivery.GetMessage() + s.NotNil(msg, "Message cannot be nil.") + s.Equal("test", msg.GetAckId()) + } + + time.Sleep(10 * time.Millisecond) + + stream.AssertExpectations(s.T()) +} + +func (s *OutputHostConnectionSuite) TestInitialCreditsWriteFailed() { + conn, _, _, stream, _ := createOutputHostConnection() + + initialFlows := cherami.NewControlFlow() + initialFlows.Credits = common.Int32Ptr(conn.prefetchSize) + + stream.On("Write", initialFlows).Return(errors.New("some error")).After(10 * time.Millisecond).Once() + stream.On("Read").Return(wrapMessageInCommand(&cherami.ConsumerMessage{ + AckId: common.StringPtr("test"), + }), nil) + stream.On("Done").Return(nil) + + conn.open() + s.True(conn.isOpened(), "Connection not opened.") + + time.Sleep(20 * time.Millisecond) + s.True(conn.isClosed(), "Connection not closed.") + + stream.AssertExpectations(s.T()) +} + +func (s *OutputHostConnectionSuite) TestInitialCreditsFlushFailed() { + conn, _, _, stream, _ := createOutputHostConnection() + + initialFlows := cherami.NewControlFlow() + initialFlows.Credits = common.Int32Ptr(conn.prefetchSize) + + stream.On("Write", initialFlows).Return(nil).After(10 * time.Millisecond).Once() + stream.On("Read").Return(wrapMessageInCommand(&cherami.ConsumerMessage{ + AckId: common.StringPtr("test"), + }), nil) + stream.On("Flush").Return(errors.New("some error")) + stream.On("Done").Return(nil) + + conn.open() + s.True(conn.isOpened(), "Connection not opened.") + + time.Sleep(20 * time.Millisecond) + s.True(conn.isClosed(), "Connection not closed.") + + stream.AssertExpectations(s.T()) +} + +func (s *OutputHostConnectionSuite) TestRenewCreditsFailed() { + conn, _, _, stream, messagesCh := createOutputHostConnection() + + initialFlows := cherami.NewControlFlow() + initialFlows.Credits = common.Int32Ptr(conn.prefetchSize) + + renewFlows := cherami.NewControlFlow() + renewFlows.Credits = common.Int32Ptr(conn.creditBatchSize) + + stream.On("Write", initialFlows).Return(nil).Once() + stream.On("Write", renewFlows).Return(errors.New("some error")).Once() + stream.On("Flush").Return(nil) + stream.On("Read").Return(wrapMessageInCommand(&cherami.ConsumerMessage{ + AckId: common.StringPtr("test"), + }), nil) + stream.On("Done").Return(nil) + + conn.open() + s.True(conn.isOpened(), "Connection not opened.") + + for i := 0; i < int(conn.creditBatchSize); i++ { + delivery := <-messagesCh + s.NotNil(delivery, "Delivery cannot be nil.") + + msg := delivery.GetMessage() + s.NotNil(msg, "Message cannot be nil.") + s.Equal("test", msg.GetAckId()) + } + + time.Sleep(10 * time.Millisecond) + s.True(conn.isClosed(), "Connection not closed.") + + stream.AssertExpectations(s.T()) +} + +func createOutputHostConnection() (*outputHostConnection, *mc.MockTChanBOutClient, *mc.MockTChanBOutClient, *mc.MockBOutOpenConsumerStreamOutCall, chan Delivery) { + host := "testHost" + outputHostClient := new(mc.MockTChanBOutClient) + ackClient := new(mc.MockTChanBOutClient) + wsConnector := new(mc.MockWSConnector) + stream := new(mc.MockBOutOpenConsumerStreamOutCall) + deliveryCh := make(chan Delivery) + reconfigureCh := make(chan reconfigureInfo, 10) + options := &ClientOptions{Timeout: time.Minute} + + outputHostClient.On("OpenConsumerStream", mock.Anything).Return(stream, nil) + wsConnector.On("OpenConsumerStream", mock.Anything, mock.Anything).Return(stream, nil) + conn := newOutputHostConnection(outputHostClient, ackClient, wsConnector, "/test/outputhostconnection", "/consumer", options, + deliveryCh, reconfigureCh, host, cherami.Protocol_WS, int32(100), bark.NewLoggerFromLogrus(log.StandardLogger()), metrics.NewNullReporter()) + + return conn, outputHostClient, ackClient, stream, deliveryCh +} + +func wrapMessageInCommand(msg *cherami.ConsumerMessage) *cherami.OutputHostCommand { + cmd := cherami.NewOutputHostCommand() + cmd.Type = common.CheramiOutputHostCommandTypePtr(cherami.OutputHostCommandType_MESSAGE) + cmd.Message = msg + + return cmd +} diff --git a/client/cherami/publisher.go b/client/cherami/publisher.go new file mode 100644 index 0000000..81f92fd --- /dev/null +++ b/client/cherami/publisher.go @@ -0,0 +1,347 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/uber/cherami-thrift/.generated/go/cherami" + "github.com/uber/cherami-client-go/common" + "github.com/uber/cherami-client-go/common/backoff" + "github.com/uber/cherami-client-go/common/metrics" + "github.com/uber/tchannel-go" + "github.com/uber/tchannel-go/thrift" + + "github.com/pborman/uuid" +) + +type ( + publisherImpl struct { + basePublisher + connection *tchannel.Channel + messagesCh chan putMessageRequest + reconfigureCh chan reconfigureInfo + closingCh chan struct{} + closeChannel chan struct{} + isClosing int32 + maxInflightMessagesPerConnection int + lk sync.Mutex + opened bool + closed bool + connections map[string]*connection + wsConnector WSConnector + reconfigurable *reconfigurable + } + + putMessageRequest struct { + message *cherami.PutMessage + messageAck chan<- *PublisherReceipt + } +) + +const ( + maxDuration time.Duration = 1<<62 - 1 + defaultMessageTimeout = time.Minute +) + +var _ Publisher = (*publisherImpl)(nil) + +// NewPublisher constructs a new Publisher object +func NewPublisher(client *clientImpl, path string, maxInflightMessagesPerConnection int) Publisher { + base := basePublisher{ + client: client, + retryPolicy: createDefaultPublisherRetryPolicy(), + path: path, + logger: client.options.Logger.WithField(common.TagDstPth, common.FmtDstPth(path)), + reporter: client.options.MetricsReporter, + } + publisher := &publisherImpl{ + basePublisher: base, + maxInflightMessagesPerConnection: maxInflightMessagesPerConnection, + messagesCh: make(chan putMessageRequest), + reconfigureCh: make(chan reconfigureInfo, reconfigureChBufferSize), + closingCh: make(chan struct{}), + closeChannel: make(chan struct{}), + isClosing: 0, + wsConnector: NewWSConnector(), + } + + return publisher +} + +func (s *publisherImpl) Open() error { + s.lk.Lock() + defer s.lk.Unlock() + + if !s.opened { + publisherOptions, err := s.readPublisherOptions() + if err != nil { + s.logger.Errorf("Error resolving input hosts: %v", err) + return err + } + s.checksumOption = publisherOptions.GetChecksumOption() + + ch, err := tchannel.NewChannel(uuid.New(), nil) + if err != nil { + return err + } + s.connection = ch + + chosenProtocol, chosenHostAddresses := s.choosePublishEndpoints(publisherOptions) + s.connections = make(map[string]*connection) + for _, host := range chosenHostAddresses { + connKey := common.GetConnectionKey(host) + // ReadDestinationHosts can return duplicates, so we need to dedupe to make sure we create a single connection for each host + if _, ok := s.connections[connKey]; !ok { + conn, err := s.createInputHostConnection(connKey, chosenProtocol) + if err != nil { + if conn != nil { + conn.close() + } + + // TODO: We should be returning failure only when no connections could be opened to InputHost + return err + } + + s.connections[connKey] = conn + } + } + s.reporter.UpdateGauge(metrics.PublishNumConnections, nil, int64(len(s.connections))) + + s.reconfigurable = newReconfigurable(s.reconfigureCh, s.closingCh, s.reconfigurePublisher, s.logger) + go s.reconfigurable.reconfigurePump() + + s.opened = true + s.logger.Info("Publisher Opened.") + } + + return nil +} + +func (s *publisherImpl) Close() { + if atomic.CompareAndSwapInt32(&s.isClosing, 0, 1) { + close(s.closingCh) + } else { + return + } + + s.lk.Lock() + defer s.lk.Unlock() + if s.connections != nil { + for _, inputHostConn := range s.connections { + inputHostConn.close() + } + s.reporter.UpdateGauge(metrics.PublishNumConnections, nil, 0) + } + + if s.connection != nil { + s.connection.Close() + } + + // closing channel should make all outstanding publish to fail + close(s.closeChannel) + s.closed = true + s.opened = false + s.logger.Info("Publisher Closed.") +} + +// Publish can be used to synchronously publish a message to Cherami +func (s *publisherImpl) Publish(message *PublisherMessage) *PublisherReceipt { + timeoutTimer := time.NewTimer(defaultMessageTimeout) + defer timeoutTimer.Stop() + + var receipt *PublisherReceipt + publishOp := func() error { + srCh := make(chan *PublisherReceipt, 1) + _, err := s.PublishAsync(message, srCh) + if err != nil { + return err + } + + select { + case receipt = <-srCh: + return receipt.Error + case <-timeoutTimer.C: + return ErrMessageTimedout + } + } + + err := backoff.Retry(publishOp, s.retryPolicy, nil) + if err != nil { + return &PublisherReceipt{Error: err} + } + + return receipt +} + +// PublishAsync accepts a message, but returns immediately with the local +// reference ID +func (s *publisherImpl) PublishAsync(message *PublisherMessage, done chan<- *PublisherReceipt) (string, error) { + + if !s.opened { + return "", fmt.Errorf("Cannot publish message to path '%s'. Publisher is not opened.", s.path) + } + + request := putMessageRequest{ + message: s.toPutMessage(message), + messageAck: done, + } + id := request.message.GetID() + s.messagesCh <- request + return id, nil +} + +func (s *publisherImpl) reconfigurePublisher() { + s.lk.Lock() + defer s.lk.Unlock() + + select { + case <-s.closingCh: + s.logger.Info("Publisher is closing. Ignore reconfiguration.") + default: + var conn *connection + + publisherOptions, err := s.readPublisherOptions() + if err != nil { + s.logger.Infof("Error resolving input hosts: %v", err) + if _, ok := err.(*cherami.EntityNotExistsError); ok { + // Destination is deleted. Continue with reconfigure and close all connections + publisherOptions = &cherami.ReadPublisherOptionsResult_{} + } else { + // This is a potentially a transient error. + // Retry on next reconfigure + return + } + } + + chosenProtocol, chosenHostAddresses := s.choosePublishEndpoints(publisherOptions) + + // First remove any closed connections from the connections map + for existingConnKey, existingConn := range s.connections { + if existingConn.isClosed() { + delete(s.connections, existingConnKey) + } + } + + currentHosts := make(map[string]*connection) + for _, host := range chosenHostAddresses { + connKey := common.GetConnectionKey(host) + conn = s.connections[connKey] + if conn == nil || conn.isClosed() { + // Newly assigned host, create a connection + connLogger := s.logger.WithField(common.TagHostIP, common.FmtHostIP(connKey)) + connLogger.Info("Discovered new InputHost during reconfiguration.") + conn, err = s.createInputHostConnection(connKey, chosenProtocol) + if err != nil { + connLogger.Info("Error creating connection to InputHost after reconfiguration.") + if conn != nil { + conn.close() + } + } else { + connLogger.Info("Successfully created connection to InputHost after reconfiguration.") + // Successfully created a connection to new host. Add it to current list of input hosts + currentHosts[connKey] = conn + } + } else { + // Existing input host connection, copy it over to current collection of input hosts + currentHosts[connKey] = conn + } + } + + // Now close all remaining list of input host connections + for host, inputHostConn := range s.connections { + if _, ok := currentHosts[host]; !ok { + connLogger := s.logger.WithField(common.TagHostIP, common.FmtHostIP(inputHostConn.connKey)) + connLogger.Info("Closing connection to InputHost after reconfiguration.") + inputHostConn.close() + } + } + + connectedHosts := make([]string, len(currentHosts)) + for k := range currentHosts { + connectedHosts = append(connectedHosts, k) + } + s.logger.WithField(common.TagHosts, connectedHosts).Debug("List of connected input hosts.") + + // Now assign the list of current input hosts to list of connections + s.connections = currentHosts + s.reporter.UpdateGauge(metrics.PublishNumConnections, nil, int64(len(s.connections))) + } +} + +func (s *publisherImpl) createInputHostConnection(connKey string, protocol cherami.Protocol) (*connection, error) { + connLogger := s.logger.WithField(common.TagHostIP, common.FmtHostIP(connKey)) + + // TODO [ljj] to be removed once moved to websocket + c, err := s.createInputHostClient(connKey) + if err != nil { + connLogger.Infof("Error creating InputHost client: %v", err) + return nil, err + } + + conn := newConnection(c, s.wsConnector, s.path, s.messagesCh, s.reconfigureCh, connKey, protocol, s.maxInflightMessagesPerConnection, connLogger, s.reporter) + err = conn.open() + if err != nil { + connLogger.Infof("Error opening InputHost connection: %v", err) + return conn, err + } + + return conn, nil +} + +func (s *publisherImpl) createInputHostClient(hostPort string) (cherami.TChanBIn, error) { + tClient := thrift.NewClient(s.connection, common.InputServiceName, &thrift.ClientOptions{ + HostPort: hostPort, + }) + client := cherami.NewTChanBInClient(tClient) + + return client, nil +} + +func (s *publisherImpl) chooseProcotol(hostProtocols []*cherami.HostProtocol) (int, error) { + clientSupportedProtocol := map[cherami.Protocol]bool{cherami.Protocol_WS: true} + clientSupportButDeprecated := -1 + serverSupportedProtocol := make([]cherami.Protocol, 0, len(hostProtocols)) + + for idx, hostProtocol := range hostProtocols { + serverSupportedProtocol = append(serverSupportedProtocol, hostProtocol.GetProtocol()) + if _, found := clientSupportedProtocol[hostProtocol.GetProtocol()]; found { + if !hostProtocol.GetDeprecated() { + // found first supported and non-deprecated one, done + return idx, nil + } else if clientSupportButDeprecated == -1 { + // found first supported but deprecated one, keep looking + clientSupportButDeprecated = idx + } + } + } + + if clientSupportButDeprecated == -1 { + s.logger.WithField(`protocols`, serverSupportedProtocol).Error("No protocol is supported by client") + return clientSupportButDeprecated, &cherami.BadRequestError{Message: `No protocol is supported by client`} + } + + s.logger.WithField(`protocol`, hostProtocols[clientSupportButDeprecated].GetProtocol()).Warn("Client using deprecated protocol") + return clientSupportButDeprecated, nil +} diff --git a/client/cherami/reconfigurable.go b/client/cherami/reconfigurable.go new file mode 100644 index 0000000..821d48d --- /dev/null +++ b/client/cherami/reconfigurable.go @@ -0,0 +1,103 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "time" + + "github.com/uber/cherami-client-go/common" + "github.com/uber-common/bark" +) + +type ( + reconfigureType int + + reconfigureInfo struct { + eventType reconfigureType + reconfigureID string + } + + reconfigurable struct { + reconfigureCh <-chan reconfigureInfo + reconfigureHandler func() + connClosedHandler func(host string) + closingCh chan struct{} + logger bark.Logger + } +) + +const ( + heartbeatDuration = time.Second * 10 + reconfigureChBufferSize = 1 + limiterDuration = time.Millisecond * 500 +) + +const ( + reconfigureCmdReconfigureType = iota + connClosedReconfigureType +) + +func newReconfigurable(reconfigureCh <-chan reconfigureInfo, closingCh chan struct{}, reconfigureHandler func(), + logger bark.Logger) *reconfigurable { + r := &reconfigurable{ + reconfigureCh: reconfigureCh, + closingCh: closingCh, + reconfigureHandler: reconfigureHandler, + logger: logger, + } + + return r +} + +func (s *reconfigurable) reconfigurePump() { + s.logger.Info("Reconfiguration pump started.") + heartbeat := time.NewTicker(heartbeatDuration) + limiter := time.NewTicker(limiterDuration) + lastReconfigureID := "" + for { + select { + case <-s.closingCh: + s.logger.Info("Reconfigure pump closing.") + // Publisher/Consumer is going away, stop heartbeating and bail out of the pump + heartbeat.Stop() + limiter.Stop() + return + default: + select { + case reconfigure := <-s.reconfigureCh: + <-limiter.C + switch reconfigure.eventType { + case reconfigureCmdReconfigureType: + s.logger.WithField(common.TagReconfigureID, common.FmtReconfigureID(reconfigure.reconfigureID)).Infof("Reconfiguration command received from host connection.") + if lastReconfigureID != reconfigure.reconfigureID { + lastReconfigureID = reconfigure.reconfigureID + s.reconfigureHandler() + } + case connClosedReconfigureType: + s.logger.WithField(common.TagHostIP, common.FmtHostIP(reconfigure.reconfigureID)).Infof("Reconfiguration due to host connection closed.") + s.reconfigureHandler() + } + case <-heartbeat.C: + s.reconfigureHandler() + } + } + } +} diff --git a/client/cherami/task.go b/client/cherami/task.go new file mode 100644 index 0000000..4c72ff9 --- /dev/null +++ b/client/cherami/task.go @@ -0,0 +1,52 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import "encoding/json" + +type ( + taskImpl struct { + Type string `json:"type"` + ID string `json:"id"` + JSONValue string `json:"value"` + Context map[string]string `json:"context"` + } +) + +// GetType returns the unique type name that can be used to identify cooresponding task handler +func (t *taskImpl) GetType() string { + return t.Type +} + +// GetID returns the unique identifier of this specific task +func (t *taskImpl) GetID() string { + return t.ID +} + +// GetValue deserializes task value into given struct that matches the type used to publish the task +func (t *taskImpl) GetValue(instance interface{}) error { + return json.Unmarshal([]byte(t.JSONValue), instance) +} + +// GetContext returns key value pairs context accosicated with the task when published +func (t *taskImpl) GetContext() map[string]string { + return t.Context +} diff --git a/client/cherami/taskexecutor.go b/client/cherami/taskexecutor.go new file mode 100644 index 0000000..03246b3 --- /dev/null +++ b/client/cherami/taskexecutor.go @@ -0,0 +1,182 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "encoding/json" + "sync" + "sync/atomic" + "time" +) + +type ( + taskExecutorImpl struct { + client Client + consumer Consumer + concurrency int + taskFuncs map[string]TaskFunc + taskFuncsLock sync.RWMutex + + path string + cgName string + consumerName string + prefetchSize int + timeout time.Duration + + waitGroup *sync.WaitGroup + killSignal chan struct{} + initialized uint32 + mu sync.Mutex + } +) + +// NewTaskExecutor creates a task executor +func NewTaskExecutor(client Client, request *CreateTaskExecutorRequest) TaskExecutor { + if client == nil || request == nil { + return nil + } + return &taskExecutorImpl{ + client: client, + consumer: nil, + concurrency: request.Concurrency, + taskFuncs: make(map[string]TaskFunc), + + path: request.Path, + cgName: request.ConsumerGroupName, + consumerName: request.ConsumerName, + prefetchSize: request.PrefetchCount, + timeout: request.Timeout, + + waitGroup: &sync.WaitGroup{}, + killSignal: make(chan struct{}), + } +} + +// Register registers task handler with its *unique* task type +func (t *taskExecutorImpl) Register(taskType string, taskFunc TaskFunc) { + t.taskFuncsLock.Lock() + defer t.taskFuncsLock.Unlock() + t.taskFuncs[taskType] = taskFunc +} + +// Start starts dequeuing tasks and execute them +func (t *taskExecutorImpl) Start() error { + taskCh := make(chan Delivery, t.prefetchSize) + if _, err := t.getConsumer().Open(taskCh); err != nil { + return err + } + + // spin up workers to handle tasks + for i := 0; i < t.concurrency; i++ { + + go func(workerID int) { + defer t.waitGroup.Done() + + LOOP: + for { + select { + + // if asked to stop + case <-t.killSignal: + // fmt.Printf("[Worker %d] Killed\n", workerID) + break LOOP + + // or get new task + case taskDelivery := <-taskCh: + // deserialize task data + taskData := taskDelivery.GetMessage().GetPayload().GetData() + task := &taskImpl{} + if err := json.Unmarshal(taskData, task); err != nil { + // fmt.Printf("[Worker %d] Failed to json unmarshal task: %v\n", workerID, err) + taskDelivery.Nack() + continue + } + + // process + if taskFunc, found := t.getTaskFunc(task.GetType()); found { + if err := taskFunc(task); err != nil { + // fmt.Printf("[Worker %d] Failed execute task %s: %v\n", workerID, task.GetID(), err) + taskDelivery.Nack() + continue + } + } else { + // fmt.Printf("[Worker %d] Task %s type %s not registered\n", workerID, task.GetID(), task.GetType()) + taskDelivery.Nack() + continue + } + + // ack back to Cherami + taskDelivery.Ack() + } + } + }(i) + + t.waitGroup.Add(1) + } + + return nil +} + +// Stop stops dequeuing/exeuction of tasks +// There's no guarantee to drain scheduled tasks when Stop is invoked +func (t *taskExecutorImpl) Stop() { + // signal task kill + close(t.killSignal) + + // wait for worker to be done + t.waitGroup.Wait() + + // close consumer + if atomic.LoadUint32(&t.initialized) == 1 { + t.consumer.Close() + } +} + +func (t *taskExecutorImpl) getConsumer() Consumer { + if atomic.LoadUint32(&t.initialized) == 1 { + return t.consumer + } + + t.mu.Lock() + defer t.mu.Unlock() + + if t.initialized == 0 { + t.consumer = t.client.CreateConsumer(&CreateConsumerRequest{ + Path: t.path, + ConsumerGroupName: t.cgName, + ConsumerName: t.consumerName, + PrefetchCount: t.prefetchSize, + Options: &ClientOptions{ + Timeout: t.timeout, + }, + }) + atomic.StoreUint32(&t.initialized, 1) + } + + return t.consumer +} + +func (t *taskExecutorImpl) getTaskFunc(taskType string) (taskFunc TaskFunc, found bool) { + t.taskFuncsLock.RLock() + defer t.taskFuncsLock.RUnlock() + taskFunc, found = t.taskFuncs[taskType] + return +} diff --git a/client/cherami/taskscheduler.go b/client/cherami/taskscheduler.go new file mode 100644 index 0000000..e93f6ac --- /dev/null +++ b/client/cherami/taskscheduler.go @@ -0,0 +1,108 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "encoding/json" + "sync" + "sync/atomic" +) + +type ( + taskSchedulerImpl struct { + client Client + publisher Publisher + + path string + maxInflightMessagesPerConnection int + + initialized uint32 + mu sync.Mutex + } +) + +// NewTaskScheduler creates a task scheduler +func NewTaskScheduler(client Client, request *CreateTaskSchedulerRequest) TaskScheduler { + if client == nil || request == nil { + return nil + } + return &taskSchedulerImpl{ + client: client, + publisher: nil, + path: request.Path, + maxInflightMessagesPerConnection: request.MaxInflightMessagesPerConnection, + } +} + +// Open gets TaskScheduler for scheduling tasks +func (t *taskSchedulerImpl) Open() error { + return t.getPublisher().Open() +} + +// Close make sure resources are released +func (t *taskSchedulerImpl) Close() { + if atomic.LoadUint32(&t.initialized) == 1 { + t.publisher.Close() + } +} + +// ScheduleTask enqueues a task +func (t *taskSchedulerImpl) ScheduleTask(request *ScheduleTaskRequest) error { + jsonValue, err := json.Marshal(request.TaskValue) + if err != nil { + return err + } + + messageData, err := json.Marshal(&taskImpl{ + Type: request.TaskType, + ID: request.TaskID, + JSONValue: string(jsonValue), + Context: request.Context, + }) + if err != nil { + return err + } + + receipt := t.getPublisher().Publish(&PublisherMessage{ + Data: messageData, + Delay: request.Delay, + }) + return receipt.Error +} + +func (t *taskSchedulerImpl) getPublisher() Publisher { + if atomic.LoadUint32(&t.initialized) == 1 { + return t.publisher + } + + t.mu.Lock() + defer t.mu.Unlock() + + if t.initialized == 0 { + t.publisher = t.client.CreatePublisher(&CreatePublisherRequest{ + Path: t.path, + MaxInflightMessagesPerConnection: t.maxInflightMessagesPerConnection, + }) + atomic.StoreUint32(&t.initialized, 1) + } + + return t.publisher +} diff --git a/client/cherami/tchanPublisher.go b/client/cherami/tchanPublisher.go new file mode 100644 index 0000000..188102e --- /dev/null +++ b/client/cherami/tchanPublisher.go @@ -0,0 +1,427 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "errors" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/pborman/uuid" + "github.com/uber-common/bark" + "github.com/uber/tchannel-go" + "github.com/uber/tchannel-go/thrift" + + "github.com/uber/cherami-thrift/.generated/go/cherami" + "github.com/uber/cherami-client-go/common" + "github.com/uber/cherami-client-go/common/backoff" + "github.com/uber/cherami-client-go/common/metrics" +) + +type ( + tchannelBatchPublisher struct { + basePublisher + sync.RWMutex + opened int32 + closed int32 + tchan *tchannel.Channel + endpoints map[string]struct{} + thriftClient cherami.TChanBIn + reconfigureCh chan reconfigureInfo + reconfigurable *reconfigurable + messagesCh chan *putMessageRequest + closeCh chan struct{} + } +) + +var errInvalidAckID = errors.New("invalid msg id found in ack") +var errPublisherClosed = errors.New("publisher is closed") +var errPublisherUnopened = errors.New("publish is not open") + +const ( + maxBatchSize = 16 // max number of messages in a single batch to server + endpointsInitialSz = 8 // initial size of the endpoints map + messageBatchThriftTimeout = 30 * time.Second + enqueueTimeout = time.Minute + inputServiceTChannelPort = "4240" +) + +var _ Publisher = (*tchannelBatchPublisher)(nil) + +func newTChannelBatchPublisher(client Client, path string, logger bark.Logger, metricsReporter metrics.Reporter) Publisher { + base := basePublisher{ + client: client, + retryPolicy: createDefaultPublisherRetryPolicy(), + path: path, + logger: logger.WithField(common.TagDstPth, common.FmtDstPth(path)), + reporter: metricsReporter, + } + return &tchannelBatchPublisher{ + basePublisher: base, + reconfigureCh: make(chan reconfigureInfo, 1), + messagesCh: make(chan *putMessageRequest, maxBatchSize), + closeCh: make(chan struct{}), + endpoints: make(map[string]struct{}, endpointsInitialSz), + } +} + +// Open prepares the publisher for message publishing +func (p *tchannelBatchPublisher) Open() error { + + p.Lock() + defer p.Unlock() + + if atomic.LoadInt32(&p.opened) == 1 { + return nil + } + + publisherOptions, err := p.readPublisherOptions() + if err != nil { + p.logger.Errorf("Error resolving input hosts: %v", err) + return err + } + + ch, err := tchannel.NewChannel(uuid.New(), nil) + if err != nil { + return err + } + + p.tchan = ch + p.checksumOption = publisherOptions.GetChecksumOption() + + _, addrs := p.choosePublishEndpoints(publisherOptions) + for _, addr := range addrs { + key := net.JoinHostPort(addr.GetHost(), inputServiceTChannelPort) + p.endpoints[key] = struct{}{} + p.tchan.Peers().Add(key) + } + + p.thriftClient = cherami.NewTChanBInClient(thrift.NewClient(p.tchan, common.InputServiceName, nil)) + + p.reporter.UpdateGauge(metrics.PublishNumConnections, nil, int64(len(addrs))) + + p.reconfigurable = newReconfigurable(p.reconfigureCh, p.closeCh, p.reconfigureHandler, p.logger) + go p.reconfigurable.reconfigurePump() + go p.processor() + atomic.StoreInt32(&p.opened, 1) + p.logger.WithField(`endpoints`, addrs).Info("Publisher Opened.") + + return nil +} + +// Close closes the publisher, no more messages +// can be published and in-flight messages will +// be failed. +func (p *tchannelBatchPublisher) Close() { + p.Lock() + defer p.Unlock() + + if p.isClosed() { + return + } + + close(p.closeCh) + if p.tchan != nil { + p.tchan.Close() + } + + p.drain() + p.reporter.UpdateGauge(metrics.PublishNumConnections, nil, int64(0)) + atomic.StoreInt32(&p.closed, 1) + p.logger.Info("Publisher Closed.") +} + +// Publish publishes a message to cherami +func (p *tchannelBatchPublisher) Publish(message *PublisherMessage) *PublisherReceipt { + + if !p.isOpened() { + return &PublisherReceipt{Error: errPublisherUnopened} + } + if p.isClosed() { + return &PublisherReceipt{Error: errPublisherClosed} + } + + var receipt *PublisherReceipt + + publishOp := func() error { + ackCh := make(chan *PublisherReceipt, 1) + _, err := p.PublishAsync(message, ackCh) + if err != nil { + return err + } + receipt = <-ackCh + if receipt.Error != nil { + return receipt.Error + } + return nil + } + + err := backoff.Retry(publishOp, p.retryPolicy, nil) + if err != nil { + return &PublisherReceipt{Error: err} + } + + return receipt +} + +// PublishAsync publishes a message asynchronously. +// On completion, the receipt will be enqueued into +// the done channel. +func (p *tchannelBatchPublisher) PublishAsync(message *PublisherMessage, done chan<- *PublisherReceipt) (string, error) { + + if !p.isOpened() { + return "", errPublisherUnopened + } + if p.isClosed() { + return "", errPublisherClosed + } + + putMsg := &putMessageRequest{ + message: p.toPutMessage(message), + messageAck: done, + } + + msgID := putMsg.message.GetID() + timer := time.NewTimer(enqueueTimeout) + defer timer.Stop() + + select { + case p.messagesCh <- putMsg: + case <-timer.C: + return "", ErrMessageTimedout + } + + return msgID, nil +} + +// publishBatch publishes the given batch of messages +// to cherami. On success, returns a slice of receipts +// where the order of the receipts is the same as +// the order of the given messages. If an error is +// encountered before publishing the whole batch, the +// returned receipts will be for a subset of messages. +// +// This func will return err != nil if and only if no +// messages can be published. +func (p *tchannelBatchPublisher) publishBatch(putMessages []*cherami.PutMessage) ([]*PublisherReceipt, error) { + + if p.isClosed() { + return nil, errPublisherClosed + } + + batchRequest := &cherami.PutMessageBatchRequest{ + DestinationPath: common.StringPtr(p.path), + Messages: putMessages, + } + + p.reporter.IncCounter(metrics.PublishMessageRate, nil, 1) + sw := p.reporter.StartTimer(metrics.PublishMessageLatency, nil) + + result, err := p.putMessageBatch(batchRequest) + sw.Stop() + if err != nil { + p.reporter.IncCounter(metrics.PublishMessageFailedRate, nil, 1) + return nil, err + } + + receipts := make([]*PublisherReceipt, len(putMessages)) + + if e := p.processAcks(result.GetSuccessMessages(), receipts); e != nil { + return nil, e + } + if e := p.processAcks(result.GetFailedMessages(), receipts); e != nil { + return nil, e + } + return receipts, nil +} + +func (p *tchannelBatchPublisher) putMessageBatch(request *cherami.PutMessageBatchRequest) (*cherami.PutMessageBatchResult_, error) { + ctx, cancel := thrift.NewContext(messageBatchThriftTimeout) + defer cancel() + return p.thriftClient.PutMessageBatch(ctx, request) +} + +// processAcks takes a set of acks received in response +// to putMessageBatch, converts them into receipts and +// stores them into the receipts slice. Stores receipt +// for ack with ID:id into receipts[id]. +func (p *tchannelBatchPublisher) processAcks(acks []*cherami.PutMessageAck, receipts []*PublisherReceipt) error { + for _, ack := range acks { + id, err := p.hexStrToMsgID(ack.GetID()) + if err != nil { + return err + } + if id < 0 || id >= len(receipts) { + p.logger.WithField(`id`, ack.GetID()).Error("putMessageBatch ack result contains invalid message id") + return errInvalidAckID + } + receipts[id] = &PublisherReceipt{ + ID: ack.GetID(), + Receipt: ack.GetReceipt(), + UserContext: ack.GetUserContext(), + } + if ack.GetStatus() != cherami.Status_OK { + receipts[id].Error = newPublishError(ack.GetStatus()) + } + } + return nil +} + +// processor is the main loop that dequeues +// messages and publishes them in batch +func (p *tchannelBatchPublisher) processor() { + + msgIDs := make([]string, maxBatchSize) // original message ids + putMessages := make([]*cherami.PutMessage, maxBatchSize) + ackChannels := make([]chan<- *PublisherReceipt, maxBatchSize) + + for { + + batchSz := 0 + + select { + case <-p.closeCh: + return + case m := <-p.messagesCh: + + msgLoop: + for m != nil { + + msgIDs[batchSz] = m.message.GetID() + putMessages[batchSz] = m.message + putMessages[batchSz].ID = common.StringPtr(p.msgIDToHexStr(batchSz)) + ackChannels[batchSz] = m.messageAck + batchSz++ + + if batchSz == maxBatchSize { + break msgLoop + } + + select { + case m = <-p.messagesCh: + default: + m = nil + } + } + + receipts, err := p.publishBatch(putMessages[:batchSz]) + + for i := 0; i < batchSz; i++ { + if err != nil { + ackChannels[i] <- &PublisherReceipt{Error: err} + } else { + receipts[i].ID = msgIDs[i] + ackChannels[i] <- receipts[i] + } + putMessages[i] = nil + ackChannels[i] = nil + msgIDs[i] = "" + } + } + } +} + +// reconfigureHandler re-disovers the publish endpoints +// and updates the tchannel peers list. +func (p *tchannelBatchPublisher) reconfigureHandler() { + + publisherOptions, err := p.readPublisherOptions() + if err != nil { + return + } + + if err != nil { + p.logger.Infof("Error resolving input hosts: %v", err) + if _, ok := err.(*cherami.EntityNotExistsError); ok { + // Destination is deleted. Continue with reconfigure + // remove all addrs from the peers list + publisherOptions = &cherami.ReadPublisherOptionsResult_{} + } else { + // This is a potentially a transient error. + // Retry on next reconfigure + return + } + } + + newEndpoints := make(map[string]struct{}, endpointsInitialSz) + _, addrs := p.choosePublishEndpoints(publisherOptions) + for _, addr := range addrs { + key := net.JoinHostPort(addr.GetHost(), inputServiceTChannelPort) + if _, ok := p.endpoints[key]; !ok { + p.tchan.Peers().Add(key) + } else { + delete(p.endpoints, key) + } + newEndpoints[key] = struct{}{} + } + + for addr := range p.endpoints { + p.tchan.Peers().Remove(addr) + } + + p.endpoints = newEndpoints + p.reporter.UpdateGauge(metrics.PublishNumConnections, nil, int64(len(p.endpoints))) +} + +func (p *tchannelBatchPublisher) isClosed() bool { + return atomic.LoadInt32(&p.closed) == 1 +} + +func (p *tchannelBatchPublisher) isOpened() bool { + return atomic.LoadInt32(&p.opened) == 1 +} + +func (p *tchannelBatchPublisher) drain() { + for { + select { + case m := <-p.messagesCh: + m.messageAck <- &PublisherReceipt{Error: errPublisherClosed} + default: + return + } + } +} + +func (p *tchannelBatchPublisher) msgIDToHexStr(id int) string { + switch { + case id >= 0 && id <= 9: + return string(byte('0') + byte(id)) + case id > 9 && id < 16: + return string(byte('A') + byte(id-10)) + default: + p.logger.WithField(`id`, id).Fatal("msgIDToHexStr() encountered invalid msgID") + } + return "" +} + +func (p *tchannelBatchPublisher) hexStrToMsgID(id string) (int, error) { + val := byte(id[0]) + switch { + case val >= '0' && val <= '9': + return int(val - byte('0')), nil + case val >= 'A' && val <= 'F': + return 10 + int(val-byte('A')), nil + default: + return 0, errInvalidAckID + } +} diff --git a/client/cherami/tchanPublisher_test.go b/client/cherami/tchanPublisher_test.go new file mode 100644 index 0000000..fb0da55 --- /dev/null +++ b/client/cherami/tchanPublisher_test.go @@ -0,0 +1,233 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "math/rand" + "sync" + "testing" + + log "github.com/Sirupsen/logrus" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/uber-common/bark" + + "github.com/uber/cherami-thrift/.generated/go/cherami" + "github.com/uber/cherami-client-go/common" + "github.com/uber/cherami-client-go/common/metrics" + mc "github.com/uber/cherami-client-go/mocks/clients/cherami" +) + +type ( + TChanBatchPublisherSuite struct { + *require.Assertions // override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test, not merely log an error + suite.Suite + client *mockClient + logger bark.Logger + publisher *tchannelBatchPublisher + } +) + +func TestTChanBatchPublisherSuite(t *testing.T) { + suite.Run(t, new(TChanBatchPublisherSuite)) +} + +func (s *TChanBatchPublisherSuite) SetupTest() { + s.Assertions = require.New(s.T()) // Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil + s.logger = bark.NewLoggerFromLogrus(log.StandardLogger()) + s.client = new(mockClient) + s.publisher = newTChannelBatchPublisher(s.client, "/test/tchanBatchPublisher", s.logger, metrics.NewNullReporter()).(*tchannelBatchPublisher) + s.publisher.opened = 1 +} + +func (s *TChanBatchPublisherSuite) TestBatchSzWithinRange() { + s.True(maxBatchSize > 0, "invalid maxBatchSz") + s.True(maxBatchSize <= 16, "invalid maxBatchSz, cannot be greater than 16") +} + +func (s *TChanBatchPublisherSuite) TestPublishBatchSuccess() { + for _, sz := range []int{1, 16, 16, 5} { + ackIDs := make([]string, common.MinInt(16, sz)) + for i := range ackIDs { + ackIDs[i] = s.publisher.msgIDToHexStr(i) + } + result := newBatchResult(ackIDs, []string{}) + mockInput := new(mc.MockTChanBInClient) + mockInput.On("PutMessageBatch", mock.Anything, mock.Anything).Return(result, nil).Times((sz + 15) / 16) + s.publisher.thriftClient = mockInput + + msg := newPublisherMessage() + ids := make([]string, sz) + ackChs := make([]chan *PublisherReceipt, sz) + for i := range ackIDs { + ackChs[i] = make(chan *PublisherReceipt, 1) + id, err := s.publisher.PublishAsync(msg, ackChs[i]) + s.Nil(err, "publishAsync returned unexpected error") + ids[i] = id + } + + s.publisher.closeCh = make(chan struct{}) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + s.publisher.processor() + wg.Done() + }() + + for i, ch := range ackChs { + receipt := <-ch + s.Nil(receipt.Error, "publishBatch receipt contains unexpected error") + s.Equal(ids[i], receipt.ID, "publishBatch receipt contains invalid message id") + } + + close(s.publisher.closeCh) + wg.Wait() + } +} + +func (s *TChanBatchPublisherSuite) TestPublishBatchFailure() { + for _, sz := range []int{1, 16, 16, 5} { + ackIDs := make([]string, common.MinInt(16, sz)) + for i := range ackIDs { + ackIDs[i] = s.publisher.msgIDToHexStr(i) + } + result := newBatchResult([]string{}, ackIDs) + mockInput := new(mc.MockTChanBInClient) + mockInput.On("PutMessageBatch", mock.Anything, mock.Anything).Return(result, nil).Times((sz + 15) / 16) + s.publisher.thriftClient = mockInput + + msg := newPublisherMessage() + ids := make([]string, sz) + ackChs := make([]chan *PublisherReceipt, sz) + for i := range ackIDs { + ackChs[i] = make(chan *PublisherReceipt, 1) + id, err := s.publisher.PublishAsync(msg, ackChs[i]) + s.Nil(err, "publishAsync returned unexpected error") + ids[i] = id + } + + s.publisher.closeCh = make(chan struct{}) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + s.publisher.processor() + wg.Done() + }() + + for i, ch := range ackChs { + receipt := <-ch + s.NotNil(receipt.Error, "publishBatch receipt contains unexpected error") + s.Equal(ids[i], receipt.ID, "publishBatch receipt contains invalid message id") + } + + close(s.publisher.closeCh) + wg.Wait() + } +} + +func (s *TChanBatchPublisherSuite) TestPublishBatchPartialFailure() { + sz := 16 + succIDs := make([]string, 0, common.MinInt(8, sz)) + failIDs := make([]string, 0, common.MinInt(8, sz)) + for i := 0; i < sz; i++ { + if i%2 == 0 { + succIDs = append(succIDs, s.publisher.msgIDToHexStr(i)) + continue + } + failIDs = append(failIDs, s.publisher.msgIDToHexStr(i)) + } + + result := newBatchResult(succIDs, failIDs) + mockInput := new(mc.MockTChanBInClient) + mockInput.On("PutMessageBatch", mock.Anything, mock.Anything).Return(result, nil).Times((sz + 15) / 16) + s.publisher.thriftClient = mockInput + + msg := newPublisherMessage() + ids := make([]string, sz) + ackChs := make([]chan *PublisherReceipt, sz) + for i := 0; i < sz; i++ { + ackChs[i] = make(chan *PublisherReceipt, 1) + id, err := s.publisher.PublishAsync(msg, ackChs[i]) + s.Nil(err, "publishAsync returned unexpected error") + ids[i] = id + } + + s.publisher.closeCh = make(chan struct{}) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + s.publisher.processor() + wg.Done() + }() + + for i, ch := range ackChs { + receipt := <-ch + s.Equal(ids[i], receipt.ID, "publishBatch receipt contains invalid message id") + if i%2 == 0 { + s.Nil(receipt.Error, "publishBatch receipt contains unexpected error") + } else { + s.NotNil(receipt.Error, "publishBatch receipt contains unexpected error") + } + } + + close(s.publisher.closeCh) + wg.Wait() +} + +func newBatchResult(succIDs []string, failIDs []string) *cherami.PutMessageBatchResult_ { + result := cherami.NewPutMessageBatchResult_() + for _, id := range shuffle(succIDs) { + ack := cherami.NewPutMessageAck() + ack.Status = cherami.StatusPtr(cherami.Status_OK) + ack.ID = common.StringPtr(id) + ack.Receipt = common.StringPtr(id) + result.SuccessMessages = append(result.SuccessMessages, ack) + } + for _, id := range shuffle(failIDs) { + ack := cherami.NewPutMessageAck() + ack.Status = cherami.StatusPtr(cherami.Status_FAILED) + ack.Message = common.StringPtr("Internal service error") + ack.ID = common.StringPtr(id) + ack.Receipt = common.StringPtr(id) + result.FailedMessages = append(result.FailedMessages, ack) + } + return result +} + +func shuffle(input []string) []string { + perms := rand.Perm(len(input)) + output := make([]string, len(input)) + for i, srcIdx := range perms { + output[i] = input[srcIdx] + } + return output +} + +func newPublisherMessage() *PublisherMessage { + return &PublisherMessage{ + Data: []byte("test"), + UserContext: map[string]string{"test-name": "tchanelBatchPublisher"}, + } +} diff --git a/client/cherami/wsconnector.go b/client/cherami/wsconnector.go new file mode 100644 index 0000000..26d13fd --- /dev/null +++ b/client/cherami/wsconnector.go @@ -0,0 +1,156 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "fmt" + "net/http" + "reflect" + + "github.com/uber/cherami-thrift/.generated/go/cherami" + "github.com/uber/cherami-client-go/common" + "github.com/uber/cherami-client-go/common/websocket" + "github.com/uber/cherami-client-go/stream" +) + +type ( + // WSConnector takes care of establishing connection via websocket stream + WSConnector interface { + OpenPublisherStream(hostPort string, requestHeader http.Header) (stream.BInOpenPublisherStreamOutCall, error) + OpenConsumerStream(hostPort string, requestHeader http.Header) (stream.BOutOpenConsumerStreamOutCall, error) + } + + wsConnectorImpl struct { + wsHub websocket.Hub + + openPublisherOutStreamReadType reflect.Type + openConsumerOutStreamReadType reflect.Type + } + + // OpenPublisherOutWebsocketStream is a wrapper for websocket to work with OpenPublisherStream + OpenPublisherOutWebsocketStream struct { + stream websocket.StreamClient + } + + // OpenConsumerOutWebsocketStream is a wrapper for websocket to work with OpenPublisherStream + OpenConsumerOutWebsocketStream struct { + stream websocket.StreamClient + } +) + +// interface implementation check +var _ WSConnector = (*wsConnectorImpl)(nil) +var _ stream.BInOpenPublisherStreamOutCall = &OpenPublisherOutWebsocketStream{} +var _ stream.BOutOpenConsumerStreamOutCall = &OpenConsumerOutWebsocketStream{} + +// NewWSConnector creates a WSConnector +func NewWSConnector() WSConnector { + return &wsConnectorImpl{ + wsHub: websocket.NewWebsocketHub(), + openPublisherOutStreamReadType: reflect.TypeOf((*cherami.InputHostCommand)(nil)).Elem(), + openConsumerOutStreamReadType: reflect.TypeOf((*cherami.OutputHostCommand)(nil)).Elem(), + } +} + +func (c *wsConnectorImpl) OpenPublisherStream(hostPort string, requestHeader http.Header) (stream.BInOpenPublisherStreamOutCall, error) { + url := fmt.Sprintf(common.WSUrlFormat, hostPort, common.EndpointOpenPublisherStream) + + stream := websocket.NewStreamClient(url, requestHeader, c.wsHub, c.openPublisherOutStreamReadType) + + if err := stream.Start(); err != nil { + return nil, err + } + + return &OpenPublisherOutWebsocketStream{stream: stream}, nil +} + +// Write writes a result to the response stream +func (s *OpenPublisherOutWebsocketStream) Write(arg *cherami.PutMessage) error { + return s.stream.Write(arg) +} + +// Flush flushes all written arguments. +func (s *OpenPublisherOutWebsocketStream) Flush() error { + return s.stream.Flush() +} + +// Done closes the request stream and should be called after all arguments have been written. +func (s *OpenPublisherOutWebsocketStream) Done() error { + return s.stream.Done() +} + +// Read returns the next argument, if any is available. +func (s *OpenPublisherOutWebsocketStream) Read() (*cherami.InputHostCommand, error) { + + msg, err := s.stream.Read() + if err != nil { + return nil, err + } + return msg.(*cherami.InputHostCommand), err +} + +// ResponseHeaders is defined to conform to the tchannel-stream .*OutCall interface +func (s *OpenPublisherOutWebsocketStream) ResponseHeaders() (map[string]string, error) { + return map[string]string{}, nil +} + +func (c *wsConnectorImpl) OpenConsumerStream(hostPort string, requestHeader http.Header) (stream.BOutOpenConsumerStreamOutCall, error) { + + url := fmt.Sprintf(common.WSUrlFormat, hostPort, common.EndpointOpenConsumerStream) + + stream := websocket.NewStreamClient(url, requestHeader, c.wsHub, c.openConsumerOutStreamReadType) + + if err := stream.Start(); err != nil { + return nil, err + } + + return &OpenConsumerOutWebsocketStream{stream: stream}, nil +} + +// Write writes a result to the response stream +func (s *OpenConsumerOutWebsocketStream) Write(arg *cherami.ControlFlow) error { + return s.stream.Write(arg) +} + +// Flush flushes all written arguments. +func (s *OpenConsumerOutWebsocketStream) Flush() error { + return s.stream.Flush() +} + +// Done closes the request stream and should be called after all arguments have been written. +func (s *OpenConsumerOutWebsocketStream) Done() error { + return s.stream.Done() +} + +// Read returns the next argument, if any is available. +func (s *OpenConsumerOutWebsocketStream) Read() (*cherami.OutputHostCommand, error) { + + msg, err := s.stream.Read() + if err != nil { + return nil, err + } + return msg.(*cherami.OutputHostCommand), err +} + +// ResponseHeaders is defined to conform to the tchannel-stream .*OutCall interface +func (s *OpenConsumerOutWebsocketStream) ResponseHeaders() (map[string]string, error) { + return map[string]string{}, nil +} diff --git a/common/backoff/retry.go b/common/backoff/retry.go new file mode 100644 index 0000000..7ce7098 --- /dev/null +++ b/common/backoff/retry.go @@ -0,0 +1,69 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package backoff + +import "time" + +type ( + // Operation to retry + Operation func() error + + // IsRetryable handler can be used to exclude certain errors during retry + IsRetryable func(error) bool +) + +// Retry function can be used to wrap any call with retry logic using the passed in policy +func Retry(operation Operation, policy RetryPolicy, isRetryable IsRetryable) error { + var err error + var next time.Duration + + r := NewRetrier(policy, SystemClock) + for { + // operation completed successfully. No need to retry. + if err = operation(); err == nil { + return nil + } + + if next = r.NextBackOff(); next == done { + return err + } + + // Check if the error is retryable + if isRetryable != nil && !isRetryable(err) { + return err + } + + time.Sleep(next) + } +} + +// IgnoreErrors can be used as IsRetryable handler for Retry function to exclude certain errors from the retry list +func IgnoreErrors(errorsToExclude []error) func(error) bool { + return func(err error) bool { + for _, errorToExclude := range errorsToExclude { + if err == errorToExclude { + return false + } + } + + return true + } +} diff --git a/common/backoff/retry_test.go b/common/backoff/retry_test.go new file mode 100644 index 0000000..c2723f9 --- /dev/null +++ b/common/backoff/retry_test.go @@ -0,0 +1,141 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package backoff + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type ( + RetrySuite struct { + *require.Assertions // override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test, not merely log an error + suite.Suite + } + + someError struct{} +) + +func TestRetrySuite(t *testing.T) { + suite.Run(t, new(RetrySuite)) +} + +func (s *RetrySuite) SetupTest() { + s.Assertions = require.New(s.T()) // Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil +} + +func (s *RetrySuite) TestRetrySuccess() { + i := 0 + op := func() error { + i++ + + if i == 5 { + return nil + } + + return &someError{} + } + + policy := NewExponentialRetryPolicy(1 * time.Millisecond) + policy.SetMaximumInterval(5 * time.Millisecond) + policy.SetMaximumAttempts(10) + + err := Retry(op, policy, nil) + s.NoError(err) + s.Equal(5, i) +} + +func (s *RetrySuite) TestRetryFailed() { + i := 0 + op := func() error { + i++ + + if i == 7 { + return nil + } + + return &someError{} + } + + policy := NewExponentialRetryPolicy(1 * time.Millisecond) + policy.SetMaximumInterval(5 * time.Millisecond) + policy.SetMaximumAttempts(5) + + err := Retry(op, policy, nil) + s.Error(err) +} + +func (s *RetrySuite) TestIsRetryableSuccess() { + i := 0 + op := func() error { + i++ + + if i == 5 { + return nil + } + + return &someError{} + } + + isRetryable := func(err error) bool { + if _, ok := err.(*someError); ok { + return true + } + + return false + } + + policy := NewExponentialRetryPolicy(1 * time.Millisecond) + policy.SetMaximumInterval(5 * time.Millisecond) + policy.SetMaximumAttempts(10) + + err := Retry(op, policy, isRetryable) + s.NoError(err, "Retry count: %v", i) + s.Equal(5, i) +} + +func (s *RetrySuite) TestIsRetryableFailure() { + i := 0 + op := func() error { + i++ + + if i == 5 { + return nil + } + + return &someError{} + } + + policy := NewExponentialRetryPolicy(1 * time.Millisecond) + policy.SetMaximumInterval(5 * time.Millisecond) + policy.SetMaximumAttempts(10) + + err := Retry(op, policy, IgnoreErrors([]error{&someError{}})) + s.Error(err) + s.Equal(1, i) +} + +func (e *someError) Error() string { + return "Some Error" +} diff --git a/common/backoff/retrypolicy.go b/common/backoff/retrypolicy.go new file mode 100644 index 0000000..77b261d --- /dev/null +++ b/common/backoff/retrypolicy.go @@ -0,0 +1,197 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package backoff + +import ( + "math" + "math/rand" + "time" +) + +const ( + done time.Duration = -1 + noMaximumAttempts = 0 + noInterval = 0 + + defaultBackoffCoefficient = 2.0 + defaultMaximumInterval = 10 * time.Second + defaultExpirationInterval = time.Minute + defaultMaximumAttempts = noMaximumAttempts +) + +type ( + // RetryPolicy is the API which needs to be implemented by various retry policy implementations + RetryPolicy interface { + ComputeNextDelay(elapsedTime time.Duration, numAttempts int) time.Duration + } + + // Retrier manages the state of retry operation + Retrier interface { + NextBackOff() time.Duration + Reset() + } + + // Clock used by ExponentialRetryPolicy implementation to get the current time. Mainly used for unit testing + Clock interface { + Now() time.Time + } + + // ExponentialRetryPolicy provides the implementation for retry policy using a coefficient to compute the next delay. + // Formula used to compute the next delay is: initialInterval * math.Pow(backoffCoefficient, currentAttempt) + ExponentialRetryPolicy struct { + initialInterval time.Duration + backoffCoefficient float64 + maximumInterval time.Duration + expirationInterval time.Duration + maximumAttempts int + } + + systemClock struct{} + + retrierImpl struct { + policy RetryPolicy + clock Clock + currentAttempt int + startTime time.Time + } +) + +// SystemClock implements Clock interface that uses time.Now(). +var SystemClock = systemClock{} + +// NewExponentialRetryPolicy returns an instance of ExponentialRetryPolicy using the provided initialInterval +func NewExponentialRetryPolicy(initialInterval time.Duration) *ExponentialRetryPolicy { + p := &ExponentialRetryPolicy{ + initialInterval: initialInterval, + backoffCoefficient: defaultBackoffCoefficient, + maximumInterval: defaultMaximumInterval, + expirationInterval: defaultExpirationInterval, + maximumAttempts: defaultMaximumAttempts, + } + + return p +} + +// NewRetrier is used for creating a new instance of Retrier +func NewRetrier(policy RetryPolicy, clock Clock) Retrier { + return &retrierImpl{ + policy: policy, + clock: clock, + startTime: clock.Now(), + currentAttempt: 0, + } +} + +// SetInitialInterval sets the initial interval used by ExponentialRetryPolicy for the very first retry +// All later retries are computed using the following formula: +// initialInterval * math.Pow(backoffCoefficient, currentAttempt) +func (p *ExponentialRetryPolicy) SetInitialInterval(initialInterval time.Duration) { + p.initialInterval = initialInterval +} + +// SetBackoffCoefficient sets the coefficient used by ExponentialRetryPolicy to compute next delay for each retry +// All retries are computed using the following formula: +// initialInterval * math.Pow(backoffCoefficient, currentAttempt) +func (p *ExponentialRetryPolicy) SetBackoffCoefficient(backoffCoefficient float64) { + p.backoffCoefficient = backoffCoefficient +} + +// SetMaximumInterval sets the maximum interval for each retry +func (p *ExponentialRetryPolicy) SetMaximumInterval(maximumInterval time.Duration) { + p.maximumInterval = maximumInterval +} + +// SetExpirationInterval sets the absolute expiration interval for all retries +func (p *ExponentialRetryPolicy) SetExpirationInterval(expirationInterval time.Duration) { + p.expirationInterval = expirationInterval +} + +// SetMaximumAttempts sets the maximum number of retry attempts +func (p *ExponentialRetryPolicy) SetMaximumAttempts(maximumAttempts int) { + p.maximumAttempts = maximumAttempts +} + +// ComputeNextDelay returns the next delay interval. This is used by Retrier to delay calling the operation again +func (p *ExponentialRetryPolicy) ComputeNextDelay(elapsedTime time.Duration, numAttempts int) time.Duration { + // Check to see if we ran out of maximum number of attempts + if p.maximumAttempts != noMaximumAttempts && numAttempts >= p.maximumAttempts { + return done + } + + // Stop retrying after expiration interval is elasped + if p.expirationInterval != noInterval && elapsedTime > p.expirationInterval { + return done + } + + nextInterval := float64(p.initialInterval) * math.Pow(p.backoffCoefficient, float64(numAttempts)) + // Disallow retries if initialInterval is negative or nextInterval overflows + if nextInterval <= 0 { + return done + } + if p.maximumInterval != noInterval { + nextInterval = math.Min(nextInterval, float64(p.maximumInterval)) + } + + if p.expirationInterval != noInterval { + remainingTime := float64(math.Max(0, float64(p.expirationInterval-elapsedTime))) + nextInterval = math.Min(remainingTime, nextInterval) + } + + // Bail out if the next interval is smaller than initial retry interval + nextDuration := time.Duration(nextInterval) + if nextDuration < p.initialInterval { + return done + } + + // add jitter to avoid global synchronization + jitterPortion := int(0.2 * nextInterval) + // Prevent overflow + if jitterPortion < 1 { + jitterPortion = 1 + } + nextInterval = nextInterval*0.8 + float64(rand.Intn(jitterPortion)) + + return time.Duration(nextInterval) +} + +// Now returns the current time using the system clock +func (t systemClock) Now() time.Time { + return time.Now() +} + +// Reset will set the Retrier into initial state +func (r *retrierImpl) Reset() { + r.startTime = r.clock.Now() + r.currentAttempt = 0 +} + +// NextBackOff returns the next delay interval. This is used by Retry to delay calling the operation again +func (r *retrierImpl) NextBackOff() time.Duration { + nextInterval := r.policy.ComputeNextDelay(r.getElapsedTime(), r.currentAttempt) + + // Now increment the current attempt + r.currentAttempt++ + return nextInterval +} + +func (r *retrierImpl) getElapsedTime() time.Duration { + return r.clock.Now().Sub(r.startTime) +} diff --git a/common/backoff/retrypolicy_test.go b/common/backoff/retrypolicy_test.go new file mode 100644 index 0000000..6797383 --- /dev/null +++ b/common/backoff/retrypolicy_test.go @@ -0,0 +1,231 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package backoff + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type ( + RetryPolicySuite struct { + *require.Assertions // override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test, not merely log an error + suite.Suite + } + + TestClock struct { + currentTime time.Time + } +) + +func TestRetryPolicySuite(t *testing.T) { + suite.Run(t, new(RetryPolicySuite)) +} + +func (s *RetryPolicySuite) SetupTest() { + s.Assertions = require.New(s.T()) // Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil +} + +func (s *RetryPolicySuite) TestExponentialBackoff() { + policy := createPolicy(time.Second) + policy.SetMaximumInterval(10 * time.Second) + + expectedResult := []time.Duration{1, 2, 4, 8, 10} + for i, d := range expectedResult { + expectedResult[i] = d * time.Second + } + + r, _ := createRetrier(policy) + for _, expected := range expectedResult { + min, max := getNextBackoffRange(expected) + next := r.NextBackOff() + s.True(next >= min, "NextBackoff too low") + s.True(next < max, "NextBackoff too high") + } +} + +func (s *RetryPolicySuite) TestNumberOfAttempts() { + policy := createPolicy(time.Second) + policy.SetMaximumAttempts(5) + + r, _ := createRetrier(policy) + var next time.Duration + for i := 0; i < 6; i++ { + next = r.NextBackOff() + } + + s.Equal(done, next) +} + +// Test to make sure relative maximum interval for each retry is honoured +func (s *RetryPolicySuite) TestMaximumInterval() { + policy := createPolicy(time.Second) + policy.SetMaximumInterval(10 * time.Second) + + expectedResult := []time.Duration{1, 2, 4, 8, 10, 10, 10, 10, 10, 10} + for i, d := range expectedResult { + expectedResult[i] = d * time.Second + } + + r, _ := createRetrier(policy) + for _, expected := range expectedResult { + min, max := getNextBackoffRange(expected) + next := r.NextBackOff() + s.True(next >= min, "NextBackoff too low") + s.True(next < max, "NextBackoff too high") + } +} + +func (s *RetryPolicySuite) TestBackoffCoefficient() { + policy := createPolicy(2 * time.Second) + policy.SetBackoffCoefficient(1.0) + + r, _ := createRetrier(policy) + min, max := getNextBackoffRange(2 * time.Second) + for i := 0; i < 10; i++ { + next := r.NextBackOff() + s.True(next >= min, "NextBackoff too low") + s.True(next < max, "NextBackoff too high") + } +} + +func (s *RetryPolicySuite) TestExpirationInterval() { + policy := createPolicy(2 * time.Second) + policy.SetExpirationInterval(5 * time.Minute) + + r, clock := createRetrier(policy) + clock.moveClock(6 * time.Minute) + next := r.NextBackOff() + + s.Equal(done, next) +} + +func (s *RetryPolicySuite) TestExpirationOverflow() { + policy := createPolicy(2 * time.Second) + policy.SetExpirationInterval(5 * time.Second) + + r, clock := createRetrier(policy) + next := r.NextBackOff() + min, max := getNextBackoffRange(2 * time.Second) + s.True(next >= min, "NextBackoff too low") + s.True(next < max, "NextBackoff too high") + + clock.moveClock(2 * time.Second) + + next = r.NextBackOff() + min, max = getNextBackoffRange(3 * time.Second) + s.True(next >= min, "NextBackoff too low") + s.True(next < max, "NextBackoff too high") +} + +func (s *RetryPolicySuite) TestDefaultPublishRetryPolicy() { + policy := NewExponentialRetryPolicy(50 * time.Millisecond) + policy.SetExpirationInterval(time.Minute) + policy.SetMaximumInterval(10 * time.Second) + + r, clock := createRetrier(policy) + expectedResult := []time.Duration{ + 50 * time.Millisecond, + 100 * time.Millisecond, + 200 * time.Millisecond, + 400 * time.Millisecond, + 800 * time.Millisecond, + 1600 * time.Millisecond, + 3200 * time.Millisecond, + 6400 * time.Millisecond, + 10000 * time.Millisecond, + 10000 * time.Millisecond, + 10000 * time.Millisecond, + 10000 * time.Millisecond, + 6000 * time.Millisecond, + 1300 * time.Millisecond, + done, + } + + for _, expected := range expectedResult { + next := r.NextBackOff() + if expected == done { + s.Equal(done, next, "backoff not done yet!!!") + } else { + min, max := getNextBackoffRange(expected) + s.True(next >= min, "NextBackoff too low: actual: %v, expected: %v", next, expected) + s.True(next < max, "NextBackoff too high: actual: %v, expected: %v", next, expected) + clock.moveClock(expected) + } + } +} + +func (s *RetryPolicySuite) TestNoMaxAttempts() { + policy := createPolicy(50 * time.Millisecond) + policy.SetExpirationInterval(time.Minute) + policy.SetMaximumInterval(10 * time.Second) + + r, clock := createRetrier(policy) + for i := 0; i < 100; i++ { + next := r.NextBackOff() + //print("Iter: ", i, ", Next Backoff: ", next.String(), "\n") + s.True(next > 0 || next == done, "Unexpected value for next retry duration: %v", next) + clock.moveClock(next) + } +} + +func (s *RetryPolicySuite) TestUnbounded() { + policy := createPolicy(50 * time.Millisecond) + + r, clock := createRetrier(policy) + for i := 0; i < 100; i++ { + next := r.NextBackOff() + //print("Iter: ", i, ", Next Backoff: ", next.String(), "\n") + s.True(next > 0 || next == done, "Unexpected value for next retry duration: %v", next) + clock.moveClock(next) + } +} + +func (c *TestClock) Now() time.Time { + return c.currentTime +} + +func (c *TestClock) moveClock(duration time.Duration) { + c.currentTime = c.currentTime.Add(duration) +} + +func createPolicy(initialInterval time.Duration) *ExponentialRetryPolicy { + policy := NewExponentialRetryPolicy(initialInterval) + policy.SetBackoffCoefficient(2) + policy.SetMaximumInterval(noInterval) + policy.SetExpirationInterval(noInterval) + policy.SetMaximumAttempts(noMaximumAttempts) + + return policy +} + +func createRetrier(policy RetryPolicy) (Retrier, *TestClock) { + clock := &TestClock{currentTime: time.Time{}} + return NewRetrier(policy, clock), clock +} + +func getNextBackoffRange(duration time.Duration) (time.Duration, time.Duration) { + rangeMin := time.Duration(0.8 * float64(duration)) + return rangeMin, duration +} diff --git a/common/constants.go b/common/constants.go new file mode 100644 index 0000000..a87f4cd --- /dev/null +++ b/common/constants.go @@ -0,0 +1,77 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package common + +import "math" + +const ( + // ClientVersion is current client library's version + // It uses Semantic Versions MAJOR.MINOR.PATCH (https://blog.gopheracademy.com/advent-2015/semver/) + // ClientVersion needs to be updated to reflect client library changes: + // 1. MAJOR version when you make incompatible API changes, + // 2. MINOR version when you add functionality in a backwards-compatible manner, and + // 3. PATCH version when you make backwards-compatible bug fixes. + ClientVersion = "0.1.0" + // HeaderClientVersion is the name of thrift context header contains client version + HeaderClientVersion = "client-version" + // HeaderUserName is the name of thrift context header contains current user name + HeaderUserName = "user-name" + // HeaderHostName is the name of thrift context header contains current host name + HeaderHostName = "host-name" + + // SequenceBegin refers to the beginning of an extent + SequenceBegin = 0 + // SequenceEnd refers to the end of an extent + SequenceEnd = math.MaxInt64 + + // UUIDStringLength is the length of an UUID represented as a hex string + UUIDStringLength = 36 // xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + + // InputServiceName refers to the name of the cherami in service + InputServiceName = "cherami-inputhost" + // OutputServiceName refers to the name of the cherami out service + OutputServiceName = "cherami-outputhost" + // FrontendServiceName refers to the name of the cherami frontend service + FrontendServiceName = "cherami-frontendhost" + // FrontendStagingServiceName refers to the name of the cherami staging frontend service + FrontendStagingServiceName = "cherami-frontendhost-staging" + // ControllerServiceName refers to the name of the cherami controller service + ControllerServiceName = "cherami-controllerhost" + // StoreServiceName refers to the name of the cherami store service + StoreServiceName = "cherami-storehost" + + // EndpointOpenPublisherStream is websocket endpoint name for OpenPublisherStream + EndpointOpenPublisherStream = "open_publisher_stream" + // EndpointOpenConsumerStream is websocket endpoint name for OpenConsumerStream + EndpointOpenConsumerStream = "open_consumer_stream" + // EndpointOpenAppendStream is websocket endpoint name for OpenAppendStream + EndpointOpenAppendStream = "open_append_stream" + // EndpointOpenReadStream is websocket endpoint name for OpenReadStream + EndpointOpenReadStream = "open_read_stream" + // EndpointOpenReplicationRemoteReadStream is websocket endpoint name for OpenReplicationRemoteReadStream + EndpointOpenReplicationRemoteReadStream = "open_replication_remote_read_stream" + // EndpointOpenReplicationReadStream is websocket endpoint name for OpenReplicationReadStream + EndpointOpenReplicationReadStream = "open_replication_read_stream" + // HTTPHandlerPattern is pattern format for http handler, eg "/endpoint" + HTTPHandlerPattern = "/%s" + // WSUrlFormat is url format for websocket, eg "ws://host:port/endpoint" + WSUrlFormat = "ws://%s/%s" +) diff --git a/common/convert.go b/common/convert.go new file mode 100644 index 0000000..7caa9aa --- /dev/null +++ b/common/convert.go @@ -0,0 +1,98 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package common + +import "github.com/uber/cherami-thrift/.generated/go/cherami" + +// IntPtr makes a copy and returns the pointer to an int. +func IntPtr(v int) *int { + return &v +} + +// Int32Ptr makes a copy and returns the pointer to an int32. +func Int32Ptr(v int32) *int32 { + return &v +} + +// Int64Ptr makes a copy and returns the pointer to an int64. +func Int64Ptr(v int64) *int64 { + return &v +} + +// Uint32Ptr makes a copy and returns the pointer to a uint32. +func Uint32Ptr(v uint32) *uint32 { + return &v +} + +// Uint64Ptr makes a copy and returns the pointer to a uint64. +func Uint64Ptr(v uint64) *uint64 { + return &v +} + +// Float64Ptr makes a copy and returns the pointer to an int64. +func Float64Ptr(v float64) *float64 { + return &v +} + +// BoolPtr makes a copy and returns the pointer to a bool. +func BoolPtr(v bool) *bool { + return &v +} + +// StringPtr makes a copy and returns the pointer to a string. +func StringPtr(v string) *string { + return &v +} + +// CheramiStatusPtr makes a copy and returns the pointer to a CheramiStatus. +func CheramiStatusPtr(status cherami.Status) *cherami.Status { + return &status +} + +// CheramiInputHostCommandTypePtr makes a copy and returns the pointer to a +// CheramiInputHostCommandType. +func CheramiInputHostCommandTypePtr(cmdType cherami.InputHostCommandType) *cherami.InputHostCommandType { + return &cmdType +} + +// CheramiOutputHostCommandTypePtr makes a copy and returns the pointer to a +// CheramiOutputHostCommandType. +func CheramiOutputHostCommandTypePtr(cmdType cherami.OutputHostCommandType) *cherami.OutputHostCommandType { + return &cmdType +} + +// CheramiDestinationTypePtr makes a copy and returns the pointer to a +// CheramiDestinationType. +func CheramiDestinationTypePtr(destType cherami.DestinationType) *cherami.DestinationType { + return &destType +} + +// CheramiDestinationStatusPtr makes a copy and returns the pointer to a +// CheramiDestinationStatus. +func CheramiDestinationStatusPtr(status cherami.DestinationStatus) *cherami.DestinationStatus { + return &status +} + +// CheramiConsumerGroupStatusPtr makes a copy and returns the pointer to a +// CheramiConsumerGroupStatus. +func CheramiConsumerGroupStatusPtr(status cherami.ConsumerGroupStatus) *cherami.ConsumerGroupStatus { + return &status +} diff --git a/common/log.go b/common/log.go new file mode 100644 index 0000000..d8c60ba --- /dev/null +++ b/common/log.go @@ -0,0 +1,235 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package common + +import ( + "fmt" + "regexp" + "strings" +) + +// TagErr is the tag for error object message +const TagErr = `err` + +// TagDst is the tag for Destination UUID +const TagDst = `destID` + +// TagCnsm is the logging tag for Consumer Group UUID +const TagCnsm = `cnsmID` + +// TagExt is the logging tag for Extent UUID +const TagExt = `extnID` + +// TagIn is the logging tag for Inputhost UUID +const TagIn = `inhoID` + +// TagOut is the logging tag for Outputhost UUID +const TagOut = `outhID` + +// TagCtrl is the logging tag for Extent Controller UUID +const TagCtrl = `ctrlID` + +// TagFrnt is the logging tag for Frontend UUID +const TagFrnt = `frntID` + +// TagStor is the logging tag for StoreHost UUID +const TagStor = `storID` + +// TagDstPth is the logging tag for Destination Path +const TagDstPth = `dstPth` + +// TagCnsPth is the logging tag for Consumer group Path +const TagCnsPth = `cnsPth` + +// TagMsgID is the logging tag for MsgId +const TagMsgID = `msgID` + +// TagAckID is the logging tag for AckId +const TagAckID = `ackID` + +// TagHostIP is the logging tag for host IP +const TagHostIP = `hostIP` + +// TagReconfigureID is the logging tag for reconfiguration identifiers +const TagReconfigureID = `reconfigID` + +// TagDLQID is the logging tag for a Dead Letter Queue destination UUID +const TagDLQID = `dlqID` + +// TagReconfigureType is the logging tag for reconfiguration type +const TagReconfigureType = `reconfigType` + +// TagInPutAckID is the logging tag for PutMessageAck ID +const TagInPutAckID = `inPutAckID` + +// TagInPubConnID is the logging tag for input pubconnection ID +const TagInPubConnID = `inPubConnID` + +// TagInReplicaHost is the logging tag for replica host on input +const TagInReplicaHost = `inReplicaHost` + +// TagUpdateUUID is the logging tag for reconfiguration update UUIDs +const TagUpdateUUID = `updateUUID` + +// TagService is the log tag for the service +const TagService = "service" + +// TagHostPort is the log tag for hostport +const TagHostPort = "hostport" + +// TagHosts is the log tag for list of hosts +const TagHosts = "hosts" + +const checkFormatAndPanic = true // TODO : Disable in production + +var longLowercaseGUIDRegex = regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`) +var shortLowercaseGUIDRegex = regexp.MustCompile(`^[0-9a-f]{8}$`) + +// Create a shortened, lowercased GUID suitable for low-volume object like host IDs. If there are +// more than 9,200 active objects with random GUIDs, then the birthday problem indicates that there +// will be a >1% chance of a collision +// Ex: `354754BD-B73E-4D20-8021-AB93A3D145C0` => `354754bd` +func fmtShortGUID(s string) string { + s = ShortenGUIDString(strings.ToLower(s)) + + if checkFormatAndPanic { + if !shortLowercaseGUIDRegex.MatchString(s) { + panic(fmt.Errorf("Format error on string %#q", s)) + } + } + + return s +} + +// Create a lowercased GUID. Panics if checkFormatAndPanic is enabled and the final GUID doesn't +// match the regular expression +// ex: `354754BD-B73E-4D20-8021-AB93A3D145C0` => `354754bd-b73e-4d20-8021-ab93a3d145c0` +func fmtGUID(s string) string { + s = strings.ToLower(s) + + if checkFormatAndPanic { + if !longLowercaseGUIDRegex.MatchString(s) { + panic(fmt.Errorf("Format error on string %#q", s)) + } + } + + return s +} + +// FmtDst formats a string to be used with TagDst +func FmtDst(s string) string { + return fmtGUID(s) +} + +// FmtCnsm formats a string to be used with TagCnsm +func FmtCnsm(s string) string { + return fmtGUID(s) +} + +// FmtExt formats a string to be used with TagExt +func FmtExt(s string) string { + return fmtGUID(s) +} + +// FmtIn formats a string to be used with TagIn +func FmtIn(s string) string { + return fmtShortGUID(s) +} + +// FmtOut formats a string to be used with TagOut +func FmtOut(s string) string { + return fmtShortGUID(s) +} + +// FmtCtrl formats a string to be used with TagCtrl +func FmtCtrl(s string) string { + return fmtShortGUID(s) +} + +// FmtFrnt formats a string to be used with TagFrnt +func FmtFrnt(s string) string { + return fmtShortGUID(s) +} + +// FmtStor formats a string to be used with TagStor +func FmtStor(s string) string { + return fmtShortGUID(s) +} + +// FmtDstPth formats a string to be used with TagDstPth +func FmtDstPth(s string) string { + return s +} + +// FmtCnsPth formats a string to be used with TagCnsPth +func FmtCnsPth(s string) string { + return s +} + +// FmtMsgID formats a string to be used with TagMsgID +func FmtMsgID(s string) string { + return s +} + +// FmtAckID formats a string to be used with TagAckID +func FmtAckID(s string) string { + return s +} + +// FmtHostIP formats a string to be used with TagHostIP +func FmtHostIP(s string) string { + return s +} + +// FmtReconfigureID formats a string to be used with TagReconfigureID +func FmtReconfigureID(s string) string { + return s +} + +// FmtInPutAckID formats a string to be used with TagInPutAckID +func FmtInPutAckID(s string) string { + return s +} + +// FmtInPubConnID formats an int to be used with TagInPubConnID +func FmtInPubConnID(s int) string { + return fmt.Sprintf("%v", s) +} + +// FmtInReplicaHost formats a string to be used with TagInReplicaHost +func FmtInReplicaHost(s string) string { + return s +} + +// FmtDLQID formats a string to be used with TagDLQID +func FmtDLQID(s string) string { + return fmtGUID(s) +} + +// FmtService formats a string to be used with TagService +func FmtService(s string) string { + return s +} + +// FmtHostPort formats a string to be used with TagHostPort +func FmtHostPort(s string) string { + return s +} diff --git a/common/metrics/interfaces.go b/common/metrics/interfaces.go new file mode 100644 index 0000000..0ddb5b2 --- /dev/null +++ b/common/metrics/interfaces.go @@ -0,0 +1,80 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package metrics + +import ( + "time" +) + +type ( + // Reporter is the the interface used to report stats. + Reporter interface { + // InitMetrics is used to initialize the metrics map + // with the respective type + InitMetrics(metricMap map[MetricName]MetricType) + + // GetChildReporter is used to get a child reporter from the parent + // this also makes sure we have all the tags from the parent in + // addition to the tags supplied here + GetChildReporter(tags map[string]string) Reporter + + // GetTags gets the tags for this reporter object + GetTags() map[string]string + + // IncCounter should be used for Counter style metrics + IncCounter(name string, tags map[string]string, delta int64) + + // UpdateGauge should be used for Gauge style metrics + UpdateGauge(name string, tags map[string]string, value int64) + + // StartTimer should be used for measuring latency. + // this returns a Stopwatch which can be used to stop the timer + StartTimer(name string, tags map[string]string) Stopwatch + + // RecordTimer should be used for measuring latency when you cannot start the stop watch. + RecordTimer(name string, tags map[string]string, d time.Duration) + } + + // Client is the the interface used to report metrics to m3 backend. + Client interface { + // IncCounter increments a counter and emits + // to m3 backend + IncCounter(scope int, counter int) + // AddCounter adds delta to the counter and + // emits to the m3 backend + AddCounter(scope int, counter int, delta int64) + // StartTimer starts a timer for the given + // metric name + StartTimer(scope int, timer int) Stopwatch + // RecordTimer starts a timer for the given + // metric name + RecordTimer(scope int, timer int, d time.Duration) + // UpdateGauge reports Gauge type metric to M3 + UpdateGauge(scope int, gauge int, delta int64) + // GetParentReporter return the parentReporter + GetParentReporter() Reporter + } + + // Stopwatch is the interface to stop the timer + Stopwatch interface { + Stop() time.Duration + } +) diff --git a/common/metrics/names.go b/common/metrics/names.go new file mode 100644 index 0000000..30ad9eb --- /dev/null +++ b/common/metrics/names.go @@ -0,0 +1,99 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package metrics + +// MetricName is the name of the metric +type MetricName string + +// MetricType is the type of the metric, which can be one of the 3 below +type MetricType int + +// MetricTypes which are supported +const ( + Counter MetricType = iota + Timer + Gauge +) + +const ( + // ServiceNameTagName is the tag name to identify partner service which uses Chermai client + ServiceNameTagName = "serviceName" + // DeploymentTagName is the tag name to identify current deployment name + DeploymentTagName = "deployment" + + // PublishMessageRate is the rate of message wrote to input + PublishMessageRate = "cherami.publish.message.rate" + // PublishMessageFailedRate is the rate of message try writting to input but failed + PublishMessageFailedRate = "cherami.publish.message.failed" + // PublishMessageLatency is the latency of message wrote to input + PublishMessageLatency = "cherami.publish.message.latency" + // PublishAckRate is the rate of ack got from input + PublishAckRate = "cherami.publish.ack.rate" + // PublishReconfigureRate is the rate of reconfiguration happening + PublishReconfigureRate = "cherami.publish.reconfigure.rate" + // PublishNumConnections is the number of connections with input + PublishNumConnections = "cherami.publish.connections" + // PublishNumInflightMessagess is the number of inflight messages hold locally by publisher + PublishNumInflightMessagess = "cherami.publish.message.inflights" + + // ConsumeMessageRate is the rate of message got from output + ConsumeMessageRate = "cherami.consume.message.rate" + // ConsumeCreditRate is the rate of credit sent to output + ConsumeCreditRate = "cherami.consume.credit.rate" + // ConsumeCreditFailedRate is the rate of credit try sending to output but failed + ConsumeCreditFailedRate = "cherami.consume.credit.failed" + // ConsumeCreditLatency is the latency of credit sent to output + ConsumeCreditLatency = "cherami.consume.credit.latency" + // ConsumeAckRate is the rate of ack sent to output + ConsumeAckRate = "cherami.consume.ack.rate" + // ConsumeAckFailedRate is the rate of ack try sending to output but failed + ConsumeAckFailedRate = "cherami.consume.ack.failed" + // ConsumeNackRate is the rate of nack sent to output + ConsumeNackRate = "cherami.consume.nack.rate" + // ConsumeReconfigureRate is the rate of reconfiguration happening + ConsumeReconfigureRate = "cherami.consume.reconfigure.rate" + // ConsumeNumConnections is the number of connections with output + ConsumeNumConnections = "cherami.consume.connections" + // ConsumeLocalCredits is the number of credit hold locally by consumer + ConsumeLocalCredits = "cherami.consume.credit.local" +) + +// MetricDefs contains definition of metrics to its type mapping +var MetricDefs = map[MetricName]MetricType{ + PublishMessageRate: Counter, + PublishMessageFailedRate: Counter, + PublishMessageLatency: Timer, + PublishAckRate: Counter, + PublishReconfigureRate: Counter, + PublishNumConnections: Gauge, + PublishNumInflightMessagess: Gauge, + + ConsumeMessageRate: Counter, + ConsumeCreditRate: Counter, + ConsumeCreditFailedRate: Counter, + ConsumeCreditLatency: Timer, + ConsumeAckRate: Counter, + ConsumeAckFailedRate: Counter, + ConsumeNackRate: Counter, + ConsumeReconfigureRate: Counter, + ConsumeNumConnections: Gauge, + ConsumeLocalCredits: Gauge, +} diff --git a/common/metrics/nullreporter.go b/common/metrics/nullreporter.go new file mode 100644 index 0000000..302354d --- /dev/null +++ b/common/metrics/nullreporter.go @@ -0,0 +1,94 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package metrics + +import "time" + +type ( + // NullReporter is a dummy reporter which implements the Reporter interface + NullReporter struct { + tags map[string]string + } + + nullStopWatch struct { + startTime time.Time + elasped time.Duration + } +) + +// NewNullReporter create an instance of Reporter which can be used emit metric to console +func NewNullReporter() Reporter { + reporter := &NullReporter{ + tags: make(map[string]string), + } + + return reporter +} + +// InitMetrics is used to initialize the metrics map with the respective type +func (r *NullReporter) InitMetrics(metricMap map[MetricName]MetricType) { + // This is a no-op for simple reporter as it is already have a static list of metric to work with +} + +// GetChildReporter creates the child reporter for this parent reporter +func (r *NullReporter) GetChildReporter(tags map[string]string) Reporter { + return r +} + +// GetTags returns the tags for this reporter object +func (r *NullReporter) GetTags() map[string]string { + return r.tags +} + +// IncCounter reports Counter metric to M3 +func (r *NullReporter) IncCounter(name string, tags map[string]string, delta int64) { + // not implemented +} + +// UpdateGauge reports Gauge type metric +func (r *NullReporter) UpdateGauge(name string, tags map[string]string, value int64) { + // Not implemented +} + +// Start is the implementation of the stop watch routine +func (w *nullStopWatch) Start() { + w.startTime = time.Now() +} + +// Stop is the implementation of the corresponding stop watch routine +func (w *nullStopWatch) Stop() time.Duration { + w.elasped = time.Since(w.startTime) + + return w.elasped +} + +// StartTimer returns a Stopwatch which when stopped will report the metric +func (r *NullReporter) StartTimer(name string, tags map[string]string) Stopwatch { + w := &nullStopWatch{} + w.Start() + return w +} + +// RecordTimer should be used for measuring latency when you cannot start the stop watch. +func (r *NullReporter) RecordTimer(name string, tags map[string]string, d time.Duration) { + // Record the time as counter of time in milliseconds + // not implemented +} diff --git a/common/thrift_util.go b/common/thrift_util.go new file mode 100644 index 0000000..8284ca8 --- /dev/null +++ b/common/thrift_util.go @@ -0,0 +1,117 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package common + +import ( + "reflect" + + "github.com/apache/thrift/lib/go/thrift" +) + +// TSerialize is used to serialize thrift TStruct to []byte +func TSerialize(msg thrift.TStruct) (b []byte, err error) { + return thrift.NewTSerializer().Write(msg) +} + +// TSerializeString is used to serialize thrift TStruct to string +func TSerializeString(msg thrift.TStruct) (s string, err error) { + return thrift.NewTSerializer().WriteString(msg) +} + +// TListSerialize is used to serialize list of thrift TStruct to []byte +func TListSerialize(msgs []thrift.TStruct) (b []byte, err error) { + if msgs == nil { + return + } + + t := thrift.NewTSerializer() + t.Transport.Reset() + + if e := t.Protocol.WriteListBegin(thrift.STRING, len(msgs)); e != nil { + err = thrift.PrependError("error writing list begin: ", e) + return + } + + for _, v := range msgs { + if e := v.Write(t.Protocol); e != nil { + err = thrift.PrependError("error writing TStruct: ", e) + return + } + } + + if e := t.Protocol.WriteListEnd(); e != nil { + err = thrift.PrependError("error writing list end: ", e) + return + } + + if err = t.Protocol.Flush(); err != nil { + return + } + + if err = t.Transport.Flush(); err != nil { + return + } + + b = append(b, t.Transport.Bytes()...) + return +} + +// TDeserialize is used to deserialize []byte to thrift TStruct +func TDeserialize(msg thrift.TStruct, b []byte) (err error) { + return thrift.NewTDeserializer().Read(msg, b) +} + +// TDeserializeString is used to deserialize string to thrift TStruct +func TDeserializeString(msg thrift.TStruct, s string) (err error) { + return thrift.NewTDeserializer().ReadString(msg, s) +} + +// TListDeserialize is used to deserialize []byte to list of thrift TStruct +func TListDeserialize(msgType reflect.Type, b []byte) (msgs []thrift.TStruct, err error) { + t := thrift.NewTDeserializer() + err = nil + if _, err = t.Transport.Write(b); err != nil { + return + } + + _, size, e := t.Protocol.ReadListBegin() + if e != nil { + err = thrift.PrependError("error reading list begin: ", e) + return + } + + msgs = make([]thrift.TStruct, 0, size) + for i := 0; i < size; i++ { + msg := reflect.New(msgType).Interface().(thrift.TStruct) + if e := msg.Read(t.Protocol); e != nil { + err = thrift.PrependError("error reading TStruct: ", e) + return + } + msgs = append(msgs, msg) + } + + if e := t.Protocol.ReadListEnd(); e != nil { + err = thrift.PrependError("error reading list end: ", e) + return + } + + return +} diff --git a/common/util.go b/common/util.go new file mode 100644 index 0000000..5878b2c --- /dev/null +++ b/common/util.go @@ -0,0 +1,294 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package common + +import ( + "encoding/binary" + "encoding/hex" + "encoding/json" + "fmt" + "io/ioutil" + "math/rand" + "net" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/uber/cherami-thrift/.generated/go/cherami" + + log "github.com/Sirupsen/logrus" + "github.com/uber/tchannel-go" + "github.com/uber/tchannel-go/hyperbahn" +) + +const hyperbahnPort int16 = 21300 +const rpAppNamePrefix string = "cherami" +const maxRpJoinTimeout = 2 * 60 * time.Second +const defaultNumReplicas = 3 + +const ( + inputHostAdminChannelName = "inputhost-admin-client" + outputHostAdminChannelName = "outputhost-admin-client" + storeHostClientChannelName = "storehost-client" +) + +// CreateTChannel creates the top level TChannel object for the +// specified servicer +func CreateTChannel(service string) *tchannel.Channel { + ch, err := tchannel.NewChannel(service, &tchannel.ChannelOptions{ + DefaultConnectionOptions: tchannel.ConnectionOptions{ + FramePool: tchannel.NewSyncFramePool(), + }, + }) + if err != nil { + log.Fatalf("Failed to create tchannel: %v", err) + } + return ch +} + +func getHyperbahnInitialNodes(bootstrapFile string) []string { + ip, _ := tchannel.ListenIP() + ret := []string{fmt.Sprintf("%s:%d", ip.String(), hyperbahnPort)} + + if len(bootstrapFile) < 1 { + return ret + } + + blob, err := ioutil.ReadFile(bootstrapFile) + if err != nil { + return ret + } + + err = json.Unmarshal(blob, &ret) + if err != nil { + return ret + } + + return ret +} + +func isPersistentService(name string) bool { + return (strings.Compare(name, StoreServiceName) == 0) +} + +// CreateHyperbahnClient returns a hyperbahn client +func CreateHyperbahnClient(ch *tchannel.Channel, bootstrapFile string) *hyperbahn.Client { + initialNodes := getHyperbahnInitialNodes(bootstrapFile) + config := hyperbahn.Configuration{InitialNodes: initialNodes} + if len(config.InitialNodes) == 0 { + log.Fatalf("No Hyperbahn nodes to connect to.") + } + hClient, _ := hyperbahn.NewClient(ch, config, nil) + return hClient +} + +// AdvertiseInHyperbahn advertises this node in Hyperbahn +func AdvertiseInHyperbahn(ch *tchannel.Channel, bootstrapFile string) *hyperbahn.Client { + hbClient := CreateHyperbahnClient(ch, bootstrapFile) + if err := hbClient.Advertise(); err != nil { + log.Errorf("Failed to advertise in Hyperbahn: %v", err) + return nil + } + return hbClient +} + +// This is just a utility to satisfy the RPM service so that listen hosts is +// map[string][]string +func convertListenHosts(cfgListenHosts map[string]string) map[string][]string { + listenHosts := make(map[string][]string) + + for service, hosts := range cfgListenHosts { + listenHosts[service] = strings.Split(hosts, ",") + } + + return listenHosts +} + +// SplitHostPort takes a x.x.x.x:yyyy string and split it into host and ports +func SplitHostPort(hostPort string) (string, int, error) { + parts := strings.Split(hostPort, ":") + port, err := strconv.Atoi(parts[1]) + return parts[0], port, err +} + +var guidRegex = regexp.MustCompile(`([[:xdigit:]]{8})-[[:xdigit:]]{4}-[[:xdigit:]]{4}-[[:xdigit:]]{4}-[[:xdigit:]]{12}`) + +// ShortenGUIDString takes a string with one or more GUIDs and elides them to make it more human readable. It turns +// "354754bd-b73e-4d20-8021-ab93a3d145c0:67af70c5-f45e-4b3d-9d20-6758195e2ff4:3:2" into "354754bd:67af70c5:3:2" +func ShortenGUIDString(s string) string { + return guidRegex.ReplaceAllString(s, "$1") +} + +// ConditionFunc represents an expression that evaluates to +// true on when some condition is satisfied and false otherwise +type ConditionFunc func() bool + +// SpinWaitOnCondition busy waits for a given condition to be true until the timeout +// Returns true if the condition was satisfied, false on timeout +func SpinWaitOnCondition(condition ConditionFunc, timeout time.Duration) bool { + + timeoutCh := time.After(timeout) + + for !condition() { + select { + case <-timeoutCh: + return false + default: + time.Sleep(time.Millisecond * 5) + } + } + + return true +} + +// AwaitWaitGroup calls Wait on the given wait +// Returns true if the Wait() call succeeded before the timeout +// Returns false if the Wait() did not return before the timeout +func AwaitWaitGroup(wg *sync.WaitGroup, timeout time.Duration) bool { + + doneC := make(chan struct{}) + + go func() { + wg.Wait() + close(doneC) + }() + + select { + case <-doneC: + return true + case <-time.After(timeout): + return false + } +} + +// IsRetryableTChanErr returns true if the given tchannel +// error is a retryable error. +func IsRetryableTChanErr(err error) bool { + return (err == tchannel.ErrTimeout || + err == tchannel.ErrServerBusy || + err == tchannel.ErrRequestCancelled) +} + +// GetDirectoryName function gives the directory name given a path used for destination or consumer groups +func GetDirectoryName(path string) (string, error) { + parts := strings.Split(path, "/") + if len(parts) < 3 { + return "", fmt.Errorf("Invalid path: %v", path) + } + + return parts[1], nil +} + +// GetDateTag returns the current date used for tagging daily metric +func GetDateTag() string { + return time.Now().Format("2006-01-02") +} + +// GetConnectionKey is used to create a key used by connections for looking up connections +func GetConnectionKey(host *cherami.HostAddress) string { + return net.JoinHostPort(host.GetHost(), strconv.Itoa(int(host.GetPort()))) +} + +// GetRandInt64 is used to get a 64 bit random number between min and max +func GetRandInt64(min int64, max int64) int64 { + // we need to get a random number between min and max + return min + rand.Int63n(max-min) +} + +// UUIDHashCode is a hash function for hashing string uuid +// if the uuid is malformed, then the hash function always +// returns 0 as the hash value +func UUIDHashCode(key string) uint32 { + if len(key) != UUIDStringLength { + return 0 + } + // Use the first 4 bytes of the uuid as the hash + b, err := hex.DecodeString(key[:8]) + if err != nil { + return 0 + } + return binary.BigEndian.Uint32(b) +} + +// SequenceNumber is an int64 number represents the sequence of messages in Extent +type SequenceNumber int64 + +// UnixNanoTime is Unix time as nanoseconds since Jan 1st, 1970, 00:00 GMT +type UnixNanoTime int64 + +// Seconds is time as seconds, either relative or absolute since the epoch +type Seconds float64 + +// ToSeconds turns a relative or absolute UnixNanoTime to float Seconds +func (u UnixNanoTime) ToSeconds() Seconds { + return Seconds(float64(u) / float64(1e9)) +} + +// DurationToSeconds converts a time.Duration to Seconds +func DurationToSeconds(t time.Duration) Seconds { + return Seconds(float64(int64(t)) / float64(int64(time.Second))) +} + +// Now is the version to return UnixNanoTime +func Now() UnixNanoTime { + return UnixNanoTime(time.Now().UnixNano()) +} + +// CalculateRate does a simple rate calculation +func CalculateRate(last, curr SequenceNumber, lastTime, currTime UnixNanoTime) float64 { + deltaV := float64(curr - last) + deltaT := float64(currTime-lastTime) / float64(1e9) // 10^9 nanoseconds per second + return float64(deltaV / deltaT) +} + +// GeometricRollingAverage is the value of a geometrically diminishing rolling average +type GeometricRollingAverage float64 + +// SetGeometricRollingAverage adds a value to the geometric rolling average +func (avg *GeometricRollingAverage) SetGeometricRollingAverage(val float64) { + const rollingAverageFalloff = 100 + *avg -= *avg / GeometricRollingAverage(rollingAverageFalloff) + *avg += GeometricRollingAverage(val) / GeometricRollingAverage(rollingAverageFalloff) +} + +// GetGeometricRollingAverage returns the result of the geometric rolling average +func (avg *GeometricRollingAverage) GetGeometricRollingAverage() float64 { + return float64(*avg) +} + +// ValidateTimeout panics if the passed timeout is unreasonable +func ValidateTimeout(t time.Duration) { + if t >= time.Millisecond*100 && t <= time.Minute*5 { + return + } + + panic(fmt.Sprintf(`Configured timeout is out of range: %v`, t)) +} + +// MinInt returns the minimum of values (a, b) +func MinInt(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/common/websocket/base_test.go b/common/websocket/base_test.go new file mode 100644 index 0000000..25f9bf8 --- /dev/null +++ b/common/websocket/base_test.go @@ -0,0 +1,260 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package websocket + +import ( + "fmt" + "net/http" + "reflect" + "sync" + + "github.com/uber/cherami-thrift/.generated/go/cherami" + log "github.com/Sirupsen/logrus" +) + +// -- Websocket Streaming Server -- // +type streamHandler func(recvC <-chan int, sendC chan<- int, flushC chan<- interface{}) + +func startTestServer(port int, handler streamHandler) error { + + mux := http.NewServeMux() + wsHub := NewWebsocketHub() + readMsgType := reflect.TypeOf((*cherami.PutMessage)(nil)).Elem() + + mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { + + stream := NewStreamServer(w, r, wsHub, readMsgType) + + if err := stream.Start(); err != nil { + log.Errorf("testServer.startTestServer (start): %v", err) + return + } + + serve(stream, handler) + }) + + // listen-and-serve on a separate go-routine + go http.ListenAndServe(fmt.Sprintf(":%d", port), mux) + + return nil + +} + +func serve(stream StreamServer, handler streamHandler) { + + recvC, sendC, flushC := make(chan int), make(chan int), make(chan interface{}) + + var wg sync.WaitGroup + + // start read/write pumps + wg.Add(3) + go readPump(stream, recvC, &wg) + go writePump(sendC, flushC, stream, &wg) + go func() { + defer wg.Done() + handler(recvC, sendC, flushC) + }() + + // wait until the read/write/handler pumps are done + wg.Wait() +} + +func readPump(stream StreamServer, recvC chan<- int, wg *sync.WaitGroup) { + + defer wg.Done() + defer close(recvC) // close outgoing pipe + + for { + tMsg, err := stream.Read() + + if err != nil { + // log.Debugf("testServer.readPump (read): %v", err) + return + } + + msg, ok := tMsg.(*cherami.PutMessage) + + if !ok { + log.Errorf("testServer.readPump: invalid msg type") + return + } + + // log.Debugf("testServer.readPump: recv %d", getMsgID(msg)) + + recvC <- getMsgID(msg) + } +} + +func writePump(sendC chan int, flushC <-chan interface{}, stream StreamServer, wg *sync.WaitGroup) { + + defer wg.Done() + defer func() { + // log.Debugf("testServer.writePump: Done") + if err := stream.Done(); err != nil { // close outgoing pipe + log.Errorf("testServer.writePump (done): %v", err) + } + }() + + for { + select { + case i, ok := <-sendC: + if !ok { + // log.Debugf("testServer.writePump: closed") + return + } + + // log.Debugf("testServer.writePump: Write [%d]", i) + err := stream.Write(newTestMsg(i)) + + if err != nil { + log.Errorf("testServer.writePump (write): %v", err) + return + } + + // log.Debugf("testServer.writePump: sent %d", i) + + case <-flushC: + // log.Debugf("testServer.writePump: Flush") + err := stream.Flush() + + if err != nil { + log.Errorf("testServer.writePump (flush): %v", err) + return + } + } + } +} + +// -- Websocket Streaming Client -- // + +type testClient struct { + recvC chan int + sendC chan int + flushC chan interface{} + stream StreamClient + + wg sync.WaitGroup + doneC chan struct{} +} + +func startTestClient(port int) (*testClient, error) { + + readMsgType := reflect.TypeOf((*cherami.PutMessage)(nil)).Elem() + + stream := NewStreamClient(fmt.Sprintf("ws://localhost:%d/test", port), http.Header{}, NewWebsocketHub(), readMsgType) + + if err := stream.Start(); err != nil { + log.Errorf("error starting client: %v", err) + return nil, err + } + + recvC, sendC, flushC := make(chan int), make(chan int), make(chan interface{}) + + t := &testClient{ + recvC: recvC, + sendC: sendC, + flushC: flushC, + stream: stream, + doneC: make(chan struct{}), + } + + // start read/write pumps + t.wg.Add(2) + go t.readPump(stream, recvC) + go t.writePump(sendC, flushC, stream) + + go func() { + t.wg.Wait() // wait until the read/write/handler pumps are done + close(t.doneC) // signal doneC + }() + + return t, nil +} + +func (t *testClient) done() <-chan struct{} { + return t.doneC +} + +func (t *testClient) readPump(stream StreamClient, recvC chan<- int) { + + defer t.wg.Done() + defer close(recvC) // close outgoing pipe + + for { + tMsg, err := stream.Read() + + if err != nil { + // log.Debugf("testClient.readPump: %v", err) + return + } + + msg, ok := tMsg.(*cherami.PutMessage) + + if !ok { + log.Errorf("testClient.readPump: invalid msg type") + return + } + + // log.Debugf("testClient.readPump: recv %d", getMsgID(msg)) + + recvC <- getMsgID(msg) + } +} + +func (t *testClient) writePump(sendC chan int, flushC <-chan interface{}, stream StreamClient) { + + defer t.wg.Done() + defer func() { + // log.Debugf("testClient.writePump: Done") + if err := stream.Done(); err != nil { // close outgoing pipe + log.Errorf("testClient.writePump (done): %v", err) + } + }() + + for { + select { + case i, ok := <-sendC: + if !ok { + // log.Debugf("testClient.writePump: closed") + return + } + + // log.Debugf("testClient.writePump: Write [%d]", i) + err := stream.Write(newTestMsg(i)) + + if err != nil { + log.Errorf("testClient.writePump: write error=%v", err) + return + } + + // log.Debugf("testClient.writePump: sent %d", i) + + case <-flushC: + // log.Debugf("testClient.writePump: Flush") + err := stream.Flush() + + if err != nil { + log.Errorf("testClient.writePump: flush error=%v", err) + return + } + } + } +} diff --git a/common/websocket/client.go b/common/websocket/client.go new file mode 100644 index 0000000..525dc46 --- /dev/null +++ b/common/websocket/client.go @@ -0,0 +1,105 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package websocket + +import ( + "fmt" + "net/http" + "reflect" + "sync/atomic" + "time" + + "github.com/apache/thrift/lib/go/thrift" +) + +const ( + defaultClientFlushThreshold = 64 // 640 kiB + defaultPingInterval = 60 * time.Second // defaultPingInterval is the interval at which websocket PingMessages are sent +) + +type ( + // Dialer interface + Dialer interface { + // Dial opens a websocket connection from the client + Dial(urlStr string, requestHeader http.Header) (Conn, *http.Response, error) + } + + // StreamClient defines the interface for the 'client' side of the stream + StreamClient interface { + // Start starts the stream-client + Start() error + // Write writes a message into the write buffer + Write(thrift.TStruct) error + // Flush flushes out buffered messages + Flush() error + // Read reads and returns a message + Read() (thrift.TStruct, error) + // Done is used to indicate that the client is done + Done() error + } + + // streamClient is a wrapper for websocket-stream client that batches responses/requests + streamClient struct { + Stream + + url string + requestHeader http.Header + dialer Dialer + readMsgType reflect.Type + + started int32 + } +) + +// NewStreamClient initializes a websocket-streaming client to the given url +func NewStreamClient(url string, requestHeader http.Header, dialer Dialer, readMsgType reflect.Type) StreamClient { + + return &streamClient{ + url: url, + requestHeader: requestHeader, + dialer: dialer, + readMsgType: readMsgType, + } +} + +func (s *streamClient) Start() error { + + // make Start idempotent + if !atomic.CompareAndSwapInt32(&s.started, 0, 1) { + return nil + } + + wsConn, _, err := s.dialer.Dial(s.url, s.requestHeader) + if err != nil { + return fmt.Errorf("Dial failed: %v", err) + } + + s.Stream = NewStream(wsConn, &StreamOpts{ReadMsgType: s.readMsgType, FlushThreshold: defaultClientFlushThreshold, PingInterval: defaultPingInterval}) + + return s.Stream.Start() +} + +// Done closes the write-path on the client, leaving the read-path open so it can be drained +func (s *streamClient) Done() error { + + // shutdown only the write-path, leaving the read-path open so it can be drained by the application + return s.Stream.CloseWrite() +} diff --git a/common/websocket/conn.go b/common/websocket/conn.go new file mode 100644 index 0000000..c1c4e14 --- /dev/null +++ b/common/websocket/conn.go @@ -0,0 +1,51 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package websocket + +import ( + "time" + + gorilla "github.com/gorilla/websocket" +) + +type ( + // Conn interfaces out the underlying (gorilla) websocket implementation, so we + // can mock it out easily when testing, etc. + Conn interface { + ReadMessage() (messageType int, p []byte, err error) + WriteMessage(messageType int, data []byte) error + WriteControl(messageType int, data []byte, deadline time.Time) error + SetCloseHandler(h func(code int, text string) error) + Close() error + } + + connImpl struct { + *gorilla.Conn + } +) + +// interface implementation check +var _ Conn = (*connImpl)(nil) + +// NewWebsocketConn creates a websocket Conn +func NewWebsocketConn(c *gorilla.Conn) Conn { + return &connImpl{Conn: c} +} diff --git a/common/websocket/hub.go b/common/websocket/hub.go new file mode 100644 index 0000000..e5602f9 --- /dev/null +++ b/common/websocket/hub.go @@ -0,0 +1,73 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package websocket + +import ( + "net/http" + + gorilla "github.com/gorilla/websocket" +) + +type ( + // Hub is the interface for websocket Hub that can be used to establish websocket connection + Hub interface { + // Dial handshakes websocket connection from client side + Dial(urlStr string, requestHeader http.Header) (Conn, *http.Response, error) + // Upgrade handshakes and upgrades websocket connection from server side + Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (Conn, error) + } + + hubImpl struct { + // dialer to handshake websocket connection from client side + dialer *gorilla.Dialer + // upgrader to handshake and upgrade websocket connection from server side + upgrader *gorilla.Upgrader + } +) + +// interface implementation check +var _ Hub = (*hubImpl)(nil) + +// NewWebsocketHub creates a websocket Hub +func NewWebsocketHub() Hub { + return &hubImpl{ + upgrader: &gorilla.Upgrader{ + ReadBufferSize: 102400, + WriteBufferSize: 102400, + }, + dialer: &gorilla.Dialer{ + ReadBufferSize: 102400, + WriteBufferSize: 102400, + }, + } +} + +// Dial handshakes websocket connection from client side +func (hub *hubImpl) Dial(urlStr string, requestHeader http.Header) (Conn, *http.Response, error) { + wsConn, httpResp, err := hub.dialer.Dial(urlStr, requestHeader) + return NewWebsocketConn(wsConn), httpResp, err +} + +// Upgrade handshakes and upgrades websocket connection from server side +func (hub *hubImpl) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (Conn, error) { + wsConn, err := hub.upgrader.Upgrade(w, r, responseHeader) + return NewWebsocketConn(wsConn), err +} diff --git a/common/websocket/server.go b/common/websocket/server.go new file mode 100644 index 0000000..4b3537f --- /dev/null +++ b/common/websocket/server.go @@ -0,0 +1,103 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package websocket + +import ( + "fmt" + "net/http" + "reflect" + "sync/atomic" + + "github.com/apache/thrift/lib/go/thrift" +) + +const ( + defaultServerFlushThreshold = 64 // 640 kiB +) + +type ( + // Upgrader interface + Upgrader interface { + // Upgrade handshakes and upgrades websocket connection from server side + Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (Conn, error) + } + + // StreamServer defines the interface for the 'server' side of the stream that batches responses/requests + StreamServer interface { + // Start starts the stream-server + Start() error + // Write writes a message into the write buffer + Write(msg thrift.TStruct) (err error) + // Flush flushes out buffered messages + Flush() (err error) + // Read reads and returns a message + Read() (msg thrift.TStruct, err error) + // Done is used to indicate that the server is done + Done() (err error) + } + + // streamServer is a wrapper for websocket-stream server that batches responses/requests + streamServer struct { + Stream + + httpRespWriter http.ResponseWriter + httpRequest *http.Request + upgrader Upgrader + readMsgType reflect.Type + + started int32 + } +) + +// NewStreamServer initializes a new websocket-streaming server +func NewStreamServer(httpRespWriter http.ResponseWriter, httpRequest *http.Request, upgrader Upgrader, readMsgType reflect.Type) StreamServer { + + return &streamServer{ + httpRespWriter: httpRespWriter, + httpRequest: httpRequest, + upgrader: upgrader, + readMsgType: readMsgType, + } +} + +func (s *streamServer) Start() error { + + // make Start idempotent + if !atomic.CompareAndSwapInt32(&s.started, 0, 1) { + return nil + } + + wsConn, err := s.upgrader.Upgrade(s.httpRespWriter, s.httpRequest, nil) + if err != nil { + return fmt.Errorf("Upgrade failed: %v", err) + } + + s.Stream = NewStream(wsConn, &StreamOpts{ReadMsgType: s.readMsgType, FlushThreshold: defaultServerFlushThreshold, Server: true}) + + return s.Stream.Start() +} + +// Done closes both the read and write-paths and the underlying connection on the server +func (s *streamServer) Done() error { + + // tear-down the underlying connection + return s.Stream.Close() +} diff --git a/common/websocket/serverclient_test.go b/common/websocket/serverclient_test.go new file mode 100644 index 0000000..64ab6f2 --- /dev/null +++ b/common/websocket/serverclient_test.go @@ -0,0 +1,180 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package websocket + +import ( + "fmt" + "os" + "testing" + "time" + + // "github.com/stretchr/testify/mock" + log "github.com/Sirupsen/logrus" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type ( + // ServerStreamSuite tests websocket stream implementation + // This test suite uses thrift PutMessage as test message struct (or it can be anything thrift struct) + ServerStreamSuite struct { + *require.Assertions // override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test, not merely log an error + suite.Suite + } +) + +func TestServerStreamSuite(t *testing.T) { + suite.Run(t, new(ServerStreamSuite)) +} + +func (s *ServerStreamSuite) SetupTest() { + s.Assertions = require.New(s.T()) // Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil + if testing.Verbose() { + log.SetOutput(os.Stdout) + log.SetLevel(log.DebugLevel) // test logs at debug level + } +} + +func (s *ServerStreamSuite) TearDownTest() { +} + +func (s *ServerStreamSuite) TestStreaming() { + + port := 6192 // use ephemeral port + + stopC := make(chan struct{}) + + serverPump := func(recvC <-chan int, sendC chan<- int, flushC chan<- interface{}) { + + defer close(sendC) // close outgoing pipe (to stop stream) + + flushTicker := time.NewTicker(100 * time.Millisecond) + defer flushTicker.Stop() + + for { + select { + case i, ok := <-recvC: + + if !ok { + return + } + + sendC <- -i // negate and send back + // log.Debugf("serverPump: recv=%d send=%d", i, -i) + + case <-flushTicker.C: + flushC <- true + + case <-stopC: + return + } + } + } + + // log.Debugf("starting test server:") + err := startTestServer(port, serverPump) + + if err != nil { + log.Errorf("error starting testServer: %v", err) + return + } + + // -- TEST 1: client going away gracefully -- // + numSend := 4099 + + // log.Debugf("starting test client:") + client, err := startTestClient(port) + + if err != nil { + log.Errorf("error starting testClient: %v", err) + return + } + + // do the writes + for i := 1; i <= numSend; i++ { + client.sendC <- i + } + + // client.flushC <- true + // log.Debugf("client done sending") + close(client.sendC) // stop writePump; effectively calls stream.Done() + + // now drain all the reads + var numRecv int + for i := range client.recvC { + // log.Debugf("client recv: %d", i) + numRecv++ + s.Equal(numRecv, -i) + } + + // log.Debugf("sent=%d msgs; recv=%d msgs", numSend, numRecv) + + s.Equal(numSend, numRecv, fmt.Sprintf("numSend=%d != numRecv=%d", numSend, numRecv)) + s.True(waitFor(5000, func() bool { + select { + case <-client.done(): + return true + default: + return false + } + }), "timed out waiting for client to be done") + + /* + // -- TEST 2: server going away abruptly -- // + numSend := 4099 + + // log.Debugf("starting test client:") + client, err := startTestClient(port) + + if err != nil { + log.Errorf("error starting testClient: %v", err) + return + }rl + + // do the writes + for i := 1; i <= numSend; i++ { + client.sendC <- i + } + + // client.flushC <- true + close(client.sendC) // stop writePump; effectively calls stream.Done() + + // now drain all the reads + var numRecv int + for i := range client.recvC { + // log.Debugf("client recv: %d", i) + numRecv++ + s.Equal(numRecv, -i) + } + + // log.Debugf("client stream closed (sent=%d msgs; recv=%d msgs)", numSend, numRecv) + + s.Equal(numSend, numRecv, fmt.Sprintf("numSend=%d != numRecv=%d", numSend, numRecv)) + s.True(waitFor(5000, func() bool { + select { + case <-client.done(): + return true + default: + return false + } + }), "timed out waiting for client to be done") + */ +} diff --git a/common/websocket/stream.go b/common/websocket/stream.go new file mode 100644 index 0000000..0eb7149 --- /dev/null +++ b/common/websocket/stream.go @@ -0,0 +1,401 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package websocket + +import ( + "bytes" + "fmt" + "reflect" + "sync/atomic" + "time" + + "github.com/uber/cherami-client-go/common" + "github.com/apache/thrift/lib/go/thrift" + gorilla "github.com/gorilla/websocket" +) + +type ( + // Stream defines the interface for the "stream", that implements the shared + // functionality between the server and the client ends of the stream + Stream interface { + // Start starts the stream + Start() error + // Write writes a message into the buffer + Write(thrift.TStruct) error + // Flush flushes out buffered messages + Flush() error + // Read reads and returns a message + Read() (thrift.TStruct, error) + // CloseWrite closes the write-path + CloseWrite() error + // Close shuts down read/write path and closes the underlying connection + Close() error + } + + // stream provides a buffered websocket connection for use with streaming + stream struct { + conn Conn // the underlying websocket connection + + readMsgs []thrift.TStruct // buffered read messages + readMsgIndex int // index to the next message that will be returned + readMsgType reflect.Type // the actual type of read-message (for reflection) + + writeMsgs []thrift.TStruct // buffered write messages + writeMsgLen int // max number of messages that will be buffered + writeMsgIndex int // index to where the next message would be buffered + + server bool // indicates that this is the 'server' end of the stream + pingInterval time.Duration // interval at which 'ping' messages should be sent + + started int32 // indicates that the stream has been started + writeClosed int32 // indicates that the write-path has been closed + connClosed int32 // indicates that the underlying connection has been closed + closed int32 // indicates that close has been called on the stream + closeC chan struct{} // channel used to signal the ping-pump to stop + } + + // StreamOpts are the options passed when creating the stream + StreamOpts struct { + // Server indicates whether the stream should be configured to behave + // like the "server" end of the connection. + Server bool + + // FlushThreshold is the number of messages to buffer before flushing automatically + FlushThreshold int + + // PingInterval is the interval at which websocket ping-messages need to be sent; if + // this is zero, pings will not be sent + PingInterval time.Duration + + // readMsgType is the type of read-message + ReadMsgType reflect.Type + } +) + +const closeMessageTimeout = 10 * time.Second + +var errNotStarted = fmt.Errorf("Stream not started") + +// NewStream initializes a new stream object +func NewStream(conn Conn, opts *StreamOpts) Stream { + + // ensure we buffer at least one message + if opts.FlushThreshold < 1 { + opts.FlushThreshold = 1 + } + + // create and return stream object + return &stream{ + conn: conn, + + readMsgType: opts.ReadMsgType, + readMsgIndex: 0, + + writeMsgs: make([]thrift.TStruct, opts.FlushThreshold), + writeMsgLen: opts.FlushThreshold, + writeMsgIndex: 0, + + server: opts.Server, + pingInterval: opts.PingInterval, + closeC: make(chan struct{}), + } +} + +// Start starts the stream, starting the ping-pump, if needed. +func (s *stream) Start() error { + + // make Start idempotent + if !atomic.CompareAndSwapInt32(&s.started, 0, 1) { + return nil + } + + // customize response to a close-message, based on if this is the + // server or the client end of the stream + if s.server { + + // on the server-end of the stream, simply return an error to + // the ReadMessage call, giving an opportunity to flush out + // any pending messages before closing out the connection + s.conn.SetCloseHandler(func(code int, text string) error { + // when the close-handler returns 'nil' here, the gorilla + // implementation returns a 'CloseError' to the Read + return nil + }) + + } else { + + // on the client-end of the stream, if we have received a + // close-message from the server, there's no point sending + // any more messages to the server, so simply respond back + // back with a close-message + s.conn.SetCloseHandler(func(code int, text string) error { + // send a close-message; NB: after a close-message is sent, the gorilla-websocket + // implementation fails any future writes on this connection + return s.conn.WriteControl(gorilla.CloseMessage, gorilla.FormatCloseMessage(code, "close ack"), time.Now().Add(closeMessageTimeout)) + }) + } + + // start ping pump, if needed + if s.pingInterval.Nanoseconds() != 0 { + go s.pingPump() + } + + return nil +} + +// CloseWrite closes the write path after flushing out any buffered writes. +// NB: Write(), Flush(), Close() and CloseWrite() are not thread-safe between +// each other; in other words, the caller needs to ensure that only one of the +// them is invoked at any given time. +func (s *stream) CloseWrite() error { + + // ensure Start had been called + if atomic.LoadInt32(&s.started) != 1 { + return errNotStarted + } + + // make CloseWrite idempotent + if !atomic.CompareAndSwapInt32(&s.writeClosed, 0, 1) { + return nil + } + + err0 := s.Flush() // flush out any buffered messages before sending close message + + // send a close-message; NB: after a close-message is sent, the gorilla-websocket + // implementation fails any future writes on this connection + err1 := s.conn.WriteControl(gorilla.CloseMessage, gorilla.FormatCloseMessage(gorilla.CloseGoingAway, "write closed"), time.Now().Add(closeMessageTimeout)) + + switch { + case err0 != nil: + return fmt.Errorf("CloseWrite error: flush failed: %v", err0) + + case err1 != nil: + return fmt.Errorf("CloseWrite error: write close msg failed: %v", err1) + + default: + return nil + } +} + +func (s *stream) closeConn() error { + + // ensure we close the conn only once + if !atomic.CompareAndSwapInt32(&s.connClosed, 0, 1) { + return nil + } + + // stop ping-pump + close(s.closeC) + + // send a close-message; NB: after a close-message is sent, the gorilla-websocket + // implementation fails any future writes on this connection + err0 := s.conn.WriteControl(gorilla.CloseMessage, gorilla.FormatCloseMessage(gorilla.CloseGoingAway, "stream closed"), time.Now().Add(closeMessageTimeout)) + + // cose underlying connection + err1 := s.conn.Close() // close underlying connection + + switch { + case err0 != nil: + return fmt.Errorf("write close msg failed: %v", err0) + + case err1 != nil: + return fmt.Errorf("conn close failed: %v", err1) + + default: + return nil + } +} + +// Close flushes any buffered messages before closing the underlying websocket connection. +func (s *stream) Close() error { + + // ensure Start had been called + if atomic.LoadInt32(&s.started) != 1 { + return errNotStarted + } + + // ensure we close the conn only once + if !atomic.CompareAndSwapInt32(&s.closed, 0, 1) { + return nil + } + + err0 := s.Flush() // flush out any buffered messages before sending close message + + err1 := s.closeConn() // close underlying connection + + switch { + case err0 != nil: + return fmt.Errorf("Close error: flush: %v", err0) + + case err1 != nil: + return fmt.Errorf("Close error: closeConn: %v", err1) + + default: + return nil + } +} + +// setFlushThreshold sets the write-buffer size (in number of messages); this +// is primarily intended for testing. +func (s *stream) setFlushThreshold(flushThreshold int) { + + if flushThreshold < 1 { + flushThreshold = 1 // ensure we buffer at least one message + } + + // flush out any buffered messages + if s.writeMsgIndex > 0 { + if err := s.Flush(); err != nil { + return + } + } + + s.writeMsgs = make([]thrift.TStruct, flushThreshold) + s.writeMsgLen = flushThreshold +} + +// Read returns one message from the read-buffer; if there aren't any buffered +// messages available, it reads a 'payload' from the wire, deserializes them +// into messages and returns one. +// NB: Read is *not* thread-safe; in other words, the caller needs to ensure +// that there there will *not* be more than one concurrent calls into Read(). +func (s *stream) Read() (thrift.TStruct, error) { + + // if we don't have any buffered messages, read from the wire + if s.readMsgIndex >= len(s.readMsgs) { + + msgType, payload, err := s.conn.ReadMessage() + if err != nil { + // if we got an error from read, then tear down the connection + // if this is the client-end of the stream (since the server has + // gone away). if this is the server-end of the stream, then + // simply return an error back. when the Close call comes in + // we would flush out any buffered messages before tearing down + // the connection. + if !s.server { + s.closeConn() + } + + return nil, fmt.Errorf("Read error: %v", err) + } + + switch msgType { + case gorilla.BinaryMessage: + // deserialize into 'readMsgType' messages + s.readMsgs, err = common.TListDeserialize(s.readMsgType, payload) + if err != nil { + if !s.server { + s.closeConn() // on any read error, close connection + } + + return nil, fmt.Errorf("Deserialize error: %v", err) + } + + s.readMsgIndex = 0 + + case gorilla.TextMessage: + // earlier versions of gorilla-websocket used to automatically + // respond to a close-message preventing our code from being able + // to flush out any buffered messages. to workaround that, we used + // to send out a special 'TextMessage' to convey close (from the + // client end of the stream). we still respond to that and handle + // it appropriately, in case we have an older client connect in + // to a newer server + if bytes.Equal(payload, []byte("close")) { + return nil, fmt.Errorf("Closed") + } + + fallthrough + + default: + return nil, fmt.Errorf("Invalid message type [%d]", msgType) + } + } + + msg := s.readMsgs[s.readMsgIndex] + s.readMsgIndex++ + + return msg, nil +} + +// Write buffers a message; if the write-message buffer is full, it triggers a flush +// of all the buffered messages. +// NB: Write(), Flush(), Close() and CloseWrite() are not thread-safe between +// each other; in other words, the caller needs to ensure that only one of the +// them is invoked at any given time. +func (s *stream) Write(msg thrift.TStruct) error { + + s.writeMsgs[s.writeMsgIndex] = msg + s.writeMsgIndex++ + + // trigger a flush if the write buffer is full + if s.writeMsgIndex >= s.writeMsgLen { + return s.Flush() + } + + return nil +} + +// Flush flushes any buffered messages. +// NB: Write(), Flush(), Close() and CloseWrite() are not thread-safe between +// each other; in other words, the caller needs to ensure that only one of the +// them is invoked at any given time. +func (s *stream) Flush() error { + + // no-op, if there's nothing to flush + if s.writeMsgIndex <= 0 { + return nil + } + + // serialize all the buffered messages together into one payload + payload, err := common.TListSerialize(s.writeMsgs[:s.writeMsgIndex]) + if err != nil { + return fmt.Errorf("Serialize error: %v", err) + } + + for i := 0; i < s.writeMsgIndex; i++ { + s.writeMsgs[i] = nil // free-up for GC + } + + s.writeMsgIndex = 0 // reset index into write-buffer + + return s.conn.WriteMessage(gorilla.BinaryMessage, payload) // flush! +} + +// pingPump sends ping-messages at regular intervals +func (s *stream) pingPump() { + + pingTicker := time.NewTicker(s.pingInterval) + defer pingTicker.Stop() + + for { + select { + case <-pingTicker.C: + // keeping sending 'pings', until the connection is closed (or the ping-write fails) + if err := s.conn.WriteControl(gorilla.PingMessage, []byte{}, time.Now().Add(s.pingInterval)); err != nil { + return + } + + case <-s.closeC: + return + } + } +} diff --git a/common/websocket/stream_test.go b/common/websocket/stream_test.go new file mode 100644 index 0000000..a865fba --- /dev/null +++ b/common/websocket/stream_test.go @@ -0,0 +1,576 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package websocket + +import ( + "errors" + "log" + "os" + "reflect" + "strconv" + "sync/atomic" + "testing" + "time" + + "github.com/apache/thrift/lib/go/thrift" + gorilla "github.com/gorilla/websocket" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/uber/cherami-thrift/.generated/go/cherami" + "github.com/uber/cherami-client-go/common" + mockWS "github.com/uber/cherami-client-go/mocks/common/websocket" +) + +type ( + // StreamSuite tests websocket stream implementation + // This test suite uses thrift PutMessage as test message struct (or it can be anything thrift struct) + StreamSuite struct { + *require.Assertions // override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test, not merely log an error + suite.Suite + } +) + +func TestStreamSuite(t *testing.T) { + suite.Run(t, new(StreamSuite)) +} + +func (s *StreamSuite) SetupTest() { + s.Assertions = require.New(s.T()) // Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil + if testing.Verbose() { + log.SetOutput(os.Stdout) + } +} + +func (s *StreamSuite) TearDownTest() { +} + +var testMsgType = reflect.TypeOf((*cherami.PutMessage)(nil)).Elem() + +func newTestMsg(id int) *cherami.PutMessage { + stringID := strconv.Itoa(id) + return &cherami.PutMessage{ + ID: common.StringPtr(stringID), + Data: []byte(stringID), + } +} + +func getMsgID(msg *cherami.PutMessage) int { + id, _ := strconv.Atoi(msg.GetID()) + return id +} + +func newTestPayload(n int) []byte { + messages := make([]thrift.TStruct, 0, n) + for i := 0; i < n; i++ { + messages = append(messages, newTestMsg(i)) + } + payload, _ := common.TListSerialize(messages) + return payload +} + +// TestOpen tests normal open/cloase operations +func (s *StreamSuite) TestOpenClose() { + + mockConn := &mockWS.MockWebsocketConn{} + mockConn.On("SetCloseHandler", mock.Anything).Return(nil).Once() + mockConn.On("WriteControl", gorilla.CloseMessage, mock.Anything, mock.Anything).Return(nil).Once() + mockConn.On("Close").Return(nil) + + stream := NewStream(mockConn, &StreamOpts{}) + stream.Start() + + err := stream.Close() + s.NoError(err) + + mockConn.AssertExpectations(s.T()) +} + +// TestOpenCloseError tests error out open/done operations +func (s *StreamSuite) TestOpenCloseError() { + + mockConn := &mockWS.MockWebsocketConn{} + mockConn.On("SetCloseHandler", mock.Anything).Return(nil).Once() + mockConn.On("WriteControl", gorilla.CloseMessage, mock.Anything, mock.Anything).Return(nil).Once() + mockConn.On("Close").Return(errors.New("Close error")) + + stream := NewStream(mockConn, &StreamOpts{}) + stream.Start() + + err := stream.Close() + s.Error(err) + + mockConn.AssertExpectations(s.T()) +} + +// TestPing tests open with ping pump enabled +func (s *StreamSuite) TestPing() { + + mockConn := &mockWS.MockWebsocketConn{} + mockConn.On("SetCloseHandler", mock.Anything).Return(nil).Once() + mockConn.On("WriteControl", gorilla.PingMessage, mock.Anything, mock.Anything).Return(nil) + mockConn.On("WriteControl", gorilla.CloseMessage, mock.Anything, mock.Anything).Return(nil).Once() + mockConn.On("Close").Return(nil) + + pingInternal := 10 * time.Millisecond + + stream := NewStream(mockConn, &StreamOpts{PingInterval: pingInternal}) + stream.Start() + + time.Sleep(2 * pingInternal) // wait for at least one ping to go through + + err := stream.Close() + s.NoError(err) + + mockConn.AssertExpectations(s.T()) +} + +// TestPingError tests ping pump error out +func (s *StreamSuite) TestPingError() { + + mockConn := &mockWS.MockWebsocketConn{} + mockConn.On("SetCloseHandler", mock.Anything).Return(nil).Once() + mockConn.On("WriteControl", gorilla.PingMessage, mock.Anything, mock.Anything).Return(errors.New("WriteControl error")) + mockConn.On("WriteControl", gorilla.CloseMessage, mock.Anything, mock.Anything).Return(nil).Once() + mockConn.On("Close").Return(nil) + + pingInternal := 10 * time.Millisecond + + stream := NewStream(mockConn, &StreamOpts{PingInterval: pingInternal}) + stream.Start() + + time.Sleep(2 * pingInternal) // wait for at least one ping to go through + + err := stream.Close() + s.NoError(err) // ping pump error should not impact stream's normal operations + + mockConn.AssertExpectations(s.T()) +} + +// TestRead tests multiple normal read operations +func (s *StreamSuite) TestRead() { + + msgCount := 3 + readLoop := 2 + + mockConn := &mockWS.MockWebsocketConn{} + mockConn.On("SetCloseHandler", mock.Anything).Return(nil).Once() + mockConn.On("ReadMessage").Return(gorilla.BinaryMessage, newTestPayload(msgCount), nil) + mockConn.On("WriteControl", gorilla.CloseMessage, mock.Anything, mock.Anything).Return(nil).Once() + mockConn.On("Close").Return(nil) + + stream := NewStream(mockConn, &StreamOpts{ReadMsgType: testMsgType}) + stream.Start() + + // each read loop will read all message in one payload + for i := 0; i < readLoop; i++ { + for j := 0; j < msgCount; j++ { + + msg, err := stream.Read() + s.NoError(err) + + typedMsg := msg.(*cherami.PutMessage) + id, _ := strconv.Atoi(typedMsg.GetID()) + s.Equal(j, id, "id mismatch") + } + } + + err := stream.Close() + s.NoError(err) + + mockConn.AssertExpectations(s.T()) +} + +// TestReadError tests normal read error out +func (s *StreamSuite) TestReadError() { + + mockConn := &mockWS.MockWebsocketConn{} + mockConn.On("SetCloseHandler", mock.Anything).Return(nil).Once() + mockConn.On("ReadMessage").Return(gorilla.BinaryMessage, newTestPayload(1), nil).Once() + mockConn.On("ReadMessage").Return(gorilla.BinaryMessage, nil, errors.New("ReadMessage error")).Once() + mockConn.On("WriteControl", gorilla.CloseMessage, mock.Anything, mock.Anything).Return(nil).Once() + mockConn.On("Close").Return(nil) + + stream := NewStream(mockConn, &StreamOpts{ReadMsgType: testMsgType}) + stream.Start() + + _, err := stream.Read() + s.NoError(err) + + _, err = stream.Read() + s.Error(err) + + err = stream.Close() + s.NoError(err) + + mockConn.AssertExpectations(s.T()) +} + +// TestWrite tests multiple normal write operations +func (s *StreamSuite) TestWrite() { + + flushThreshold := 1 + writeLoop := 3 + totalMsgCount := flushThreshold * writeLoop + + mockConn := &mockWS.MockWebsocketConn{} + mockConn.On("SetCloseHandler", mock.Anything).Return(nil).Once() + mockConn.On("WriteMessage", gorilla.BinaryMessage, mock.Anything).Return(nil).Run(func(args mock.Arguments) { + payload := args.Get(1).([]byte) + msgs, err := common.TListDeserialize(testMsgType, payload) + s.NoError(err) + s.Equal(flushThreshold, len(msgs), "unexpected number of messages") + typedMsg := msgs[0].(*cherami.PutMessage) + id, _ := strconv.Atoi(typedMsg.GetID()) + s.True(id >= 1 && id <= totalMsgCount, "id out of range") + }) + mockConn.On("WriteControl", gorilla.CloseMessage, mock.Anything, mock.Anything).Return(nil).Once() + mockConn.On("Close").Return(nil) + + stream := NewStream(mockConn, &StreamOpts{ReadMsgType: testMsgType, FlushThreshold: flushThreshold}) + stream.Start() + + // write multiple messages + for i := 1; i <= totalMsgCount; i++ { + err := stream.Write(newTestMsg(i)) + s.NoError(err) + } + + err := stream.Close() + s.NoError(err) + mockConn.AssertExpectations(s.T()) +} + +// TestWriteFlush tests the flush mechanisms: +// 1. flush automatically, when we have buffered messages equal to the flushThreshold +// 2. flush, when Flush is called +// 3. flush, when Close is called and there were buffered messages +func (s *StreamSuite) TestWriteFlushThreshold() { + + flushThreshold := 3 + writeLoop := 2 + flushMsgs := 1 + tailMsgs := 2 + + mockConn := &mockWS.MockWebsocketConn{} + mockConn.On("SetCloseHandler", mock.Anything).Return(nil).Once() + mockConn.On("WriteMessage", gorilla.BinaryMessage, mock.Anything).Return(nil).Run(func(args mock.Arguments) { + payload := args.Get(1).([]byte) + msgs, err := common.TListDeserialize(testMsgType, payload) + s.NoError(err) + s.Equal(flushThreshold, len(msgs), "unexpected number of messages") + for i := 0; i < flushThreshold; i++ { + typedMsg := msgs[i].(*cherami.PutMessage) + id, _ := strconv.Atoi(typedMsg.GetID()) + s.Equal(i, id, "id mismatch") + } + }).Times(writeLoop) + mockConn.On("WriteMessage", gorilla.BinaryMessage, mock.Anything).Return(nil).Run(func(args mock.Arguments) { + payload := args.Get(1).([]byte) + msgs, err := common.TListDeserialize(testMsgType, payload) + s.NoError(err) + s.Equal(flushMsgs, len(msgs), "unexpected number of messages") + for i := 0; i < flushMsgs; i++ { + typedMsg := msgs[i].(*cherami.PutMessage) + id, _ := strconv.Atoi(typedMsg.GetID()) + s.Equal(i, id, "id mismatch") + } + }).Once() + mockConn.On("WriteMessage", gorilla.BinaryMessage, mock.Anything).Return(nil).Run(func(args mock.Arguments) { + payload := args.Get(1).([]byte) + msgs, err := common.TListDeserialize(testMsgType, payload) + s.NoError(err) + s.Equal(tailMsgs, len(msgs), "unexpected number of messages") + for i := 0; i < tailMsgs; i++ { + typedMsg := msgs[i].(*cherami.PutMessage) + id, _ := strconv.Atoi(typedMsg.GetID()) + s.Equal(i, id, "id mismatch") + } + }).Once() + mockConn.On("WriteControl", gorilla.CloseMessage, mock.Anything, mock.Anything).Return(nil).Once() + mockConn.On("Close").Return(nil) + + stream := NewStream(mockConn, &StreamOpts{ReadMsgType: testMsgType, FlushThreshold: flushThreshold}) + stream.Start() + + // each write loop will write batch of messages up to flush threshold + for i := 0; i < writeLoop; i++ { + for j := 0; j < flushThreshold; j++ { + + err := stream.Write(newTestMsg(j)) + s.NoError(err) + } + } + + // write some messages and call Flush + for j := 0; j < flushMsgs; j++ { + + err := stream.Write(newTestMsg(j)) + s.NoError(err) + } + err := stream.Flush() + s.NoError(err) + + // write some extra "tail" messages, that should get flushed with the close + for j := 0; j < tailMsgs; j++ { + + err = stream.Write(newTestMsg(j)) + s.NoError(err) + } + + err = stream.Close() + s.NoError(err) + + mockConn.AssertExpectations(s.T()) +} + +// TestWriteError tests write error out +func (s *StreamSuite) TestWriteError() { + + flushThreshold := 2 + + mockConn := &mockWS.MockWebsocketConn{} + mockConn.On("SetCloseHandler", mock.Anything).Return(nil).Once() + mockConn.On("WriteMessage", gorilla.BinaryMessage, mock.Anything).Return(errors.New("WriteMessage error")) + mockConn.On("WriteControl", gorilla.CloseMessage, mock.Anything, mock.Anything).Return(nil).Once() + mockConn.On("Close").Return(nil) + + stream := NewStream(mockConn, &StreamOpts{ReadMsgType: testMsgType, FlushThreshold: flushThreshold}) + stream.Start() + + err := stream.Write(newTestMsg(0)) + s.NoError(err) // the first msg just gets buffered and will not see an error + + err = stream.Write(newTestMsg(1)) + s.Error(err) // the second msg will fill the buffer, causing a flush that will see the error + + err = stream.Close() + s.NoError(err) + + mockConn.AssertExpectations(s.T()) +} + +// mocks the low-level 'Conn' and exposes it as a pair of IO and error channels +type mockConn struct { + writeMsgC chan []byte // WriteMessage writes out messages onto this channel + writeErrC chan error // WriteMessage returns errors sent to this channel + readMsgC chan []byte // ReadMessage reads and returns messages sent to this channel + readErrC chan error // ReadMessage returns errors sent to this channel + closeC chan struct{} // Done closes this channel +} + +type msgType int + +const ( + msgBinary = iota + msgClose + msgPing + msgPong + msgError +) + +type msg struct { + msgType msgType + payload []byte +} + +func newBinaryMsg(payload []byte) *msg { + return &msg{msgType: msgBinary, payload: payload} +} + +func newMockConn() *mockConn { + return &mockConn{ + writeMsgC: make(chan []byte), + readMsgC: make(chan []byte), + closeC: make(chan struct{}), + } +} + +func (t *mockConn) ReadMessage() (int, []byte, error) { + select { + case msg, ok := <-t.readMsgC: + if !ok { + // if readMsgC was closed, return an error + return gorilla.BinaryMessage, nil, errors.New(`EOF`) + } + return gorilla.BinaryMessage, msg, nil + case err := <-t.readErrC: + return gorilla.BinaryMessage, nil, err + + } +} + +func (t *mockConn) WriteMessage(messageType int, msg []byte) (err error) { + select { + case t.writeMsgC <- msg: + return nil + case err = <-t.writeErrC: + return err + case <-t.closeC: + return errors.New(`closed`) + } +} + +func (t *mockConn) WriteControl(messageType int, data []byte, deadline time.Time) error { + return nil +} + +func (t *mockConn) SetCloseHandler(h func(code int, text string) error) { + return +} + +func (t *mockConn) Close() error { + close(t.readMsgC) // on 'conn' close, the websocket layer returns messages on read, until the end + close(t.closeC) + return nil +} + +// wait for given 'timeout' (in milliseconds) for condition 'cond' to satisfy +func waitFor(timeout int, cond func() bool) bool { + for i := 0; i < timeout/10; i++ { + if cond() { + return true + } + time.Sleep(10 * time.Millisecond) + } + return false +} + +// TestWriteWithClose tests to verify that all writes are received despite a close +func (s *StreamSuite) TestWritesWithClose() { + + numMsgs := 1500 // greater than defaultChannelSize (= 1024) + flushThreshold := 32 + writeMessagesDelay := 10 * time.Millisecond + + var lastMsgID int64 + + conn := newMockConn() + + // the following responds to conn.WriteMessage by validating that the messages + // are in order each time also storing the id of the message in 'lastMsgID'. + // in addition, this sleeps to simulate a "slow" write .. this would have the + // effect of keeping the 'writeChan' in websocket-stream to fill up. + // in the meantime, the call to conn.ReadMessages remains blocked. + go func() { + pump: + for { + select { + case payload := <-conn.writeMsgC: // mockConn.WriteMessage puts messages into writeMsgC + msgs, err := common.TListDeserialize(testMsgType, payload) + s.NoError(err) + + for _, m := range msgs { + typedMsg := m.(*cherami.PutMessage) + id, _ := strconv.Atoi(typedMsg.GetID()) + s.Equal(atomic.LoadInt64(&lastMsgID)+1, int64(id), "msg out of order") + atomic.StoreInt64(&lastMsgID, int64(id)) + } + + time.Sleep(writeMessagesDelay) // slow writes, to cause 'writeChan' to fill up + + case <-conn.closeC: // mockConn.Done() closes closeC + break pump + } + } + }() + + stream := NewStream(conn, &StreamOpts{ReadMsgType: testMsgType, FlushThreshold: flushThreshold}) + stream.Start() + + // write multiple messages + for i := 1; i <= numMsgs; i++ { + + err := stream.Write(newTestMsg(i)) + s.NoError(err) + } + + err := stream.Close() + s.NoError(err) + + // wait 2 seconds until all the messages have been flushed out + s.True(waitFor(2000, func() bool { + return int64(numMsgs) == atomic.LoadInt64(&lastMsgID) + }), `missing messages`) +} + +// TestReadWithClose tests to verify that all reads are received despite a close +func (s *StreamSuite) TestReadsWithClose() { + + numMsgs := 3*1024 + 5 // ensure there's more than 'defaultChannelSize' payloads + payloadMsgs := 3 // every 3 messages, make a payload and "send" + + var lastMsgID int64 + + conn := newMockConn() + + stream := NewStream(conn, &StreamOpts{ReadMsgType: testMsgType}) + stream.Start() + + // the following responds to conn.ReadMessage by validating that the messages + // are in order each time also storing the id of the message in 'lastMsgID'. + // in addition, this sleeps to simulate a "slow" read .. this would have the + // effect of keeping the 'readChan' in websocket-stream to fill up. + // in the meantime, the call to conn.ReadMessages remains blocked. + go func() { + + payload := make([]thrift.TStruct, 0, payloadMsgs) + + for i := 1; i <= numMsgs; i++ { + + payload = append(payload, newTestMsg(i)) + + if len(payload) >= payloadMsgs || i == numMsgs { + + pl, err := common.TListSerialize(payload) + if err != nil { + panic("TListSerialize failed") + } + + conn.readMsgC <- pl + payload = payload[:0] // truncate, keeping capacity + } + } + + // once all the payloads are on the connection-stream, call 'Close' on the stream + stream.Close() + }() + + time.Sleep(time.Second) // wait for messages to be written and the stream closed + + // we should still be able to "drain" out all messages on the stream + for { + tMsg, err := stream.Read() + + if err != nil { + break + } + + msg := tMsg.(*cherami.PutMessage) + + id, _ := strconv.Atoi(msg.GetID()) + s.Equal(lastMsgID+1, int64(id), "msg out of order") + lastMsgID = int64(id) + } + + s.EqualValues(numMsgs, lastMsgID) // we should have received all the messages on the stream +} diff --git a/example.go b/example.go new file mode 100644 index 0000000..ce58f94 --- /dev/null +++ b/example.go @@ -0,0 +1,197 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package main + +import ( + "fmt" + "os" + "runtime/debug" + "time" + + cthrift "github.com/uber/cherami-thrift/.generated/go/cherami" + "github.com/uber/cherami-client-go/client/cherami" + "github.com/apache/thrift/lib/go/thrift" + "github.com/uber/tchannel-go" +) + +const ( + port = 4922 // cherami-frontend running on this port by default +) + +// helper function to print out a thrift object in json +func jsonify(obj thrift.TStruct) string { + transport := thrift.NewTMemoryBufferLen(1024) + protocol := thrift.NewTSimpleJSONProtocol(transport) + obj.Write(protocol) + protocol.Flush() + transport.Flush() + return transport.String() +} + +func exitIfError(err error) { + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + debug.PrintStack() + os.Exit(1) + } +} + +func main() { + // get the IP address of the interface + host, _ := tchannel.ListenIP() + + // First, create the client to interact with Cherami + // Here we directly connect to cherami running on host:port + cClient, err := cherami.NewClient("cherami-example", host.String(), port, &cherami.ClientOptions{ + Timeout: time.Minute, + }) + exitIfError(err) + + // Now, create a destination with timestamp to avoid collision. + path := fmt.Sprintf("/test/test_%d", time.Now().UnixNano()) + dType := cthrift.DestinationType_PLAIN + consumedMessagesRetention := int32(3600) + unconsumedMessagesRetention := int32(7200) + ownerEmail := "cherami-client-example@cherami" + + desc, err := cClient.CreateDestination(&cthrift.CreateDestinationRequest{ + Path: &path, + Type: &dType, + ConsumedMessagesRetention: &consumedMessagesRetention, + UnconsumedMessagesRetention: &unconsumedMessagesRetention, + OwnerEmail: &ownerEmail, + }) + + exitIfError(err) + fmt.Printf("%v\n", jsonify(desc)) + + // Create a consumer group for that destination + name := fmt.Sprintf("%s_reader", path) + startTime := int64(0) + lockTimeout := int32(60) + maxDelivery := int32(3) + skipOlder := int32(3600) + + cdesc, err := cClient.CreateConsumerGroup(&cthrift.CreateConsumerGroupRequest{ + DestinationPath: &path, + ConsumerGroupName: &name, + StartFrom: &startTime, + LockTimeoutInSeconds: &lockTimeout, + MaxDeliveryCount: &maxDelivery, + SkipOlderMessagesInSeconds: &skipOlder, + OwnerEmail: &ownerEmail, + }) + + exitIfError(err) + fmt.Printf("%v\n", jsonify(cdesc)) + + // To publish, we need to create a Publisher for the specific destination + publisher := cClient.CreatePublisher(&cherami.CreatePublisherRequest{ + Path: path, + }) + + err = publisher.Open() + exitIfError(err) + + // We will do async publishing, so we need to have a channel to receive + // publish receipts. Spin up a goroutine to print the receipt or error. + receiptCh := make(chan *cherami.PublisherReceipt) + go func() { + for receipt := range receiptCh { + if receipt.Error != nil { + fmt.Fprintf(os.Stdout, "Error for publish ID %s is %s. With context userMsgID: %s\n", receipt.ID, receipt.Error.Error(), receipt.UserContext["userMsgID"]) + } else { + fmt.Fprintf(os.Stdout, "Receipt for publish ID %s is %s. With context userMsgID: %s\n", receipt.ID, receipt.Receipt, receipt.UserContext["userMsgID"]) + } + } + }() + + // To consume, we need to create a Consumer object to handle the consumption + // from the destination. The Consumer is part of the Consumer Group. + consumer := cClient.CreateConsumer(&cherami.CreateConsumerRequest{ + Path: path, + ConsumerGroupName: name, + ConsumerName: "", + PrefetchCount: 1, + Options: &cherami.ClientOptions{ + Timeout: 15 * time.Second, + }, + }) + + // The messages will be delivered via a channel. Spin up a goroutine to print out the message content. + ch := make(chan cherami.Delivery, 1) + _, err = consumer.Open(ch) + + doneCh := make(chan struct{}) + go func() { + i := 0 + for delivery := range ch { + msg := delivery.GetMessage() + fmt.Fprintf(os.Stdout, "msg: '%s', ack_token: %s\n", string(msg.GetPayload().GetData()), delivery.GetDeliveryToken()) + delivery.Ack() + i++ + if i == 10 { + doneCh <- struct{}{} + return + } + } + }() + + // Start publishing + for i := 0; i < 10; i++ { + var id string + data := fmt.Sprintf("message %d", i) + userMsgID := fmt.Sprintf("user-msg-%d", i) + id, err = publisher.PublishAsync(&cherami.PublisherMessage{ + Data: []byte(data), + UserContext: map[string]string{"userMsgID": userMsgID}, + }, receiptCh) + + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + break + } + + fmt.Fprintf(os.Stdout, "Local publish ID for message '%s': %s. With context userMsgID: %s\n", data, id, userMsgID) + } + + publisher.Close() + close(receiptCh) + + // Wait for all messages are consumed. + <-doneCh + + // Clean up consumer group and destination. System will take care of actual deleting of messages. + err = cClient.DeleteConsumerGroup(&cthrift.DeleteConsumerGroupRequest{ + DestinationPath: &path, + ConsumerGroupName: &name, + }) + + exitIfError(err) + + err = cClient.DeleteDestination(&cthrift.DeleteDestinationRequest{ + Path: &path, + }) + + exitIfError(err) + + println("end") +} diff --git a/glide.lock b/glide.lock new file mode 100644 index 0000000..e32fa4a --- /dev/null +++ b/glide.lock @@ -0,0 +1,68 @@ +hash: af047361c4b9178f93fda3c1c7f7b06d7abd462e48183c1fe36bbea933b38933 +updated: 2016-12-27T08:53:57.896554928-08:00 +imports: +- name: github.com/apache/thrift + version: f39d4c8535472db962930fb22d733a4f32ed6fc1 + subpackages: + - lib/go/thrift +- name: github.com/cactus/go-statsd-client + version: d8eabe07bc70ff9ba6a56836cde99d1ea3d005f7 + subpackages: + - statsd +- name: github.com/davecgh/go-spew + version: 346938d642f2ec3594ed81d874461961cd0faa76 + subpackages: + - spew +- name: github.com/gorilla/websocket + version: 3ab3a8b8831546bd18fd182c20687ca853b2bb13 +- name: github.com/opentracing/opentracing-go + version: ac5446f53f2c0fc68dc16dc5f426eae1cd288b34 + subpackages: + - ext + - log +- name: github.com/pborman/uuid + version: 5007efa264d92316c43112bc573e754bc889b7b1 +- name: github.com/pmezard/go-difflib + version: 792786c7400a136282c1664665ae0a8db921c6c2 + subpackages: + - difflib +- name: github.com/Sirupsen/logrus + version: 08a8a7c27e3d058a8989316a850daad1c10bf4ab +- name: github.com/stretchr/objx + version: 1a9d0bb9f541897e62256577b352fdbc1fb4fd94 +- name: github.com/stretchr/testify + version: 2402e8e7a02fc811447d11f881aa9746cdc57983 + subpackages: + - assert + - mock + - require + - suite +- name: github.com/uber-common/bark + version: 8841a0f8e7ca869284ccb29c08a14cf3f4310f46 +- name: github.com/uber-go/atomic + version: 3b8db5e93c4c02efbc313e17b2e796b0914a01fb +- name: github.com/uber/cherami-thrift + version: 07226f4510ac9818d905a94708b8dcbe9dc24cc0 + subpackages: + - .generated/go/cherami +- name: github.com/uber/tchannel-go + version: 2caa315516e1836b7b2eff70d2fcbd23538d1b22 + subpackages: + - hyperbahn + - hyperbahn/gen-go/hyperbahn + - json + - relay + - thrift + - thrift/gen-go/meta + - tnet + - trand + - typed +- name: golang.org/x/net + version: 45e771701b814666a7eb299e6c7a57d0b1799e91 + subpackages: + - context +- name: golang.org/x/sys + version: d4feaf1a7e61e1d9e79e6c4e76c6349e9cab0a03 + subpackages: + - unix +testImports: [] diff --git a/glide.yaml b/glide.yaml new file mode 100644 index 0000000..e1439ad --- /dev/null +++ b/glide.yaml @@ -0,0 +1,28 @@ +package: github.com/uber/cherami-client-go +import: +- package: github.com/uber/cherami-thrift + subpackages: + - .generated/go/cherami +- package: github.com/Sirupsen/logrus +- package: github.com/apache/thrift + subpackages: + - lib/go/thrift +- package: github.com/gorilla/websocket +- package: github.com/pborman/uuid +- package: github.com/stretchr/testify + subpackages: + - mock +- package: github.com/uber-common/bark +- package: github.com/uber/tchannel-go + version: ^1.2.1 + subpackages: + - hyperbahn + - hyperbahn/gen-go/hyperbahn + - json + - thrift + - thrift/gen-go/meta + - tnet + - typed +- package: golang.org/x/net + subpackages: + - context diff --git a/mocks/README.md b/mocks/README.md new file mode 100644 index 0000000..835f05f --- /dev/null +++ b/mocks/README.md @@ -0,0 +1,9 @@ +This directory contains all the mocks required for unittesting. + +For now we have the mocks of the following: +1. common/ - contains all the stuff required by all the common services + (i.e, SCommon interface & ExtController interface) + +2. storehost/ - contains the mock of the storehost (and client) + +3. inputhost/ - contains the mock of the inputhost (and client) diff --git a/mocks/clients/cherami/MockBInOpenPublisherStreamOutCall.go b/mocks/clients/cherami/MockBInOpenPublisherStreamOutCall.go new file mode 100644 index 0000000..82e3687 --- /dev/null +++ b/mocks/clients/cherami/MockBInOpenPublisherStreamOutCall.go @@ -0,0 +1,73 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "github.com/uber/cherami-thrift/.generated/go/cherami" + + "github.com/stretchr/testify/mock" +) + +// MockBInOpenPublisherStreamOutCall is a mock for BInOpenPublisherStreamOutCall used for unit testing of cherami client +type MockBInOpenPublisherStreamOutCall struct { + mock.Mock +} + +// Write writes an argument to the request stream. The written items may not +// be sent till Flush or Done is called. routinee +func (m *MockBInOpenPublisherStreamOutCall) Write(arg *cherami.PutMessage) error { + args := m.Called(arg) + return args.Error(0) +} + +// Flush flushes all written arguments. routinee +func (m *MockBInOpenPublisherStreamOutCall) Flush() error { + args := m.Called() + return args.Error(0) +} + +// Done closes the request stream and should be called after all arguments have been written. routinee +func (m *MockBInOpenPublisherStreamOutCall) Done() error { + args := m.Called() + return args.Error(0) +} + +// Read returns the next result, if any is available. If there are no more +// results left, it will return io.EOF. routinee +func (m *MockBInOpenPublisherStreamOutCall) Read() (*cherami.InputHostCommand, error) { + args := m.Called() + var cmd *cherami.InputHostCommand + if args.Get(0) != nil { + cmd = args.Get(0).(*cherami.InputHostCommand) + } + return cmd, args.Error(1) +} + +// ResponseHeaders returns the response headers sent from the server. This will +// block until server headers have been received. routinee +func (m *MockBInOpenPublisherStreamOutCall) ResponseHeaders() (map[string]string, error) { + args := m.Called() + var headers map[string]string + if args.Error(1) == nil { + headers = args.Get(0).(map[string]string) + } + return headers, args.Error(1) +} diff --git a/mocks/clients/cherami/MockBOutOpenConsumerStreamOutCall.go b/mocks/clients/cherami/MockBOutOpenConsumerStreamOutCall.go new file mode 100644 index 0000000..d69bd7b --- /dev/null +++ b/mocks/clients/cherami/MockBOutOpenConsumerStreamOutCall.go @@ -0,0 +1,73 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "github.com/uber/cherami-thrift/.generated/go/cherami" + + "github.com/stretchr/testify/mock" +) + +// MockBOutOpenConsumerStreamOutCall is a mock for BOutOpenConsumerStreamOutCall used for unit testing of cherami client +type MockBOutOpenConsumerStreamOutCall struct { + mock.Mock +} + +// Write writes an argument to the request stream. The written items may not +// be sent till Flush or Done is called. routinee +func (m *MockBOutOpenConsumerStreamOutCall) Write(arg *cherami.ControlFlow) error { + args := m.Called(arg) + return args.Error(0) +} + +// Flush flushes all written arguments. routinee +func (m *MockBOutOpenConsumerStreamOutCall) Flush() error { + args := m.Called() + return args.Error(0) +} + +// Done closes the request stream and should be called after all arguments have been written. routinee +func (m *MockBOutOpenConsumerStreamOutCall) Done() error { + args := m.Called() + return args.Error(0) +} + +// Read returns the next result, if any is available. If there are no more +// results left, it will return io.EOF. routinee +func (m *MockBOutOpenConsumerStreamOutCall) Read() (*cherami.OutputHostCommand, error) { + args := m.Called() + var cmd *cherami.OutputHostCommand + if args.Get(0) != nil { + cmd = args.Get(0).(*cherami.OutputHostCommand) + } + return cmd, args.Error(1) +} + +// ResponseHeaders returns the response headers sent from the server. This will +// block until server headers have been received. routinee +func (m *MockBOutOpenConsumerStreamOutCall) ResponseHeaders() (map[string]string, error) { + args := m.Called() + var headers map[string]string + if args.Error(1) == nil { + headers = args.Get(0).(map[string]string) + } + return headers, args.Error(1) +} diff --git a/mocks/clients/cherami/MockTChanBInClient.go b/mocks/clients/cherami/MockTChanBInClient.go new file mode 100644 index 0000000..d7e5f5e --- /dev/null +++ b/mocks/clients/cherami/MockTChanBInClient.go @@ -0,0 +1,56 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "github.com/uber/cherami-thrift/.generated/go/cherami" + "github.com/uber/cherami-client-go/stream" + + "github.com/stretchr/testify/mock" + "github.com/uber/tchannel-go/thrift" +) + +// MockTChanBInClient is a mock for TChanBInClient used for unit testing of cherami client +type MockTChanBInClient struct { + mock.Mock +} + +func (m *MockTChanBInClient) OpenPublisherStream(ctx thrift.Context) (stream.BInOpenPublisherStreamOutCall, error) { + args := m.Called(ctx) + + var r0 stream.BInOpenPublisherStreamOutCall + if args.Get(0) != nil { + r0 = args.Get(0).(stream.BInOpenPublisherStreamOutCall) + } + + return r0, args.Error(1) +} + +func (m *MockTChanBInClient) PutMessageBatch(ctx thrift.Context, request *cherami.PutMessageBatchRequest) (*cherami.PutMessageBatchResult_, error) { + args := m.Called(ctx, request) + + var result *cherami.PutMessageBatchResult_ + if args.Get(0) != nil { + result = args.Get(0).(*cherami.PutMessageBatchResult_) + } + + return result, args.Error(1) +} diff --git a/mocks/clients/cherami/MockTChanBOutClient.go b/mocks/clients/cherami/MockTChanBOutClient.go new file mode 100644 index 0000000..018c1e7 --- /dev/null +++ b/mocks/clients/cherami/MockTChanBOutClient.go @@ -0,0 +1,79 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import ( + "github.com/uber/cherami-thrift/.generated/go/cherami" + "github.com/uber/cherami-client-go/stream" + + "github.com/stretchr/testify/mock" + "github.com/uber/tchannel-go/thrift" +) + +// MockBOutOpenConsumerStreamOutCall is a mock for BOutOpenConsumerStreamOutCall used for unit testing of cherami client +type MockTChanBOutClient struct { + mock.Mock +} + +func (m *MockTChanBOutClient) AckMessages(ctx thrift.Context, ackRequest *cherami.AckMessagesRequest) error { + args := m.Called(ctx, ackRequest) + + return args.Error(0) +} + +func (m *MockTChanBOutClient) OpenConsumerStream(ctx thrift.Context) (stream.BOutOpenConsumerStreamOutCall, error) { + args := m.Called(ctx) + + var r0 stream.BOutOpenConsumerStreamOutCall + if args.Get(0) != nil { + r0 = args.Get(0).(stream.BOutOpenConsumerStreamOutCall) + } + + return r0, args.Error(1) +} + +func (m *MockTChanBOutClient) OpenStreamingConsumerStream(ctx thrift.Context) (stream.BOutOpenStreamingConsumerStreamOutCall, error) { + args := m.Called(ctx) + + var r0 stream.BOutOpenStreamingConsumerStreamOutCall + if args.Get(0) != nil { + r0 = args.Get(0).(stream.BOutOpenStreamingConsumerStreamOutCall) + } + + return r0, args.Error(1) +} + +func (m *MockTChanBOutClient) SetConsumedMessages(ctx thrift.Context, request *cherami.SetConsumedMessagesRequest) error { + args := m.Called(ctx, request) + + return args.Error(0) +} + +func (m *MockTChanBOutClient) ReceiveMessageBatch(ctx thrift.Context, request *cherami.ReceiveMessageBatchRequest) (*cherami.ReceiveMessageBatchResult_, error) { + args := m.Called(ctx, request) + + var result *cherami.ReceiveMessageBatchResult_ + if args.Get(0) != nil { + result = args.Get(0).(*cherami.ReceiveMessageBatchResult_) + } + + return result, args.Error(1) +} diff --git a/mocks/clients/cherami/MockWSConnector.go b/mocks/clients/cherami/MockWSConnector.go new file mode 100644 index 0000000..78562a1 --- /dev/null +++ b/mocks/clients/cherami/MockWSConnector.go @@ -0,0 +1,70 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cherami + +import "net/http" +import "github.com/stretchr/testify/mock" + +import "github.com/uber/cherami-client-go/stream" + +// MockWSConnector is a mock for WSConnector used for unit testing +type MockWSConnector struct { + mock.Mock +} + +func (_m *MockWSConnector) OpenPublisherStream(hostPort string, requestHeader http.Header) (stream.BInOpenPublisherStreamOutCall, error) { + ret := _m.Called(hostPort, requestHeader) + + var r0 stream.BInOpenPublisherStreamOutCall + if rf, ok := ret.Get(0).(func(string, http.Header) stream.BInOpenPublisherStreamOutCall); ok { + r0 = rf(hostPort, requestHeader) + } else { + r0 = ret.Get(0).(stream.BInOpenPublisherStreamOutCall) + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, http.Header) error); ok { + r1 = rf(hostPort, requestHeader) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} +func (_m *MockWSConnector) OpenConsumerStream(hostPort string, requestHeader http.Header) (stream.BOutOpenConsumerStreamOutCall, error) { + ret := _m.Called(hostPort, requestHeader) + + var r0 stream.BOutOpenConsumerStreamOutCall + if rf, ok := ret.Get(0).(func(string, http.Header) stream.BOutOpenConsumerStreamOutCall); ok { + r0 = rf(hostPort, requestHeader) + } else { + r0 = ret.Get(0).(stream.BOutOpenConsumerStreamOutCall) + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, http.Header) error); ok { + r1 = rf(hostPort, requestHeader) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/mocks/common/websocket/MockWebsocketConn.go b/mocks/common/websocket/MockWebsocketConn.go new file mode 100644 index 0000000..f1578da --- /dev/null +++ b/mocks/common/websocket/MockWebsocketConn.go @@ -0,0 +1,114 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package websocket + +import "github.com/stretchr/testify/mock" + +import "time" + +// MockWebsocketConn is a mock for WebsocketConn used for unit testing +type MockWebsocketConn struct { + mock.Mock +} + +// ReadMessage reads a message +func (_m *MockWebsocketConn) ReadMessage() (int, []byte, error) { + ret := _m.Called() + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + var r1 []byte + if rf, ok := ret.Get(1).(func() []byte); ok { + r1 = rf() + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([]byte) + } + } + + var r2 error + if rf, ok := ret.Get(2).(func() error); ok { + r2 = rf() + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// WriteMessage writes a message +func (_m *MockWebsocketConn) WriteMessage(messageType int, data []byte) error { + ret := _m.Called(messageType, data) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []byte) error); ok { + r0 = rf(messageType, data) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// WriteControl writes a control message +func (_m *MockWebsocketConn) WriteControl(messageType int, data []byte, deadline time.Time) error { + ret := _m.Called(messageType, data, deadline) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []byte, time.Time) error); ok { + r0 = rf(messageType, data, deadline) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetCloseHandler sets a close handler +func (_m *MockWebsocketConn) SetCloseHandler(h func(code int, text string) error) { + + ret := _m.Called(h) + + if rf, ok := ret.Get(0).(func(h func(code int, text string) error) error); ok { + rf(h) + } + + return +} + +// Close closes the connection +func (_m *MockWebsocketConn) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/stream/stream.go b/stream/stream.go new file mode 100644 index 0000000..d2ab72f --- /dev/null +++ b/stream/stream.go @@ -0,0 +1,91 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package stream + +import ( + "github.com/uber/cherami-thrift/.generated/go/cherami" +) + +// BInOpenPublisherStreamOutCall is the object used to stream arguments/results and +// read response headers for outgoing calls. +type BInOpenPublisherStreamOutCall interface { + // Write writes an argument to the request stream. The written items may not + // be sent till Flush or Done is called. + Write(arg *cherami.PutMessage) error + + // Flush flushes all written arguments. + Flush() error + + // Done closes the request stream and should be called after all arguments have been written. + Done() error + + // Read returns the next result, if any is available. If there are no more + // results left, it will return io.EOF. + Read() (*cherami.InputHostCommand, error) + + // ResponseHeaders returns the response headers sent from the server. This will + // block until server headers have been received. + ResponseHeaders() (map[string]string, error) +} + +// BOutOpenConsumerStreamOutCall is the object used to stream arguments/results and +// read response headers for outgoing calls. +type BOutOpenConsumerStreamOutCall interface { + // Write writes an argument to the request stream. The written items may not + // be sent till Flush or Done is called. + Write(arg *cherami.ControlFlow) error + + // Flush flushes all written arguments. + Flush() error + + // Done closes the request stream and should be called after all arguments have been written. + Done() error + + // Read returns the next result, if any is available. If there are no more + // results left, it will return io.EOF. + Read() (*cherami.OutputHostCommand, error) + + // ResponseHeaders returns the response headers sent from the server. This will + // block until server headers have been received. + ResponseHeaders() (map[string]string, error) +} + +// BOutOpenStreamingConsumerStreamOutCall is the object used to stream arguments/results and +// read response headers for outgoing calls. +type BOutOpenStreamingConsumerStreamOutCall interface { + // Write writes an argument to the request stream. The written items may not + // be sent till Flush or Done is called. + Write(arg *cherami.ControlFlow) error + + // Flush flushes all written arguments. + Flush() error + + // Done closes the request stream and should be called after all arguments have been written. + Done() error + + // Read returns the next result, if any is available. If there are no more + // results left, it will return io.EOF. + Read() (*cherami.OutputHostCommand, error) + + // ResponseHeaders returns the response headers sent from the server. This will + // block until server headers have been received. + ResponseHeaders() (map[string]string, error) +}