diff --git a/drivers/exec/driver.go b/drivers/exec/driver.go index 8c5cfbeabf2..b5f723de280 100644 --- a/drivers/exec/driver.go +++ b/drivers/exec/driver.go @@ -1,6 +1,7 @@ package exec import ( + "context" "fmt" "os" "path/filepath" @@ -21,7 +22,6 @@ import ( "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" "github.com/hashicorp/nomad/plugins/shared/loader" - "golang.org/x/net/context" ) const ( diff --git a/drivers/java/driver.go b/drivers/java/driver.go index dce0467f89c..5a68b5ef93f 100644 --- a/drivers/java/driver.go +++ b/drivers/java/driver.go @@ -1,6 +1,7 @@ package java import ( + "context" "fmt" "os" "os/exec" @@ -23,7 +24,6 @@ import ( "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" "github.com/hashicorp/nomad/plugins/shared/loader" - "golang.org/x/net/context" ) const ( diff --git a/drivers/mock/driver.go b/drivers/mock/driver.go index 734ae61f8da..6fb2d1b5277 100644 --- a/drivers/mock/driver.go +++ b/drivers/mock/driver.go @@ -16,7 +16,6 @@ import ( "github.com/hashicorp/nomad/plugins/drivers" "github.com/hashicorp/nomad/plugins/shared/hclspec" "github.com/hashicorp/nomad/plugins/shared/loader" - netctx "golang.org/x/net/context" ) const ( @@ -232,7 +231,7 @@ func (d *Driver) Capabilities() (*drivers.Capabilities, error) { return capabilities, nil } -func (d *Driver) Fingerprint(ctx netctx.Context) (<-chan *drivers.Fingerprint, error) { +func (d *Driver) Fingerprint(ctx context.Context) (<-chan *drivers.Fingerprint, error) { ch := make(chan *drivers.Fingerprint) go d.handleFingerprint(ctx, ch) return ch, nil @@ -365,7 +364,7 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru } -func (d *Driver) WaitTask(ctx netctx.Context, taskID string) (<-chan *drivers.ExitResult, error) { +func (d *Driver) WaitTask(ctx context.Context, taskID string) (<-chan *drivers.ExitResult, error) { handle, ok := d.tasks.Get(taskID) if !ok { return nil, drivers.ErrTaskNotFound @@ -430,7 +429,7 @@ func (d *Driver) TaskStats(taskID string) (*cstructs.TaskResourceUsage, error) { return nil, nil } -func (d *Driver) TaskEvents(ctx netctx.Context) (<-chan *drivers.TaskEvent, error) { +func (d *Driver) TaskEvents(ctx context.Context) (<-chan *drivers.TaskEvent, error) { return d.eventer.TaskEvents(ctx) } diff --git a/drivers/qemu/driver.go b/drivers/qemu/driver.go index 363771515be..c8a7ac50cab 100644 --- a/drivers/qemu/driver.go +++ b/drivers/qemu/driver.go @@ -1,6 +1,7 @@ package qemu import ( + "context" "errors" "fmt" "net" @@ -25,7 +26,6 @@ import ( "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" "github.com/hashicorp/nomad/plugins/shared/loader" - "golang.org/x/net/context" ) const ( diff --git a/drivers/rawexec/driver.go b/drivers/rawexec/driver.go index 0e73a27ab7a..afac6c550b9 100644 --- a/drivers/rawexec/driver.go +++ b/drivers/rawexec/driver.go @@ -1,6 +1,7 @@ package rawexec import ( + "context" "fmt" "os" "path/filepath" @@ -22,7 +23,6 @@ import ( "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" "github.com/hashicorp/nomad/plugins/shared/loader" - "golang.org/x/net/context" ) const ( diff --git a/drivers/rkt/driver.go b/drivers/rkt/driver.go index 4ffa69f8d76..74cac57b34c 100644 --- a/drivers/rkt/driver.go +++ b/drivers/rkt/driver.go @@ -4,6 +4,7 @@ package rkt import ( "bytes" + "context" "encoding/json" "fmt" "io/ioutil" @@ -36,7 +37,6 @@ import ( "github.com/hashicorp/nomad/plugins/shared/hclspec" "github.com/hashicorp/nomad/plugins/shared/loader" rktv1 "github.com/rkt/rkt/api/v1" - "golang.org/x/net/context" ) const ( diff --git a/drivers/rkt/driver_test.go b/drivers/rkt/driver_test.go index 7847cb72a25..0edfec535ad 100644 --- a/drivers/rkt/driver_test.go +++ b/drivers/rkt/driver_test.go @@ -3,17 +3,16 @@ package rkt import ( + "bytes" + "context" "fmt" "io/ioutil" + "os" "path/filepath" "sync" "testing" "time" - "os" - - "bytes" - "github.com/hashicorp/hcl2/hcl" ctestutil "github.com/hashicorp/nomad/client/testutil" "github.com/hashicorp/nomad/helper/testlog" @@ -26,7 +25,6 @@ import ( "github.com/hashicorp/nomad/plugins/shared/hclspec" "github.com/hashicorp/nomad/testutil" "github.com/stretchr/testify/require" - "golang.org/x/net/context" ) var _ drivers.DriverPlugin = (*Driver)(nil) diff --git a/drivers/shared/eventer/eventer.go b/drivers/shared/eventer/eventer.go index a68a2016240..1e7674ee4b1 100644 --- a/drivers/shared/eventer/eventer.go +++ b/drivers/shared/eventer/eventer.go @@ -1,12 +1,12 @@ package eventer import ( + "context" "sync" "time" hclog "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/plugins/drivers" - "golang.org/x/net/context" ) var ( diff --git a/plugins/base/client.go b/plugins/base/client.go index f5476cef70f..6baf9a07d91 100644 --- a/plugins/base/client.go +++ b/plugins/base/client.go @@ -12,10 +12,13 @@ import ( // gRPC to communicate to the remote plugin. type BasePluginClient struct { Client proto.BasePluginClient + + // DoneCtx is closed when the plugin exits + DoneCtx context.Context } func (b *BasePluginClient) PluginInfo() (*PluginInfoResponse, error) { - presp, err := b.Client.PluginInfo(context.Background(), &proto.PluginInfoRequest{}) + presp, err := b.Client.PluginInfo(b.DoneCtx, &proto.PluginInfoRequest{}) if err != nil { return nil, err } @@ -41,7 +44,7 @@ func (b *BasePluginClient) PluginInfo() (*PluginInfoResponse, error) { } func (b *BasePluginClient) ConfigSchema() (*hclspec.Spec, error) { - presp, err := b.Client.ConfigSchema(context.Background(), &proto.ConfigSchemaRequest{}) + presp, err := b.Client.ConfigSchema(b.DoneCtx, &proto.ConfigSchemaRequest{}) if err != nil { return nil, err } @@ -51,7 +54,7 @@ func (b *BasePluginClient) ConfigSchema() (*hclspec.Spec, error) { func (b *BasePluginClient) SetConfig(data []byte, config *ClientAgentConfig) error { // Send the config - _, err := b.Client.SetConfig(context.Background(), &proto.SetConfigRequest{ + _, err := b.Client.SetConfig(b.DoneCtx, &proto.SetConfigRequest{ MsgpackConfig: data, NomadConfig: config.toProto(), }) diff --git a/plugins/base/plugin.go b/plugins/base/plugin.go index a386d2c4512..411c796629f 100644 --- a/plugins/base/plugin.go +++ b/plugins/base/plugin.go @@ -51,7 +51,10 @@ func (p *PluginBase) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error } func (p *PluginBase) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) { - return &BasePluginClient{Client: proto.NewBasePluginClient(c)}, nil + return &BasePluginClient{ + Client: proto.NewBasePluginClient(c), + DoneCtx: ctx, + }, nil } // MsgpackHandle is a shared handle for encoding/decoding of structs diff --git a/plugins/device/client.go b/plugins/device/client.go index d20146e7501..ffbb80166fa 100644 --- a/plugins/device/client.go +++ b/plugins/device/client.go @@ -9,9 +9,7 @@ import ( "github.com/golang/protobuf/ptypes" "github.com/hashicorp/nomad/plugins/base" "github.com/hashicorp/nomad/plugins/device/proto" - netctx "golang.org/x/net/context" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "github.com/hashicorp/nomad/plugins/shared" ) // devicePluginClient implements the client side of a remote device plugin, using @@ -49,28 +47,33 @@ func (d *devicePluginClient) Fingerprint(ctx context.Context) (<-chan *Fingerpri // the gRPC stream to a channel. Exits either when context is cancelled or the // stream has an error. func (d *devicePluginClient) handleFingerprint( - ctx netctx.Context, + ctx context.Context, stream proto.DevicePlugin_FingerprintClient, out chan *FingerprintResponse) { + defer close(out) for { resp, err := stream.Recv() if err != nil { if err != io.EOF { out <- &FingerprintResponse{ - Error: d.handleStreamErr(err, ctx), + Error: shared.HandleStreamErr(err, ctx, d.doneCtx), } } // End the stream - close(out) return } // Send the response - out <- &FingerprintResponse{ + f := &FingerprintResponse{ Devices: convertProtoDeviceGroups(resp.GetDeviceGroup()), } + select { + case <-ctx.Done(): + return + case out <- f: + } } } @@ -116,69 +119,32 @@ func (d *devicePluginClient) Stats(ctx context.Context, interval time.Duration) // the gRPC stream to a channel. Exits either when context is cancelled or the // stream has an error. func (d *devicePluginClient) handleStats( - ctx netctx.Context, + ctx context.Context, stream proto.DevicePlugin_StatsClient, out chan *StatsResponse) { + defer close(out) for { resp, err := stream.Recv() if err != nil { if err != io.EOF { out <- &StatsResponse{ - Error: d.handleStreamErr(err, ctx), + Error: shared.HandleStreamErr(err, ctx, d.doneCtx), } } // End the stream - close(out) return } // Send the response - out <- &StatsResponse{ + s := &StatsResponse{ Groups: convertProtoDeviceGroupsStats(resp.GetGroups()), } - } -} - -// handleStreamErr is used to handle a non io.EOF error in a stream. It handles -// detecting if the plugin has shutdown -func (d *devicePluginClient) handleStreamErr(err error, ctx context.Context) error { - if err == nil { - return nil - } - - // Determine if the error is because the plugin shutdown - if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.Unavailable { - // Potentially wait a little before returning an error so we can detect - // the exit select { - case <-d.doneCtx.Done(): - err = base.ErrPluginShutdown case <-ctx.Done(): - err = ctx.Err() - - // There is no guarantee that the select will choose the - // doneCtx first so we have to double check - select { - case <-d.doneCtx.Done(): - err = base.ErrPluginShutdown - default: - } - case <-time.After(3 * time.Second): - // Its okay to wait a while since the connection isn't available and - // on local host it is likely shutting down. It is not expected for - // this to ever reach even close to 3 seconds. + return + case out <- s: } - - // It is an error we don't know how to handle, so return it - return err } - - // Context was cancelled - if errStatus := status.FromContextError(ctx.Err()); errStatus.Code() == codes.Canceled { - return context.Canceled - } - - return err } diff --git a/plugins/device/plugin.go b/plugins/device/plugin.go index f0373385737..65ec19540cf 100644 --- a/plugins/device/plugin.go +++ b/plugins/device/plugin.go @@ -31,7 +31,8 @@ func (p *PluginDevice) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker doneCtx: ctx, client: proto.NewDevicePluginClient(c), BasePluginClient: &base.BasePluginClient{ - Client: bproto.NewBasePluginClient(c), + Client: bproto.NewBasePluginClient(c), + DoneCtx: ctx, }, }, nil } diff --git a/plugins/drivers/client.go b/plugins/drivers/client.go index 1fb60ccde57..6974f8dcfb0 100644 --- a/plugins/drivers/client.go +++ b/plugins/drivers/client.go @@ -1,18 +1,19 @@ package drivers import ( + "context" "errors" - "fmt" "io" "time" + "github.com/LK4D4/joincontext" "github.com/golang/protobuf/ptypes" hclog "github.com/hashicorp/go-hclog" cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/plugins/base" "github.com/hashicorp/nomad/plugins/drivers/proto" + "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" - "golang.org/x/net/context" ) var _ DriverPlugin = &driverPluginClient{} @@ -22,12 +23,15 @@ type driverPluginClient struct { client proto.DriverClient logger hclog.Logger + + // doneCtx is closed when the plugin exits + doneCtx context.Context } func (d *driverPluginClient) TaskConfigSchema() (*hclspec.Spec, error) { req := &proto.TaskConfigSchemaRequest{} - resp, err := d.client.TaskConfigSchema(context.Background(), req) + resp, err := d.client.TaskConfigSchema(d.doneCtx, req) if err != nil { return nil, err } @@ -38,7 +42,7 @@ func (d *driverPluginClient) TaskConfigSchema() (*hclspec.Spec, error) { func (d *driverPluginClient) Capabilities() (*Capabilities, error) { req := &proto.CapabilitiesRequest{} - resp, err := d.client.Capabilities(context.Background(), req) + resp, err := d.client.Capabilities(d.doneCtx, req) if err != nil { return nil, err } @@ -67,12 +71,15 @@ func (d *driverPluginClient) Capabilities() (*Capabilities, error) { func (d *driverPluginClient) Fingerprint(ctx context.Context) (<-chan *Fingerprint, error) { req := &proto.FingerprintRequest{} + // Join the passed context and the shutdown context + ctx, _ = joincontext.Join(ctx, d.doneCtx) + stream, err := d.client.Fingerprint(ctx, req) if err != nil { return nil, err } - ch := make(chan *Fingerprint) + ch := make(chan *Fingerprint, 1) go d.handleFingerprint(ctx, ch, stream) return ch, nil @@ -82,17 +89,18 @@ func (d *driverPluginClient) handleFingerprint(ctx context.Context, ch chan *Fin defer close(ch) for { pb, err := stream.Recv() - if err == io.EOF { - return - } if err != nil { - select { - case <-ctx.Done(): - case ch <- &Fingerprint{Err: fmt.Errorf("error from RPC stream: %v", err)}: + if err != io.EOF { d.logger.Error("error receiving stream from Fingerprint driver RPC", "error", err) + ch <- &Fingerprint{ + Err: shared.HandleStreamErr(err, ctx, d.doneCtx), + } } + + // End the stream return } + f := &Fingerprint{ Attributes: pb.Attributes, Health: healthStateFromProto(pb.Health), @@ -112,7 +120,7 @@ func (d *driverPluginClient) handleFingerprint(ctx context.Context, ch chan *Fin func (d *driverPluginClient) RecoverTask(h *TaskHandle) error { req := &proto.RecoverTaskRequest{Handle: taskHandleToProto(h)} - _, err := d.client.RecoverTask(context.Background(), req) + _, err := d.client.RecoverTask(d.doneCtx, req) return err } @@ -124,7 +132,7 @@ func (d *driverPluginClient) StartTask(c *TaskConfig) (*TaskHandle, *cstructs.Dr Task: taskConfigToProto(c), } - resp, err := d.client.StartTask(context.Background(), req) + resp, err := d.client.StartTask(d.doneCtx, req) if err != nil { return nil, nil, err } @@ -150,6 +158,10 @@ func (d *driverPluginClient) StartTask(c *TaskConfig) (*TaskHandle, *cstructs.Dr // the same task without issue. func (d *driverPluginClient) WaitTask(ctx context.Context, id string) (<-chan *ExitResult, error) { ch := make(chan *ExitResult) + + // Join the passed context and the shutdown context + ctx, _ = joincontext.Join(ctx, d.doneCtx) + go d.handleWaitTask(ctx, id, ch) return ch, nil } @@ -186,7 +198,7 @@ func (d *driverPluginClient) StopTask(taskID string, timeout time.Duration, sign Signal: signal, } - _, err := d.client.StopTask(context.Background(), req) + _, err := d.client.StopTask(d.doneCtx, req) return err } @@ -199,7 +211,7 @@ func (d *driverPluginClient) DestroyTask(taskID string, force bool) error { Force: force, } - _, err := d.client.DestroyTask(context.Background(), req) + _, err := d.client.DestroyTask(d.doneCtx, req) return err } @@ -207,7 +219,7 @@ func (d *driverPluginClient) DestroyTask(taskID string, force bool) error { func (d *driverPluginClient) InspectTask(taskID string) (*TaskStatus, error) { req := &proto.InspectTaskRequest{TaskId: taskID} - resp, err := d.client.InspectTask(context.Background(), req) + resp, err := d.client.InspectTask(d.doneCtx, req) if err != nil { return nil, err } @@ -238,7 +250,7 @@ func (d *driverPluginClient) InspectTask(taskID string) (*TaskStatus, error) { func (d *driverPluginClient) TaskStats(taskID string) (*cstructs.TaskResourceUsage, error) { req := &proto.TaskStatsRequest{TaskId: taskID} - resp, err := d.client.TaskStats(context.Background(), req) + resp, err := d.client.TaskStats(d.doneCtx, req) if err != nil { return nil, err } @@ -255,28 +267,36 @@ func (d *driverPluginClient) TaskStats(taskID string) (*cstructs.TaskResourceUsa // tasks such as lifecycle events, terminal errors, etc. func (d *driverPluginClient) TaskEvents(ctx context.Context) (<-chan *TaskEvent, error) { req := &proto.TaskEventsRequest{} + + // Join the passed context and the shutdown context + ctx, _ = joincontext.Join(ctx, d.doneCtx) + stream, err := d.client.TaskEvents(ctx, req) if err != nil { return nil, err } - ch := make(chan *TaskEvent) - go d.handleTaskEvents(ch, stream) + ch := make(chan *TaskEvent, 1) + go d.handleTaskEvents(ctx, ch, stream) return ch, nil } -func (d *driverPluginClient) handleTaskEvents(ch chan *TaskEvent, stream proto.Driver_TaskEventsClient) { +func (d *driverPluginClient) handleTaskEvents(ctx context.Context, ch chan *TaskEvent, stream proto.Driver_TaskEventsClient) { defer close(ch) for { ev, err := stream.Recv() - if err == io.EOF { - break - } if err != nil { - d.logger.Error("error receiving stream from TaskEvents driver RPC", "error", err) - ch <- &TaskEvent{Err: err} - break + if err != io.EOF { + d.logger.Error("error receiving stream from TaskEvents driver RPC", "error", err) + ch <- &TaskEvent{ + Err: shared.HandleStreamErr(err, ctx, d.doneCtx), + } + } + + // End the stream + return } + timestamp, _ := ptypes.Timestamp(ev.Timestamp) event := &TaskEvent{ TaskID: ev.TaskId, @@ -284,7 +304,11 @@ func (d *driverPluginClient) handleTaskEvents(ch chan *TaskEvent, stream proto.D Message: ev.Message, Timestamp: timestamp, } - ch <- event + select { + case <-ctx.Done(): + return + case ch <- event: + } } } @@ -294,7 +318,7 @@ func (d *driverPluginClient) SignalTask(taskID string, signal string) error { TaskId: taskID, Signal: signal, } - _, err := d.client.SignalTask(context.Background(), req) + _, err := d.client.SignalTask(d.doneCtx, req) return err } @@ -309,7 +333,7 @@ func (d *driverPluginClient) ExecTask(taskID string, cmd []string, timeout time. Timeout: ptypes.DurationProto(timeout), } - resp, err := d.client.ExecTask(context.Background(), req) + resp, err := d.client.ExecTask(d.doneCtx, req) if err != nil { return nil, err } diff --git a/plugins/drivers/driver.go b/plugins/drivers/driver.go index 2bb7267c4b1..458635f6d61 100644 --- a/plugins/drivers/driver.go +++ b/plugins/drivers/driver.go @@ -1,6 +1,7 @@ package drivers import ( + "context" "fmt" "path/filepath" "sort" @@ -14,7 +15,6 @@ import ( "github.com/hashicorp/nomad/plugins/shared/hclspec" "github.com/zclconf/go-cty/cty" "github.com/zclconf/go-cty/cty/msgpack" - "golang.org/x/net/context" ) // DriverPlugin is the interface with drivers will implement. It is also diff --git a/plugins/drivers/plugin.go b/plugins/drivers/plugin.go index b485c883651..67165cb8a37 100644 --- a/plugins/drivers/plugin.go +++ b/plugins/drivers/plugin.go @@ -38,9 +38,11 @@ func (p *PluginDriver) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) err func (p *PluginDriver) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) { return &driverPluginClient{ BasePluginClient: &base.BasePluginClient{ - Client: baseproto.NewBasePluginClient(c), + DoneCtx: ctx, + Client: baseproto.NewBasePluginClient(c), }, - client: proto.NewDriverClient(c), - logger: p.logger, + client: proto.NewDriverClient(c), + logger: p.logger, + doneCtx: ctx, }, nil } diff --git a/plugins/drivers/plugin_test.go b/plugins/drivers/plugin_test.go index 3409124e8cb..0bb01ed9f24 100644 --- a/plugins/drivers/plugin_test.go +++ b/plugins/drivers/plugin_test.go @@ -2,6 +2,7 @@ package drivers import ( "bytes" + "context" "sync" "testing" "time" @@ -10,7 +11,6 @@ import ( "github.com/hashicorp/nomad/nomad/structs" "github.com/stretchr/testify/require" "github.com/ugorji/go/codec" - "golang.org/x/net/context" ) type testDriverState struct { diff --git a/plugins/drivers/server.go b/plugins/drivers/server.go index 4ad385e24e9..bbe73e73d88 100644 --- a/plugins/drivers/server.go +++ b/plugins/drivers/server.go @@ -4,13 +4,12 @@ import ( "fmt" "io" - "golang.org/x/net/context" - "github.com/golang/protobuf/ptypes" hclog "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/plugins/drivers/proto" + context "golang.org/x/net/context" ) type driverPluginServer struct { diff --git a/plugins/drivers/testing.go b/plugins/drivers/testing.go index 000c81b92db..3bace5fb141 100644 --- a/plugins/drivers/testing.go +++ b/plugins/drivers/testing.go @@ -1,16 +1,13 @@ package drivers import ( + "context" "fmt" "io/ioutil" "path/filepath" "runtime" "time" - "github.com/mitchellh/go-testing-interface" - "github.com/stretchr/testify/require" - "golang.org/x/net/context" - hclog "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" "github.com/hashicorp/nomad/client/allocdir" @@ -21,6 +18,8 @@ import ( "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/plugins/base" "github.com/hashicorp/nomad/plugins/shared/hclspec" + "github.com/mitchellh/go-testing-interface" + "github.com/stretchr/testify/require" ) type DriverHarness struct { diff --git a/plugins/shared/grpc_utils.go b/plugins/shared/grpc_utils.go new file mode 100644 index 00000000000..34fb33a870c --- /dev/null +++ b/plugins/shared/grpc_utils.go @@ -0,0 +1,61 @@ +package shared + +import ( + "context" + "time" + + "github.com/hashicorp/nomad/plugins/base" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// HandleStreamErr is used to handle a non io.EOF error in a stream. It handles +// detecting if the plugin has shutdown via the passeed pluginCtx. The +// parameters are: +// - err: the error returned from the streaming RPC +// - reqCtx: the context passed to the streaming request +// - pluginCtx: the plugins done ctx used to detect the plugin dying +// +// The return values are: +// - base.ErrPluginShutdown if the error is because the plugin shutdown +// - context.Canceled if the reqCtx is canceled +// - The original error +func HandleStreamErr(err error, reqCtx, pluginCtx context.Context) error { + if err == nil { + return nil + } + + // Determine if the error is because the plugin shutdown + if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.Unavailable { + // Potentially wait a little before returning an error so we can detect + // the exit + select { + case <-pluginCtx.Done(): + err = base.ErrPluginShutdown + case <-reqCtx.Done(): + err = reqCtx.Err() + + // There is no guarantee that the select will choose the + // doneCtx first so we have to double check + select { + case <-pluginCtx.Done(): + err = base.ErrPluginShutdown + default: + } + case <-time.After(3 * time.Second): + // Its okay to wait a while since the connection isn't available and + // on local host it is likely shutting down. It is not expected for + // this to ever reach even close to 3 seconds. + } + + // It is an error we don't know how to handle, so return it + return err + } + + // Context was cancelled + if errStatus := status.FromContextError(reqCtx.Err()); errStatus.Code() == codes.Canceled { + return context.Canceled + } + + return err +}