diff --git a/api/api.go b/api/api.go index ce6d910fe17..09c000a4a9b 100644 --- a/api/api.go +++ b/api/api.go @@ -125,7 +125,9 @@ func (c *Config) ClientConfig(region, address string, tlsEnabled bool) *Config { WaitTime: c.WaitTime, TLSConfig: c.TLSConfig.Copy(), } - config.TLSConfig.TLSServerName = fmt.Sprintf("client.%s.nomad", c.Region) + if tlsEnabled && config.TLSConfig != nil { + config.TLSConfig.TLSServerName = fmt.Sprintf("client.%s.nomad", region) + } return config } @@ -221,6 +223,9 @@ func DefaultConfig() *Config { // ConfigureTLS applies a set of TLS configurations to the the HTTP client. func (c *Config) ConfigureTLS() error { + if c.TLSConfig == nil { + return nil + } if c.HttpClient == nil { return fmt.Errorf("config HTTP Client must be set") } @@ -300,7 +305,17 @@ func (c *Client) SetRegion(region string) { // GetNodeClient returns a new Client that will dial the specified node. If the // QueryOptions is set, its region will be used. func (c *Client) GetNodeClient(nodeID string, q *QueryOptions) (*Client, error) { - node, _, err := c.Nodes().Info(nodeID, q) + return c.getNodeClientImpl(nodeID, q, c.Nodes().Info) +} + +// nodeLookup is used to lookup a node +type nodeLookup func(nodeID string, q *QueryOptions) (*Node, *QueryMeta, error) + +// getNodeClientImpl is the implementation of creating a API client for +// contacting a node. It is takes a function to lookup the node such that it can +// be mocked during tests. +func (c *Client) getNodeClientImpl(nodeID string, q *QueryOptions, lookup nodeLookup) (*Client, error) { + node, _, err := lookup(nodeID, q) if err != nil { return nil, err } @@ -316,6 +331,10 @@ func (c *Client) GetNodeClient(nodeID string, q *QueryOptions) (*Client, error) region = q.Region } + if region == "" { + region = "global" + } + // Get an API client for the node conf := c.config.ClientConfig(region, node.HTTPAddr, node.TLSEnabled) return NewClient(conf) diff --git a/api/api_test.go b/api/api_test.go index 06eba83b3c9..ae79fee4075 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -2,6 +2,7 @@ package api import ( "encoding/json" + "fmt" "net/http" "net/http/httptest" "os" @@ -9,7 +10,9 @@ import ( "testing" "time" + "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" + "github.com/stretchr/testify/assert" ) type configCallback func(c *Config) @@ -243,3 +246,126 @@ func TestQueryString(t *testing.T) { t.Fatalf("bad uri: %q", uri) } } + +func TestClient_NodeClient(t *testing.T) { + http := "testdomain:4646" + tlsNode := func(string, *QueryOptions) (*Node, *QueryMeta, error) { + return &Node{ + ID: structs.GenerateUUID(), + Status: "ready", + HTTPAddr: http, + TLSEnabled: true, + }, nil, nil + } + noTlsNode := func(string, *QueryOptions) (*Node, *QueryMeta, error) { + return &Node{ + ID: structs.GenerateUUID(), + Status: "ready", + HTTPAddr: http, + TLSEnabled: false, + }, nil, nil + } + + optionNoRegion := &QueryOptions{} + optionRegion := &QueryOptions{ + Region: "foo", + } + + clientNoRegion, err := NewClient(DefaultConfig()) + assert.Nil(t, err) + + regionConfig := DefaultConfig() + regionConfig.Region = "bar" + clientRegion, err := NewClient(regionConfig) + assert.Nil(t, err) + + expectedTLSAddr := fmt.Sprintf("https://%s", http) + expectedNoTLSAddr := fmt.Sprintf("http://%s", http) + + cases := []struct { + Node nodeLookup + QueryOptions *QueryOptions + Client *Client + ExpectedAddr string + ExpectedRegion string + ExpectedTLSServerName string + }{ + { + Node: tlsNode, + QueryOptions: optionNoRegion, + Client: clientNoRegion, + ExpectedAddr: expectedTLSAddr, + ExpectedRegion: "global", + ExpectedTLSServerName: "client.global.nomad", + }, + { + Node: tlsNode, + QueryOptions: optionRegion, + Client: clientNoRegion, + ExpectedAddr: expectedTLSAddr, + ExpectedRegion: "foo", + ExpectedTLSServerName: "client.foo.nomad", + }, + { + Node: tlsNode, + QueryOptions: optionRegion, + Client: clientRegion, + ExpectedAddr: expectedTLSAddr, + ExpectedRegion: "foo", + ExpectedTLSServerName: "client.foo.nomad", + }, + { + Node: tlsNode, + QueryOptions: optionNoRegion, + Client: clientRegion, + ExpectedAddr: expectedTLSAddr, + ExpectedRegion: "bar", + ExpectedTLSServerName: "client.bar.nomad", + }, + { + Node: noTlsNode, + QueryOptions: optionNoRegion, + Client: clientNoRegion, + ExpectedAddr: expectedNoTLSAddr, + ExpectedRegion: "global", + ExpectedTLSServerName: "", + }, + { + Node: noTlsNode, + QueryOptions: optionRegion, + Client: clientNoRegion, + ExpectedAddr: expectedNoTLSAddr, + ExpectedRegion: "foo", + ExpectedTLSServerName: "", + }, + { + Node: noTlsNode, + QueryOptions: optionRegion, + Client: clientRegion, + ExpectedAddr: expectedNoTLSAddr, + ExpectedRegion: "foo", + ExpectedTLSServerName: "", + }, + { + Node: noTlsNode, + QueryOptions: optionNoRegion, + Client: clientRegion, + ExpectedAddr: expectedNoTLSAddr, + ExpectedRegion: "bar", + ExpectedTLSServerName: "", + }, + } + + for _, c := range cases { + name := fmt.Sprintf("%s__%s__%s", c.ExpectedAddr, c.ExpectedRegion, c.ExpectedTLSServerName) + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + nodeClient, err := c.Client.getNodeClientImpl("testID", c.QueryOptions, c.Node) + assert.Nil(err) + assert.Equal(c.ExpectedRegion, nodeClient.config.Region) + assert.Equal(c.ExpectedAddr, nodeClient.config.Address) + assert.NotNil(nodeClient.config.TLSConfig) + assert.Equal(c.ExpectedTLSServerName, nodeClient.config.TLSConfig.TLSServerName) + }) + } +}