Skip to content

Commit

Permalink
Merge pull request #126 from unreality/main
Browse files Browse the repository at this point in the history
Initial work on OIDC (SSO) integration
  • Loading branch information
kradalby authored Oct 31, 2021
2 parents 7301d7e + 73d22cd commit fbdfa55
Show file tree
Hide file tree
Showing 13 changed files with 653 additions and 58 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ headscale implements this coordination server.
- [x] Taildrop (File Sharing)
- [x] Support for alternative IP ranges in the tailnets (default Tailscale's 100.64.0.0/10)
- [x] DNS (passing DNS servers to nodes)
- [x] Single-Sign-On (via Open ID Connect)
- [x] Share nodes between namespaces
- [x] MagicDNS (see `docs/`)

Expand All @@ -49,7 +50,6 @@ headscale implements this coordination server.

Suggestions/PRs welcomed!


## Running headscale

Please have a look at the documentation under [`docs/`](docs/).
Expand Down
128 changes: 81 additions & 47 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"strings"
"time"

"github.com/rs/zerolog/log"
Expand Down Expand Up @@ -64,7 +65,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Str("handler", "Registration").
Err(err).
Msg("Cannot parse machine key")
machineRegistrations.WithLabelValues("unkown", "web", "error", "unknown").Inc()
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
c.String(http.StatusInternalServerError, "Sad!")
return
}
Expand All @@ -75,45 +76,70 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Str("handler", "Registration").
Err(err).
Msg("Cannot decode message")
machineRegistrations.WithLabelValues("unkown", "web", "error", "unknown").Inc()
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
c.String(http.StatusInternalServerError, "Very sad!")
return
}

now := time.Now().UTC()
var m Machine
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
m, err := h.GetMachineByMachineKey(mKey.HexString())
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
m = Machine{
Expiry: &req.Expiry,
MachineKey: mKey.HexString(),
Name: req.Hostinfo.Hostname,
NodeKey: wgkey.Key(req.NodeKey).HexString(),
LastSuccessfulUpdate: &now,
newMachine := Machine{
Expiry: &time.Time{},
MachineKey: mKey.HexString(),
Name: req.Hostinfo.Hostname,
}
if err := h.db.Create(&m).Error; err != nil {
if err := h.db.Create(&newMachine).Error; err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Could not create row")
machineRegistrations.WithLabelValues("unkown", "web", "error", m.Namespace.Name).Inc()
machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).Inc()
return
}
m = &newMachine
}

if !m.Registered && req.Auth.AuthKey != "" {
h.handleAuthKey(c, h.db, mKey, req, m)
h.handleAuthKey(c, h.db, mKey, req, *m)
return
}

resp := tailcfg.RegisterResponse{}

// We have the updated key!
if m.NodeKey == wgkey.Key(req.NodeKey).HexString() {
if m.Registered {

// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
log.Info().
Str("handler", "Registration").
Str("machine", m.Name).
Msg("Client requested logout")

m.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
h.db.Save(&m)

resp.AuthURL = ""
resp.MachineAuthorized = false
resp.User = *m.Namespace.toUser()
respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "")
return
}
c.Data(200, "application/json; charset=utf-8", respBody)
return
}

if m.Registered && m.Expiry.UTC().After(now) {
// The machine registration is valid, respond with redirect to /map
log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Expand All @@ -122,6 +148,8 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
resp.AuthURL = ""
resp.MachineAuthorized = true
resp.User = *m.Namespace.toUser()
resp.Login = *m.Namespace.toLogin()

respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil {
log.Error().
Expand All @@ -137,12 +165,30 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
return
}

// The client has registered before, but has expired
log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Msg("Not registered and not NodeKey rotation. Sending a authurl to register")
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
h.cfg.ServerURL, mKey.HexString())
Msg("Machine registration has expired. Sending a authurl to register")

if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
}

// When a client connects, it may request a specific expiry time in its
// RegisterRequest (https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L634)
// RequestedExpiry is used to store the clients requested expiry time since the authentication flow is broken
// into two steps (which cant pass arbitrary data between them easily) and needs to be
// retrieved again after the user has authenticated. After the authentication flow
// completes, RequestedExpiry is copied into Expiry.
m.RequestedExpiry = &req.Expiry

h.db.Save(&m)

respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil {
log.Error().
Expand All @@ -158,8 +204,8 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
return
}

// The NodeKey we have matches OldNodeKey, which means this is a refresh after an key expiration
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() {
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) {
log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Expand All @@ -182,35 +228,23 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
return
}

// We arrive here after a client is restarted without finalizing the authentication flow or
// when headscale is stopped in the middle of the auth process.
if m.Registered {
log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Msg("The node is sending us a new NodeKey, but machine is registered. All clear for /map")
resp.AuthURL = ""
resp.MachineAuthorized = true
resp.User = *m.Namespace.toUser()
respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "")
return
}
c.Data(200, "application/json; charset=utf-8", respBody)
return
}

// The machine registration is new, redirect the client to the registration URL
log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Msg("The node is sending us a new NodeKey, sending auth url")
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
h.cfg.ServerURL, mKey.HexString())
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
}

// save the requested expiry time for retrieval later in the authentication flow
m.RequestedExpiry = &req.Expiry
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
h.db.Save(&m)

respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil {
log.Error().
Expand Down
29 changes: 29 additions & 0 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ import (
"sync"
"time"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"

"github.com/gin-gonic/gin"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
apiV1 "github.com/juanfont/headscale/gen/go/v1"
Expand Down Expand Up @@ -62,6 +66,18 @@ type Config struct {
ACMEEmail string

DNSConfig *tailcfg.DNSConfig

OIDC OIDCConfig

MaxMachineRegistrationDuration time.Duration
DefaultMachineRegistrationDuration time.Duration
}

type OIDCConfig struct {
Issuer string
ClientID string
ClientSecret string
MatchMap map[string]string
}

type DERPConfig struct {
Expand All @@ -87,6 +103,10 @@ type Headscale struct {
aclRules *[]tailcfg.FilterRule

lastStateChange sync.Map

oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
oidcStateCache *cache.Cache
}

// NewHeadscale returns the Headscale app.
Expand Down Expand Up @@ -127,6 +147,13 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
return nil, err
}

if cfg.OIDC.Issuer != "" {
err = h.initOIDC()
if err != nil {
return nil, err
}
}

if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS
magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain)
if err != nil {
Expand Down Expand Up @@ -255,6 +282,8 @@ func (h *Headscale) Serve() error {
r.GET("/register", h.RegisterWebAPI)
r.POST("/machine/:id/map", h.PollNetMapHandler)
r.POST("/machine/:id", h.RegistrationHandler)
r.GET("/oidc/register/:mkey", h.RegisterOIDC)
r.GET("/oidc/callback", h.OIDCCallback)
r.GET("/apple", h.AppleMobileConfig)
r.GET("/apple/:platform", h.ApplePlatformConfig)

Expand Down
3 changes: 3 additions & 0 deletions cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
return nil, errors.New("Machine not found")
}

h.updateMachineExpiry(&m) // update the machine's expiry before bailing if its already registered

if m.isAlreadyRegistered() {
return nil, errors.New("Machine already registered")
}
Expand All @@ -36,5 +38,6 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
m.Registered = true
m.RegisterMethod = "cli"
h.db.Save(&m)

return &m, nil
}
20 changes: 13 additions & 7 deletions cli_test.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
package headscale

import (
"time"

"gopkg.in/check.v1"
)

func (s *Suite) TestRegisterMachine(c *check.C) {
n, err := h.CreateNamespace("test")
c.Assert(err, check.IsNil)

now := time.Now().UTC()

m := Machine{
ID: 0,
MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
IPAddress: "10.0.0.1",
ID: 0,
MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
IPAddress: "10.0.0.1",
Expiry: &now,
RequestedExpiry: &now,
}
h.db.Save(&m)

Expand Down
Loading

0 comments on commit fbdfa55

Please sign in to comment.