Skip to content

Commit

Permalink
Add host aliasing to service discovery
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonc committed Jan 27, 2023
1 parent 4f49ac2 commit 8cde370
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Terraform svchost package
# terraform-svchost

[![CI Tests](https://github.com/hashicorp/terraform-svchost/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/hashicorp/terraform-svchost/actions/workflows/ci.yml)
[![GitHub license](https://img.shields.io/github/license/hashicorp/terraform-svchost.svg)](https://github.com/hashicorp/terraform-svchost/blob/main/LICENSE)
Expand Down
32 changes: 28 additions & 4 deletions disco/disco.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"mime"
"net/http"
"net/url"
"time"

"github.com/hashicorp/terraform-svchost"
svchost "github.com/hashicorp/terraform-svchost"
"github.com/hashicorp/terraform-svchost/auth"
)

Expand All @@ -42,6 +41,7 @@ var httpTransport = defaultHttpTransport()
// hostnames and caches the results by hostname to avoid repeated requests
// for the same information.
type Disco struct {
aliases map[svchost.Hostname]svchost.Hostname
hostCache map[svchost.Hostname]*Host
credsSrc auth.CredentialsSource

Expand Down Expand Up @@ -69,6 +69,7 @@ func New() *Disco {
// the given credentials source.
func NewWithCredentialsSource(credsSrc auth.CredentialsSource) *Disco {
return &Disco{
aliases: make(map[svchost.Hostname]svchost.Hostname),
hostCache: make(map[svchost.Hostname]*Host),
credsSrc: credsSrc,
Transport: httpTransport,
Expand Down Expand Up @@ -104,11 +105,15 @@ func (d *Disco) CredentialsSource() auth.CredentialsSource {
}

// CredentialsForHost returns a non-nil HostCredentials if the embedded source has
// credentials available for the host, and a nil HostCredentials if it does not.
// credentials available for the host, or host alias, and a nil HostCredentials if it does not.
func (d *Disco) CredentialsForHost(hostname svchost.Hostname) (auth.HostCredentials, error) {
if d.credsSrc == nil {
return nil, nil
}
if aliasedHost, aliasExists := d.aliases[hostname]; aliasExists {
log.Printf("[DEBUG] CredentialsForHost found alias %s for %s", hostname, aliasedHost)
hostname = aliasedHost
}
return d.credsSrc.ForHost(hostname)
}

Expand Down Expand Up @@ -139,6 +144,13 @@ func (d *Disco) ForceHostServices(hostname svchost.Hostname, services map[string
}
}

// Alias accepts an alias and target Hostname. When service discovery is performed
// or credentials are requested for the alias hostname, the target will be consulted instead.
func (d *Disco) Alias(alias, target svchost.Hostname) {
log.Printf("[DEBUG] Service discovery for %s aliased as %s", target, alias)
d.aliases[alias] = target
}

// Discover runs the discovery protocol against the given hostname (which must
// already have been validated and prepared with svchost.ForComparison) and
// returns an object describing the services available at that host.
Expand Down Expand Up @@ -176,6 +188,11 @@ func (d *Disco) DiscoverServiceURL(hostname svchost.Hostname, serviceID string)
// discover implements the actual discovery process, with its result cached
// by the public-facing Discover method.
func (d *Disco) discover(hostname svchost.Hostname) (*Host, error) {
if aliasedHost, aliasExists := d.aliases[hostname]; aliasExists {
log.Printf("[DEBUG] Discover found alias %s for %s", hostname, aliasedHost)
hostname = aliasedHost
}

discoURL := &url.URL{
Scheme: "https",
Host: hostname.String(),
Expand Down Expand Up @@ -259,7 +276,7 @@ func (d *Disco) discover(hostname svchost.Hostname) (*Host, error) {
// size, but we'll at least prevent reading the entire thing into memory.
lr := io.LimitReader(resp.Body, maxDiscoDocBytes)

servicesBytes, err := ioutil.ReadAll(lr)
servicesBytes, err := io.ReadAll(lr)
if err != nil {
return nil, fmt.Errorf("error reading discovery document body: %v", err)
}
Expand All @@ -284,3 +301,10 @@ func (d *Disco) Forget(hostname svchost.Hostname) {
func (d *Disco) ForgetAll() {
d.hostCache = make(map[svchost.Hostname]*Host)
}

// ForgetAlias removes a previously aliased hostname as well as its cached entry, if any exist.
// If the alias has no target then this is a no-op.
func (d *Disco) ForgetAlias(alias svchost.Hostname) {
delete(d.aliases, alias)
d.Forget(alias)
}
70 changes: 69 additions & 1 deletion disco/disco_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"testing"
"time"

"github.com/hashicorp/terraform-svchost"
svchost "github.com/hashicorp/terraform-svchost"
"github.com/hashicorp/terraform-svchost/auth"
)

Expand Down Expand Up @@ -362,6 +362,74 @@ func TestDiscover(t *testing.T) {
}

})

t.Run("alias", func(t *testing.T) {
// The server will listen on localhost and we will expect this response
// by requesting discovery on the alias.
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
resp := []byte(`
{
"thingy.v1": "http://example.com/foo"
}
`)
w.Header().Add("Content-Type", "application/json")
w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
w.Write(resp)
})
defer close()

target, err := svchost.ForComparison("localhost" + portStr)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
alias, err := svchost.ForComparison("not-a-real-host-dont-even-try.no")
if err != nil {
t.Fatalf("alias hostname is invalid: %s", err)
}

d := New()
d.SetCredentialsSource(auth.StaticCredentialsSource(map[svchost.Hostname]map[string]any{
target: {
"token": "hunter2",
},
}))

d.Alias(alias, target)

discovered, err := d.Discover(alias)
if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}

gotURL, err := discovered.ServiceURL("thingy.v1")
if err != nil {
t.Fatalf("unexpected service URL error: %s", err)
}
if gotURL == nil {
t.Fatalf("found no URL for thingy.v1")
}
if got, want := gotURL.String(), "http://example.com/foo"; got != want {
t.Fatalf("wrong result %q; want %q", got, want)
}

aliasCreds, err := d.CredentialsForHost(alias)
if err != nil {
t.Fatalf("unexpected credentials error: %s", err)
}
if aliasCreds.Token() != "hunter2" {
t.Fatalf("found no credentials for alias")
}

d.ForgetAlias(alias)

discovered, err = d.Discover(alias)
if err == nil {
t.Error("expected error, got none")
}
if discovered != nil {
t.Error("expected discovered to be nil, got non-nil")
}
})
}

func testServer(h func(w http.ResponseWriter, r *http.Request)) (portStr string, close func()) {
Expand Down

0 comments on commit 8cde370

Please sign in to comment.