Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disco: OverrideHostDiscoveryURL #85

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 54 additions & 15 deletions disco/disco.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
)

const (
// Fixed path to the discovery manifest.
// Fixed path to a host's default discovery manifest.
discoPath = "/.well-known/terraform.json"

// Arbitrary-but-small number to prevent runaway redirect loops.
Expand All @@ -45,9 +45,10 @@ var httpTransport = defaultHTTPTransport()
// for the same information.
type Disco struct {
// must lock "mu" while interacting with these maps
aliases map[svchost.Hostname]svchost.Hostname
hostCache map[svchost.Hostname]*Host
mu sync.Mutex
aliases map[svchost.Hostname]svchost.Hostname
hostCache map[svchost.Hostname]*Host
urlOverride map[svchost.Hostname]*url.URL
mu sync.Mutex

credsSrc auth.CredentialsSource

Expand Down Expand Up @@ -136,24 +137,49 @@ func (d *Disco) CredentialsForHost(hostname svchost.Hostname) (auth.HostCredenti
// at the host's default discovery URL, though using absolute URLs is strongly
// recommended to make the configured behavior more explicit.
func (d *Disco) ForceHostServices(hostname svchost.Hostname, services map[string]interface{}) {
discoURL := d.discoveryURLForHost(hostname)
if services == nil {
services = map[string]interface{}{}
}

d.mu.Lock()
d.hostCache[hostname] = &Host{
discoURL: &url.URL{
Scheme: "https",
Host: string(hostname),
Path: discoPath,
},
discoURL: discoURL,
hostname: hostname.ForDisplay(),
services: services,
transport: d.Transport,
}
d.mu.Unlock()
}

// OverrideHostDiscoveryURL forces the use of the given URL as the discovery document location
// for the given hostname, overriding the default URL structure using the "https" scheme
// and the fixed path "/.well-known/terraform.json".
//
// Any future request for service discovery with that hostname will attempt to fetch
// service information from the given URL instead, and will use the results from that discovery
// as the service information for that hostname.
//
// The caller must not modify anything reachable through the given URL pointer after passing
// it to this function.
//
// If the same hostname is used with both this method and [Disco.ForceHostServices] then the
// latter "wins", because forcing service information for a particular host prevents making
// a service discovery request for that hostname over the network at all. However, any relative
// URLs in the metadata passed to ForceHostServices will be resolved relative to the overridden
// discovery URL instead of the default URL structure.
//
// All calls to this method should be made before performing any service discovery requests.
func (d *Disco) OverrideHostDiscoveryURL(hostname svchost.Hostname, discoveryURL *url.URL) {
d.mu.Lock()
if d.urlOverride == nil {
// Lazy allocation, because most Disco objects don't use URL overrides at all.
d.urlOverride = make(map[svchost.Hostname]*url.URL)
}
d.urlOverride[hostname] = discoveryURL
d.mu.Unlock()
}

// Alias accepts an alias and target Hostname. When service discovery is performed
// or credentials are requested for the alias hostname, the target will be consulted instead.
func (d *Disco) Alias(alias, target svchost.Hostname) {
Expand Down Expand Up @@ -225,12 +251,7 @@ func (d *Disco) discover(hostname svchost.Hostname) (*Host, error) {
}
d.mu.Unlock()

discoURL := &url.URL{
Scheme: "https",
Host: hostname.String(),
Path: discoPath,
}

discoURL := d.discoveryURLForHost(hostname)
client := &http.Client{
Transport: d.Transport,
Timeout: discoTimeout,
Expand Down Expand Up @@ -323,6 +344,24 @@ func (d *Disco) discover(hostname svchost.Hostname) (*Host, error) {
return host, nil
}

// discoveryURLForHost returns the URL to fetch to find the service discovery
// document (if any) relating to the given hostname.
func (d *Disco) discoveryURLForHost(hostname svchost.Hostname) *url.URL {
d.mu.Lock() // prevent concurrent access to d.urlOverride
defer d.mu.Unlock()
if override, ok := d.urlOverride[hostname]; ok {
return override
}
// Any hostname that doesn't have an override -- which is typically all of them --
// gets a systematically-generated discovery URL using the RFC8615 "well-known"
// path structure.
return &url.URL{
Scheme: "https",
Host: hostname.String(),
Path: discoPath,
}
}

// Forget invalidates any cached record of the given hostname. If the host
// has no cache entry then this is a no-op.
func (d *Disco) Forget(hostname svchost.Hostname) {
Expand Down
60 changes: 60 additions & 0 deletions disco/disco_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,66 @@ func TestDiscover(t *testing.T) {
t.Fatalf("wrong Authorization header\ngot: %s\nwant: %s", got, want)
}
})
t.Run("discovery document URL override", func(t *testing.T) {
portStr, cleanup := testServer(func(w http.ResponseWriter, r *http.Request) {
resp := []byte(`
{
"thingy.v1": "/foo",
"wotsit.v2": "http://example.net/bar"
}
`)
w.Header().Add("Content-Type", "application/json")
w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
w.Write(resp)
})
defer cleanup()

host := svchost.Hostname("example.com")
overrideURL := &url.URL{
Scheme: "https",
Host: "localhost" + portStr,

// The following uses the well-known path just because that's what testServer
// expects. This is allowed to be any path that's valid as far as URL syntax is
// concerned.
Path: "/.well-known/terraform.json",
}
t.Logf("discovery URL for %s overridden as %s", host.ForDisplay(), overrideURL.String())

d := New()
d.OverrideHostDiscoveryURL(host, overrideURL)
discovered, err := d.Discover(host)
if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}

{
gotURL, err := discovered.ServiceURL("thingy.v1")
if err != nil {
t.Fatalf("unexpected service URL error: %s", err)
}
if gotURL == nil {
t.Fatalf("found no URL for thingy.v1")
}
// NOTE: relative URL intentionally resolved relative to the overridden
// discovery document URL, rather than the default document URL.
if got, want := gotURL.String(), "https://localhost"+portStr+"/foo"; got != want {
t.Fatalf("wrong result %q; want %q", got, want)
}
}
{
gotURL, err := discovered.ServiceURL("wotsit.v2")
if err != nil {
t.Fatalf("unexpected service URL error: %s", err)
}
if gotURL == nil {
t.Fatalf("found no URL for wotsit.v2")
}
if got, want := gotURL.String(), "http://example.net/bar"; got != want {
t.Fatalf("wrong result %q; want %q", got, want)
}
}
})
t.Run("forced services override", func(t *testing.T) {
forced := map[string]interface{}{
"thingy.v1": "http://example.net/foo",
Expand Down