Skip to content

Commit

Permalink
Merge pull request #4868 from hashicorp/b-plugin-ctx
Browse files Browse the repository at this point in the history
Plugin client's handle plugin dying
  • Loading branch information
dadgar authored Nov 13, 2018
2 parents 6702df2 + 9d42f4d commit 6d0cd01
Show file tree
Hide file tree
Showing 19 changed files with 165 additions and 110 deletions.
2 changes: 1 addition & 1 deletion drivers/exec/driver.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package exec

import (
"context"
"fmt"
"os"
"path/filepath"
Expand All @@ -21,7 +22,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
"golang.org/x/net/context"
)

const (
Expand Down
2 changes: 1 addition & 1 deletion drivers/java/driver.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package java

import (
"context"
"fmt"
"os"
"os/exec"
Expand All @@ -23,7 +24,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
"golang.org/x/net/context"
)

const (
Expand Down
7 changes: 3 additions & 4 deletions drivers/mock/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"github.com/hashicorp/nomad/plugins/drivers"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
netctx "golang.org/x/net/context"
)

const (
Expand Down Expand Up @@ -232,7 +231,7 @@ func (d *Driver) Capabilities() (*drivers.Capabilities, error) {
return capabilities, nil
}

func (d *Driver) Fingerprint(ctx netctx.Context) (<-chan *drivers.Fingerprint, error) {
func (d *Driver) Fingerprint(ctx context.Context) (<-chan *drivers.Fingerprint, error) {
ch := make(chan *drivers.Fingerprint)
go d.handleFingerprint(ctx, ch)
return ch, nil
Expand Down Expand Up @@ -365,7 +364,7 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru

}

func (d *Driver) WaitTask(ctx netctx.Context, taskID string) (<-chan *drivers.ExitResult, error) {
func (d *Driver) WaitTask(ctx context.Context, taskID string) (<-chan *drivers.ExitResult, error) {
handle, ok := d.tasks.Get(taskID)
if !ok {
return nil, drivers.ErrTaskNotFound
Expand Down Expand Up @@ -430,7 +429,7 @@ func (d *Driver) TaskStats(taskID string) (*cstructs.TaskResourceUsage, error) {
return nil, nil
}

func (d *Driver) TaskEvents(ctx netctx.Context) (<-chan *drivers.TaskEvent, error) {
func (d *Driver) TaskEvents(ctx context.Context) (<-chan *drivers.TaskEvent, error) {
return d.eventer.TaskEvents(ctx)
}

Expand Down
2 changes: 1 addition & 1 deletion drivers/qemu/driver.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package qemu

import (
"context"
"errors"
"fmt"
"net"
Expand All @@ -25,7 +26,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
"golang.org/x/net/context"
)

const (
Expand Down
2 changes: 1 addition & 1 deletion drivers/rawexec/driver.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rawexec

import (
"context"
"fmt"
"os"
"path/filepath"
Expand All @@ -22,7 +23,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
"golang.org/x/net/context"
)

const (
Expand Down
2 changes: 1 addition & 1 deletion drivers/rkt/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package rkt

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -36,7 +37,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
rktv1 "github.com/rkt/rkt/api/v1"
"golang.org/x/net/context"
)

const (
Expand Down
8 changes: 3 additions & 5 deletions drivers/rkt/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
package rkt

import (
"bytes"
"context"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"sync"
"testing"
"time"

"os"

"bytes"

"github.com/hashicorp/hcl2/hcl"
ctestutil "github.com/hashicorp/nomad/client/testutil"
"github.com/hashicorp/nomad/helper/testlog"
Expand All @@ -26,7 +25,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/testutil"
"github.com/stretchr/testify/require"
"golang.org/x/net/context"
)

var _ drivers.DriverPlugin = (*Driver)(nil)
Expand Down
2 changes: 1 addition & 1 deletion drivers/shared/eventer/eventer.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package eventer

import (
"context"
"sync"
"time"

hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/plugins/drivers"
"golang.org/x/net/context"
)

var (
Expand Down
9 changes: 6 additions & 3 deletions plugins/base/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ import (
// gRPC to communicate to the remote plugin.
type BasePluginClient struct {
Client proto.BasePluginClient

// DoneCtx is closed when the plugin exits
DoneCtx context.Context
}

func (b *BasePluginClient) PluginInfo() (*PluginInfoResponse, error) {
presp, err := b.Client.PluginInfo(context.Background(), &proto.PluginInfoRequest{})
presp, err := b.Client.PluginInfo(b.DoneCtx, &proto.PluginInfoRequest{})
if err != nil {
return nil, err
}
Expand All @@ -41,7 +44,7 @@ func (b *BasePluginClient) PluginInfo() (*PluginInfoResponse, error) {
}

func (b *BasePluginClient) ConfigSchema() (*hclspec.Spec, error) {
presp, err := b.Client.ConfigSchema(context.Background(), &proto.ConfigSchemaRequest{})
presp, err := b.Client.ConfigSchema(b.DoneCtx, &proto.ConfigSchemaRequest{})
if err != nil {
return nil, err
}
Expand All @@ -51,7 +54,7 @@ func (b *BasePluginClient) ConfigSchema() (*hclspec.Spec, error) {

func (b *BasePluginClient) SetConfig(data []byte, config *ClientAgentConfig) error {
// Send the config
_, err := b.Client.SetConfig(context.Background(), &proto.SetConfigRequest{
_, err := b.Client.SetConfig(b.DoneCtx, &proto.SetConfigRequest{
MsgpackConfig: data,
NomadConfig: config.toProto(),
})
Expand Down
5 changes: 4 additions & 1 deletion plugins/base/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ func (p *PluginBase) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error
}

func (p *PluginBase) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
return &BasePluginClient{Client: proto.NewBasePluginClient(c)}, nil
return &BasePluginClient{
Client: proto.NewBasePluginClient(c),
DoneCtx: ctx,
}, nil
}

// MsgpackHandle is a shared handle for encoding/decoding of structs
Expand Down
66 changes: 16 additions & 50 deletions plugins/device/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ import (
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/nomad/plugins/base"
"github.com/hashicorp/nomad/plugins/device/proto"
netctx "golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/hashicorp/nomad/plugins/shared"
)

// devicePluginClient implements the client side of a remote device plugin, using
Expand Down Expand Up @@ -49,28 +47,33 @@ func (d *devicePluginClient) Fingerprint(ctx context.Context) (<-chan *Fingerpri
// the gRPC stream to a channel. Exits either when context is cancelled or the
// stream has an error.
func (d *devicePluginClient) handleFingerprint(
ctx netctx.Context,
ctx context.Context,
stream proto.DevicePlugin_FingerprintClient,
out chan *FingerprintResponse) {

defer close(out)
for {
resp, err := stream.Recv()
if err != nil {
if err != io.EOF {
out <- &FingerprintResponse{
Error: d.handleStreamErr(err, ctx),
Error: shared.HandleStreamErr(err, ctx, d.doneCtx),
}
}

// End the stream
close(out)
return
}

// Send the response
out <- &FingerprintResponse{
f := &FingerprintResponse{
Devices: convertProtoDeviceGroups(resp.GetDeviceGroup()),
}
select {
case <-ctx.Done():
return
case out <- f:
}
}
}

Expand Down Expand Up @@ -116,69 +119,32 @@ func (d *devicePluginClient) Stats(ctx context.Context, interval time.Duration)
// the gRPC stream to a channel. Exits either when context is cancelled or the
// stream has an error.
func (d *devicePluginClient) handleStats(
ctx netctx.Context,
ctx context.Context,
stream proto.DevicePlugin_StatsClient,
out chan *StatsResponse) {

defer close(out)
for {
resp, err := stream.Recv()
if err != nil {
if err != io.EOF {
out <- &StatsResponse{
Error: d.handleStreamErr(err, ctx),
Error: shared.HandleStreamErr(err, ctx, d.doneCtx),
}
}

// End the stream
close(out)
return
}

// Send the response
out <- &StatsResponse{
s := &StatsResponse{
Groups: convertProtoDeviceGroupsStats(resp.GetGroups()),
}
}
}

// handleStreamErr is used to handle a non io.EOF error in a stream. It handles
// detecting if the plugin has shutdown
func (d *devicePluginClient) handleStreamErr(err error, ctx context.Context) error {
if err == nil {
return nil
}

// Determine if the error is because the plugin shutdown
if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.Unavailable {
// Potentially wait a little before returning an error so we can detect
// the exit
select {
case <-d.doneCtx.Done():
err = base.ErrPluginShutdown
case <-ctx.Done():
err = ctx.Err()

// There is no guarantee that the select will choose the
// doneCtx first so we have to double check
select {
case <-d.doneCtx.Done():
err = base.ErrPluginShutdown
default:
}
case <-time.After(3 * time.Second):
// Its okay to wait a while since the connection isn't available and
// on local host it is likely shutting down. It is not expected for
// this to ever reach even close to 3 seconds.
return
case out <- s:
}

// It is an error we don't know how to handle, so return it
return err
}

// Context was cancelled
if errStatus := status.FromContextError(ctx.Err()); errStatus.Code() == codes.Canceled {
return context.Canceled
}

return err
}
3 changes: 2 additions & 1 deletion plugins/device/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ func (p *PluginDevice) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker
doneCtx: ctx,
client: proto.NewDevicePluginClient(c),
BasePluginClient: &base.BasePluginClient{
Client: bproto.NewBasePluginClient(c),
Client: bproto.NewBasePluginClient(c),
DoneCtx: ctx,
},
}, nil
}
Expand Down
Loading

0 comments on commit 6d0cd01

Please sign in to comment.