Skip to content

Commit

Permalink
Update flake.lock (juanfont#2195)
Browse files Browse the repository at this point in the history
Flake lock file updates:

• Updated input 'nixpkgs':
    'github:NixOS/nixpkgs/e2f08f4d8b3ecb5cf5c9fd9cb2d53bb3c71807da?narHash=sha256-CAZF2NRuHmqTtRTNAruWpHA43Gg2UvuCNEIzabP0l6M%3D' (2024-10-05)
  → 'github:NixOS/nixpkgs/41dea55321e5a999b17033296ac05fe8a8b5a257?narHash=sha256-WvLXzNNnnw%2BqpFOmgaM3JUlNEH%2BT4s22b5i2oyyCpXE%3D' (2024-10-25)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

refresh token: high-level implementation

add implementation details

oidc: fix log in refreshJob

remove oauth.AccessTypeOffline from extras

fix logics
  • Loading branch information
github-actions[bot] authored and Ilya Zyabirov committed Nov 12, 2024
1 parent e2d5ee0 commit e8cfd01
Show file tree
Hide file tree
Showing 8 changed files with 302 additions and 30 deletions.
7 changes: 6 additions & 1 deletion docs/ref/oidc.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ oidc:

# Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query
# parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email".
scope: ["openid", "profile", "email", "custom"]
# Note that offline_access is enabled in order to issue refresh token as well.
scope: ["openid", "profile", "email", "custom", "offline_access"]
# Optional: Passed on to the browser login request – used to tweak behaviour for the OIDC provider
extra_params:
domain_hint: example.com
Expand All @@ -50,6 +51,10 @@ oidc:
# If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following
# user: `first-name.last-name.example.com`
strip_email_domain: true

# Set the frequency of background refreshing ID tokens and nodes' expiries using corresponding refresh tokens.
# If set to '0', the background refreshing is disabled.
force_refresh_period: 1h
```
## Azure AD example
Expand Down
6 changes: 3 additions & 3 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 44 additions & 0 deletions hscontrol/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"os/signal"
"path/filepath"
"runtime"
"slices"
"strings"
"sync"
"syscall"
Expand Down Expand Up @@ -258,13 +259,48 @@ func (h *Headscale) expireExpiredNodes(ctx context.Context, every time.Duration)
if changed {
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")

if ap, ok := h.authProvider.(*AuthProviderOIDC); ok {
refreshed, err := ap.refreshIfPossible(ctx, nodesFromPatches(update.ChangePatches))
if err != nil {
log.Error().
Err(err).
Uints64("refreshed", nodeIDsAsUint64s(refreshed)).
Msg("error refreshing tokens")
}

// consider session alive if we could refresh token successfully
// no need to notify other nodes about expiry
cps := slices.DeleteFunc(update.ChangePatches, func(p *tailcfg.PeerChange) bool {
return slices.Contains(refreshed, types.NodeID(p.NodeID))
})

update.ChangePatches = cps
}

ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
h.nodeNotifier.NotifyAll(ctx, update)
}
}
}
}

func nodesFromPatches(cps []*tailcfg.PeerChange) []types.NodeID {
ids := make([]types.NodeID, 0, len(cps))
for _, cp := range cps {
ids = append(ids, types.NodeID(cp.NodeID))
}

return ids
}

func nodeIDsAsUint64s(ids []types.NodeID) []uint64 {
res := make([]uint64, 0, len(ids))
for _, id := range ids {
res = append(res, id.Uint64())
}
return res
}

// scheduledDERPMapUpdateWorker refreshes the DERPMap stored on the global object
// at a set interval.
func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
Expand Down Expand Up @@ -546,6 +582,14 @@ func (h *Headscale) Serve() error {
defer expireNodeCancel()
go h.expireExpiredNodes(expireNodeCtx, updateInterval)

if ap, ok := h.authProvider.(*AuthProviderOIDC); ok {
// TODO: does it need its own ctx?
refreshJobCtx, refreshJobCancel := context.WithCancel(context.Background())
defer refreshJobCancel()

go ap.refreshJob(refreshJobCtx)
}

if zl.GlobalLevel() == zl.TraceLevel {
zerolog.RespLog = true
} else {
Expand Down
12 changes: 12 additions & 0 deletions hscontrol/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,18 @@ func NewHeadscaleDatabase(
},
Rollback: func(db *gorm.DB) error { return nil },
},
{
ID: "202411191627",
Migrate: func(tx *gorm.DB) error {
err := tx.AutoMigrate(&types.RefreshToken{})
if err != nil {
return err
}

return nil
},
Rollback: func(db *gorm.DB) error { return nil },
},
},
)

Expand Down
44 changes: 44 additions & 0 deletions hscontrol/db/refresh_token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package db

import (
"github.com/juanfont/headscale/hscontrol/types"
"gorm.io/gorm"
)

func (hsdb *HSDatabase) GetRefreshTokens(nodeIDs ...types.NodeID) (map[types.NodeID]*types.RefreshToken, error) {
return Read(hsdb.DB, func(tx *gorm.DB) (map[types.NodeID]*types.RefreshToken, error) {
return GetRefreshTokens(tx, nodeIDs...)
})
}

func GetRefreshTokens(tx *gorm.DB, nodeIDs ...types.NodeID) (map[types.NodeID]*types.RefreshToken, error) {
tokens := []*types.RefreshToken{}
result := make(map[types.NodeID]*types.RefreshToken)

if len(nodeIDs) > 0 {
tx = tx.Where("node_id IN ?", nodeIDs)
}

if err := tx.Find(&tokens).Error; err != nil {
return nil, err
}

for _, t := range tokens {
result[t.NodeID] = t
}

return result, nil
}

func (hsdb *HSDatabase) SaveRefreshToken(token *types.RefreshToken) (*types.RefreshToken, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.RefreshToken, error) {
return SaveRefreshToken(tx, token)
})
}

func SaveRefreshToken(tx *gorm.DB, token *types.RefreshToken) (*types.RefreshToken, error) {
if err := tx.Save(token).Error; err != nil {
return nil, err
}
return token, nil
}
Loading

0 comments on commit e8cfd01

Please sign in to comment.