diff --git a/.github/workflows/docs-test.yml b/.github/workflows/docs-test.yml index b0e601312c..a2b1532451 100644 --- a/.github/workflows/docs-test.yml +++ b/.github/workflows/docs-test.yml @@ -13,11 +13,11 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - name: Install python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.x - name: Setup cache - uses: actions/cache@v2 + uses: actions/cache@v4 with: key: ${{ github.ref }} path: .cache diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index c5cddef74b..565841db24 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -18,11 +18,11 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - name: Install python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.x - name: Setup cache - uses: actions/cache@v2 + uses: actions/cache@v4 with: key: ${{ github.ref }} path: .cache diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index f7df0fea87..a43e543bd3 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -22,6 +22,7 @@ jobs: - TestAuthNodeApproval - TestOIDCAuthenticationPingAll - TestOIDCExpireNodesBasedOnTokenExpiry + - TestOIDC024UserCreation - TestAuthWebFlowAuthenticationPingAll - TestAuthWebFlowLogoutAndRelogin - TestUserCommand @@ -32,8 +33,7 @@ jobs: - TestPreAuthKeyCorrectUserLoggedInCommand - TestApiKeyCommand - TestNodeTagCommand - - TestNodeAdvertiseTagNoACLCommand - - TestNodeAdvertiseTagWithACLCommand + - TestNodeAdvertiseTagCommand - TestNodeCommand - TestNodeApproveCommand - TestNodeExpireCommand @@ -41,6 +41,7 @@ jobs: - TestNodeMoveCommand - TestPolicyCommand - TestPolicyBrokenConfigCommand + - TestDERPVerifyEndpoint - TestResolveMagicDNS - TestValidateResolvConf - TestDERPServerScenario diff --git a/.golangci.yaml b/.golangci.yaml index cd41a4dfc6..0df9a637aa 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -27,6 +27,7 @@ linters: - nolintlint - musttag # causes issues with imported libs - depguard + - exportloopref # We should strive to enable these: - wrapcheck @@ -56,9 +57,14 @@ linters-settings: - ok - c - tt + - tx + - rx gocritic: disabled-checks: - appendAssign # TODO(kradalby): Remove this - ifElseChain + + nlreturn: + block-size: 4 diff --git a/CHANGELOG.md b/CHANGELOG.md index 617a7d354e..555e254ad9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,16 +2,82 @@ ## Next +### Security fix: OIDC changes in Headscale 0.24.0 + +_Headscale v0.23.0 and earlier_ identified OIDC users by the "username" part of their email address (when `strip_email_domain: true`, the default) or whole email address (when `strip_email_domain: false`). + +Depending on how Headscale and your Identity Provider (IdP) were configured, only using the `email` claim could allow a malicious user with an IdP account to take over another Headscale user's account, even when `strip_email_domain: false`. + +This would also cause a user to lose access to their Headscale account if they changed their email address. + +_Headscale v0.24.0_ now identifies OIDC users by the `iss` and `sub` claims. [These are guaranteed by the OIDC specification to be stable and unique](https://openid.net/specs/openid-connect-core-1_0.html#ClaimStability), even if a user changes email address. A well-designed IdP will typically set `sub` to an opaque identifier like a UUID or numeric ID, which has no relation to the user's name or email address. + +This issue _only_ affects Headscale installations which authenticate with OIDC. + +Headscale v0.24.0 and later will also automatically update profile fields with OIDC data on login. This means that users can change those details in your IdP, and have it populate to Headscale automatically the next time they log in. However, this may affect the way you reference users in policies. + +#### Migrating existing installations + +Headscale v0.23.0 and earlier never recorded the `iss` and `sub` fields, so all legacy (existing) OIDC accounts from _need to be migrated_ to be properly secured. + +Headscale v0.24.0 has an automatic migration feature, which is enabled by default (`map_legacy_users: true`). **This will be disabled by default in a future version of Headscale – any unmigrated users will get new accounts.** + +Headscale v0.24.0 will ignore any `email` claim if the IdP does not provide an `email_verified` claim set to `true`. [What "verified" actually means is contextually dependent](https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims) – Headscale uses it as a signal that the contents of the `email` claim is reasonably trustworthy. + +Headscale v0.23.0 and earlier never checked the `email_verified` claim. This means even if an IdP explicitly indicated to Headscale that its `email` claim was untrustworthy, Headscale would have still accepted it. + +##### What does automatic migration do? + +When automatic migration is enabled (`map_legacy_users: true`), Headscale will first match an OIDC account to a Headscale account by `iss` and `sub`, and then fall back to matching OIDC users similarly to how Headscale v0.23.0 did: + +- If `strip_email_domain: true` (the default): the Headscale username matches the "username" part of their email address. +- If `strip_email_domain: false`: the Headscale username matches the _whole_ email address. + +On migration, Headscale will change the account's username to their `preferred_username`. **This could break any ACLs or policies which are configured to match by username.** + +Like with Headscale v0.23.0 and earlier, this migration only works for users who haven't changed their email address since their last Headscale login. + +A _successful_ automated migration should otherwise be transparent to users. + +Once a Headscale account has been migrated, it will be _unavailable_ to be matched by the legacy process. An OIDC login with a matching username, but _non-matching_ `iss` and `sub` will instead get a _new_ Headscale account. + +Because of the way OIDC works, Headscale's automated migration process can _only_ work when a user tries to log in after the update. Mass updates would require Headscale implement a protocol like SCIM, which is **extremely** complicated and not available in all identity providers. + +Administrators could also attempt to migrate users manually by editing the database, using their own mapping rules with known-good data sources. + +Legacy account migration should have no effect on new installations where all users have a recorded `sub` and `iss`. + +##### What happens when automatic migration is disabled? + +When automatic migration is disabled (`map_legacy_users: false`), Headscale will only try to match an OIDC account to a Headscale account by `iss` and `sub`. + +If there is no match, it will get a _new_ Headscale account – even if there was a legacy account which _could_ have matched and migrated. + +We recommend new Headscale users explicitly disable automatic migration – but it should otherwise have no effect if every account has a recorded `iss` and `sub`. + +When automatic migration is disabled, the `strip_email_domain` setting will have no effect. + +Special thanks to @micolous for reviewing, proposing and working with us on these changes. + +#### Other OIDC changes + +Headscale now uses [the standard OIDC claims](https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims) to populate and update user information every time they log in: + +| Headscale profile field | OIDC claim | Notes / examples | +| ----------------------- | -------------------- | --------------------------------------------------------------------------------------------------------- | +| email address | `email` | Only used when `"email_verified": true` | +| display name | `name` | eg: `Sam Smith` | +| username | `preferred_username` | Varies depending on IdP and configuration, eg: `ssmith`, `ssmith@idp.example.com`, `\\example.com\ssmith` | +| profile picture | `picture` | URL to a profile picture or avatar | + +These should show up nicely in the Tailscale client. + +This will also affect the way you [reference users in policies](https://github.com/juanfont/headscale/pull/2205). + ### BREAKING - Remove `dns.use_username_in_magic_dns` configuration option [#2020](https://github.com/juanfont/headscale/pull/2020) - Having usernames in magic DNS is no longer possible. -- Redo OpenID Connect configuration [#2020](https://github.com/juanfont/headscale/pull/2020) - - `strip_email_domain` has been removed, domain is _always_ part of the username for OIDC. - - Users are now identified by `sub` claim in the ID token instead of username, allowing the username, name and email to be updated. - - User has been extended to store username, display name, profile picture url and email. - - These fields are forwarded to the client, and shows up nicely in the user switcher. - - These fields can be made available via the API/CLI for non-OIDC users in the future. - Remove versions older than 1.56 [#2149](https://github.com/juanfont/headscale/pull/2149) - Clean up old code required by old versions @@ -23,6 +89,7 @@ - Added conversion of 'Hostname' to 'givenName' in a node with FQDN rules applied [#2198](https://github.com/juanfont/headscale/pull/2198) - Fixed updating of hostname and givenName when it is updated in HostInfo [#2199](https://github.com/juanfont/headscale/pull/2199) - Fixed missing `stable-debug` container tag [#2232](https://github.com/juanfont/headscale/pr/2232) +- Loosened up `server_url` and `base_domain` check. It was overly strict in some cases. - Added manual approval of nodes in the network [#2245](https://github.com/juanfont/headscale/pr/2245) ## 0.23.0 (2024-09-18) diff --git a/Dockerfile.derper b/Dockerfile.derper new file mode 100644 index 0000000000..62adc7cf0c --- /dev/null +++ b/Dockerfile.derper @@ -0,0 +1,19 @@ +# For testing purposes only + +FROM golang:alpine AS build-env + +WORKDIR /go/src + +RUN apk add --no-cache git +ARG VERSION_BRANCH=main +RUN git clone https://github.com/tailscale/tailscale.git --branch=$VERSION_BRANCH --depth=1 +WORKDIR /go/src/tailscale + +ARG TARGETARCH +RUN GOARCH=$TARGETARCH go install -v ./cmd/derper + +FROM alpine:3.18 +RUN apk add --no-cache ca-certificates iptables iproute2 ip6tables curl + +COPY --from=build-env /go/bin/* /usr/local/bin/ +ENTRYPOINT [ "/usr/local/bin/derper" ] diff --git a/Dockerfile.debug b/Dockerfile.integration similarity index 100% rename from Dockerfile.debug rename to Dockerfile.integration diff --git a/Dockerfile.tailscale-HEAD b/Dockerfile.tailscale-HEAD index 92b0cae570..82f7a8d920 100644 --- a/Dockerfile.tailscale-HEAD +++ b/Dockerfile.tailscale-HEAD @@ -28,7 +28,9 @@ ARG VERSION_GIT_HASH="" ENV VERSION_GIT_HASH=$VERSION_GIT_HASH ARG TARGETARCH -RUN GOARCH=$TARGETARCH go install -ldflags="\ +ARG BUILD_TAGS="" + +RUN GOARCH=$TARGETARCH go install -tags="${BUILD_TAGS}" -ldflags="\ -X tailscale.com/version.longStamp=$VERSION_LONG \ -X tailscale.com/version.shortStamp=$VERSION_SHORT \ -X tailscale.com/version.gitCommitStamp=$VERSION_GIT_HASH" \ diff --git a/cmd/headscale/cli/mockoidc.go b/cmd/headscale/cli/mockoidc.go index 568a2a03e8..309ad67df2 100644 --- a/cmd/headscale/cli/mockoidc.go +++ b/cmd/headscale/cli/mockoidc.go @@ -1,8 +1,10 @@ package cli import ( + "encoding/json" "fmt" "net" + "net/http" "os" "strconv" "time" @@ -64,6 +66,19 @@ func mockOIDC() error { accessTTL = newTTL } + userStr := os.Getenv("MOCKOIDC_USERS") + if userStr == "" { + return fmt.Errorf("MOCKOIDC_USERS not defined") + } + + var users []mockoidc.MockUser + err := json.Unmarshal([]byte(userStr), &users) + if err != nil { + return fmt.Errorf("unmarshalling users: %w", err) + } + + log.Info().Interface("users", users).Msg("loading users from JSON") + log.Info().Msgf("Access token TTL: %s", accessTTL) port, err := strconv.Atoi(portStr) @@ -71,7 +86,7 @@ func mockOIDC() error { return err } - mock, err := getMockOIDC(clientID, clientSecret) + mock, err := getMockOIDC(clientID, clientSecret, users) if err != nil { return err } @@ -93,12 +108,18 @@ func mockOIDC() error { return nil } -func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, error) { +func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser) (*mockoidc.MockOIDC, error) { keypair, err := mockoidc.NewKeypair(nil) if err != nil { return nil, err } + userQueue := mockoidc.UserQueue{} + + for _, user := range users { + userQueue.Push(&user) + } + mock := mockoidc.MockOIDC{ ClientID: clientID, ClientSecret: clientSecret, @@ -107,9 +128,19 @@ func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, erro CodeChallengeMethodsSupported: []string{"plain", "S256"}, Keypair: keypair, SessionStore: mockoidc.NewSessionStore(), - UserQueue: &mockoidc.UserQueue{}, + UserQueue: &userQueue, ErrorQueue: &mockoidc.ErrorQueue{}, } + mock.AddMiddleware(func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Info().Msgf("Request: %+v", r) + h.ServeHTTP(w, r) + if r.Response != nil { + log.Info().Msgf("Response: %+v", r.Response) + } + }) + }) + return &mock, nil } diff --git a/config-example.yaml b/config-example.yaml index 60b6f0245d..6cbbbd7437 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -168,6 +168,11 @@ database: # https://www.sqlite.org/wal.html write_ahead_log: true + # Maximum number of WAL file frames before the WAL file is automatically checkpointed. + # https://www.sqlite.org/c3ref/wal_autocheckpoint.html + # Set to 0 to disable automatic checkpointing. + wal_autocheckpoint: 1000 + # # Postgres config # Please note that using Postgres is highly discouraged as it is only supported for legacy reasons. # See database.type for more information. @@ -364,12 +369,18 @@ unix_socket_permission: "0770" # allowed_users: # - alice@example.com # -# # If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed. -# # This will transform `first-name.last-name@example.com` to the user `first-name.last-name` -# # If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following -# user: `first-name.last-name.example.com` -# -# strip_email_domain: true +# # Map legacy users from pre-0.24.0 versions of headscale to the new OIDC users +# # by taking the username from the legacy user and matching it with the username +# # provided by the OIDC. This is useful when migrating from legacy users to OIDC +# # to force them using the unique identifier from the OIDC and to give them a +# # proper display name and picture if available. +# # Note that this will only work if the username from the legacy user is the same +# # and ther is a posibility for account takeover should a username have changed +# # with the provider. +# # Disabling this feature will cause all new logins to be created as new users. +# # Note this option will be removed in the future and should be set to false +# # on all new installations, or when all users have logged in with OIDC once. +# map_legacy_users: true # Logtail configuration # Logtail is Tailscales logging and auditing infrastructure, it allows the control panel diff --git a/docs/requirements.txt b/docs/requirements.txt index 0c70d5fbf8..d375747b83 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,7 +1,5 @@ -cairosvg~=2.7.1 -mkdocs-include-markdown-plugin~=6.2.2 -mkdocs-macros-plugin~=1.2.0 -mkdocs-material~=9.5.18 -mkdocs-minify-plugin~=0.7.1 -mkdocs-redirects~=1.2.1 -pillow~=10.1.0 +mkdocs-include-markdown-plugin~=7.1 +mkdocs-macros-plugin~=1.3 +mkdocs-material[imaging]~=9.5 +mkdocs-minify-plugin~=0.7 +mkdocs-redirects~=1.2 diff --git a/flake.lock b/flake.lock index 9a85828ef5..aaddd6a5b1 100644 --- a/flake.lock +++ b/flake.lock @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1731763621, - "narHash": "sha256-ddcX4lQL0X05AYkrkV2LMFgGdRvgap7Ho8kgon3iWZk=", + "lastModified": 1731890469, + "narHash": "sha256-D1FNZ70NmQEwNxpSSdTXCSklBH1z2isPR84J6DQrJGs=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "c69a9bffbecde46b4b939465422ddc59493d3e4d", + "rev": "5083ec887760adfe12af64830a66807423a859a7", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index df1b7e120d..90a2aad825 100644 --- a/flake.nix +++ b/flake.nix @@ -32,7 +32,7 @@ # When updating go.mod or go.sum, a new sha will need to be calculated, # update this if you have a mismatch after doing a change to thos files. - vendorHash = "sha256-CMkYTRjmhvTTrB7JbLj0cj9VEyzpG0iUWXkaOagwYTk="; + vendorHash = "sha256-4VNiHUblvtcl9UetwiL6ZeVYb0h2e9zhYVsirhAkvOg="; subPackages = ["cmd/headscale"]; @@ -102,6 +102,7 @@ ko yq-go ripgrep + postgresql # 'dot' is needed for pprof graphs # go tool pprof -http=: diff --git a/go.mod b/go.mod index 7eac4652e2..8d51fc6a8c 100644 --- a/go.mod +++ b/go.mod @@ -49,6 +49,7 @@ require ( gorm.io/gorm v1.25.11 tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7 zgo.at/zcache/v2 v2.1.0 + zombiezen.com/go/postgrestest v1.0.1 ) require ( @@ -134,6 +135,7 @@ require ( github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/lithammer/fuzzysearch v1.1.8 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect diff --git a/go.sum b/go.sum index cc15ef6ce3..9315dbb6cf 100644 --- a/go.sum +++ b/go.sum @@ -311,6 +311,7 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs= +github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lithammer/fuzzysearch v1.1.8 h1:/HIuJnjHuXS8bKaiTMeeDlW2/AyIWk2brx1V8LFgLN4= @@ -731,3 +732,5 @@ tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7 h1:nfRWV6ECxwNvvXKtbqSVs tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7/go.mod h1:xKxYf3B3PuezFlRaMT+VhuVu8XTFUTLy+VCzLPMJVmg= zgo.at/zcache/v2 v2.1.0 h1:USo+ubK+R4vtjw4viGzTe/zjXyPw6R7SK/RL3epBBxs= zgo.at/zcache/v2 v2.1.0/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk= +zombiezen.com/go/postgrestest v1.0.1 h1:aXoADQAJmZDU3+xilYVut0pHhgc0sF8ZspPW9gFNwP4= +zombiezen.com/go/postgrestest v1.0.1/go.mod h1:marlZezr+k2oSJrvXHnZUs1olHqpE9czlz8ZYkVxliQ= diff --git a/hscontrol/app.go b/hscontrol/app.go index 3cf0f5d450..7090874e83 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -88,7 +88,8 @@ type Headscale struct { DERPMap *tailcfg.DERPMap DERPServer *derpServer.DERPServer - ACLPolicy *policy.ACLPolicy + polManOnce sync.Once + polMan policy.PolicyManager mapper *mapper.Mapper nodeNotifier *notifier.Notifier @@ -154,6 +155,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { } }) + if err = app.loadPolicyManager(); err != nil { + return nil, fmt.Errorf("failed to load ACL policy: %w", err) + } + var authProvider AuthProvider authProvider = NewAuthProviderWeb(cfg.ServerURL) if cfg.OIDC.Issuer != "" { @@ -166,6 +171,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { app.db, app.nodeNotifier, app.ipAlloc, + app.polMan, ) if err != nil { if cfg.OIDC.OnlyStartIfOIDCIsAvailable { @@ -458,6 +464,8 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { router.HandleFunc("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1). Methods(http.MethodGet) + router.HandleFunc("/verify", h.VerifyHandler).Methods(http.MethodPost) + if h.cfg.DERP.ServerEnabled { router.HandleFunc("/derp", h.DERPServer.DERPHandler) router.HandleFunc("/derp/probe", derpServer.DERPProbeHandler) @@ -474,6 +482,52 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { return router } +// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. +// Maybe we should attempt a new in memory state and not go via the DB? +func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { + users, err := db.ListUsers() + if err != nil { + return err + } + + changed, err := polMan.SetUsers(users) + if err != nil { + return err + } + + if changed { + ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all") + notif.NotifyAll(ctx, types.StateUpdate{ + Type: types.StateFullUpdate, + }) + } + + return nil +} + +// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. +// Maybe we should attempt a new in memory state and not go via the DB? +func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { + nodes, err := db.ListNodes() + if err != nil { + return err + } + + changed, err := polMan.SetNodes(nodes) + if err != nil { + return err + } + + if changed { + ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all") + notif.NotifyAll(ctx, types.StateUpdate{ + Type: types.StateFullUpdate, + }) + } + + return nil +} + // Serve launches the HTTP and gRPC server service Headscale and the API. func (h *Headscale) Serve() error { if profilingEnabled { @@ -489,19 +543,13 @@ func (h *Headscale) Serve() error { } } - var err error - - if err = h.loadACLPolicy(); err != nil { - return fmt.Errorf("failed to load ACL policy: %w", err) - } - if dumpConfig { spew.Dump(h.cfg) } // Fetch an initial DERP Map before we start serving h.DERPMap = derp.GetDERPMap(h.cfg.DERP) - h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier) + h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan) if h.cfg.DERP.ServerEnabled { // When embedded DERP is enabled we always need a STUN server @@ -771,12 +819,21 @@ func (h *Headscale) Serve() error { Str("signal", sig.String()). Msg("Received SIGHUP, reloading ACL and Config") - // TODO(kradalby): Reload config on SIGHUP - if err := h.loadACLPolicy(); err != nil { - log.Error().Err(err).Msg("failed to reload ACL policy") + if err := h.loadPolicyManager(); err != nil { + log.Error().Err(err).Msg("failed to reload Policy") } - if h.ACLPolicy != nil { + pol, err := h.policyBytes() + if err != nil { + log.Error().Err(err).Msg("failed to get policy blob") + } + + changed, err := h.polMan.SetPolicy(pol) + if err != nil { + log.Error().Err(err).Msg("failed to set new policy") + } + + if changed { log.Info(). Msg("ACL policy successfully reloaded, notifying nodes of change") @@ -995,27 +1052,46 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { return &machineKey, nil } -func (h *Headscale) loadACLPolicy() error { - var ( - pol *policy.ACLPolicy - err error - ) - +// policyBytes returns the appropriate policy for the +// current configuration as a []byte array. +func (h *Headscale) policyBytes() ([]byte, error) { switch h.cfg.Policy.Mode { case types.PolicyModeFile: path := h.cfg.Policy.Path // It is fine to start headscale without a policy file. if len(path) == 0 { - return nil + return nil, nil } absPath := util.AbsolutePathFromConfigPath(path) - pol, err = policy.LoadACLPolicyFromPath(absPath) + policyFile, err := os.Open(absPath) + if err != nil { + return nil, err + } + defer policyFile.Close() + + return io.ReadAll(policyFile) + + case types.PolicyModeDB: + p, err := h.db.GetPolicy() if err != nil { - return fmt.Errorf("failed to load ACL policy from file: %w", err) + if errors.Is(err, types.ErrPolicyNotFound) { + return nil, nil + } + + return nil, err } + return []byte(p.Data), err + } + + return nil, fmt.Errorf("unsupported policy mode: %s", h.cfg.Policy.Mode) +} + +func (h *Headscale) loadPolicyManager() error { + var errOut error + h.polManOnce.Do(func() { // Validate and reject configuration that would error when applied // when creating a map response. This requires nodes, so there is still // a scenario where they might be allowed if the server has no nodes @@ -1026,42 +1102,35 @@ func (h *Headscale) loadACLPolicy() error { // allowed to be written to the database. nodes, err := h.db.ListNodes() if err != nil { - return fmt.Errorf("loading nodes from database to validate policy: %w", err) + errOut = fmt.Errorf("loading nodes from database to validate policy: %w", err) + return } - - _, err = pol.CompileFilterRules(nodes) + users, err := h.db.ListUsers() if err != nil { - return fmt.Errorf("verifying policy rules: %w", err) - } - - if len(nodes) > 0 { - _, err = pol.CompileSSHPolicy(nodes[0], nodes) - if err != nil { - return fmt.Errorf("verifying SSH rules: %w", err) - } + errOut = fmt.Errorf("loading users from database to validate policy: %w", err) + return } - case types.PolicyModeDB: - p, err := h.db.GetPolicy() + pol, err := h.policyBytes() if err != nil { - if errors.Is(err, types.ErrPolicyNotFound) { - return nil - } - - return fmt.Errorf("failed to get policy from database: %w", err) + errOut = fmt.Errorf("loading policy bytes: %w", err) + return } - pol, err = policy.LoadACLPolicyFromBytes([]byte(p.Data)) + h.polMan, err = policy.NewPolicyManager(pol, users, nodes) if err != nil { - return fmt.Errorf("failed to parse policy: %w", err) + errOut = fmt.Errorf("creating policy manager: %w", err) + return } - default: - log.Fatal(). - Str("mode", string(h.cfg.Policy.Mode)). - Msg("Unknown ACL policy mode") - } - h.ACLPolicy = pol + if len(nodes) > 0 { + _, err = h.polMan.SSHPolicy(nodes[0]) + if err != nil { + errOut = fmt.Errorf("verifying SSH rules: %w", err) + return + } + } + }) - return nil + return errOut } diff --git a/hscontrol/auth.go b/hscontrol/auth.go index d25fbd6139..c981a8e31e 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -394,6 +394,13 @@ func (h *Headscale) handleAuthKey( return } + + err = nodesChangedHook(h.db, h.polMan, h.nodeNotifier) + if err != nil { + http.Error(writer, "Internal server error", http.StatusInternalServerError) + return + } + } err = h.db.Write(func(tx *gorm.DB) error { diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index d61a3f5b48..de3869677a 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -479,6 +479,8 @@ func NewHeadscaleDatabase( Rollback: func(db *gorm.DB) error { return nil }, }, { + // Pick up new user fields used for OIDC and to + // populate the user with more interesting information. ID: "202407191627", Migrate: func(tx *gorm.DB) error { err := tx.AutoMigrate(&types.User{}) @@ -494,11 +496,11 @@ func NewHeadscaleDatabase( ID: "202410071005", Migrate: func(db *gorm.DB) error { err = db.AutoMigrate(&types.PreAuthKey{}) - if err != nil { + if err != nil { return err } - err = db.AutoMigrate(&types.Node{}) + err = db.AutoMigrate(&types.Node{}) if err != nil { return err } @@ -529,6 +531,40 @@ func NewHeadscaleDatabase( } return errNoNodeApprovedColumnInDatabase + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, + { + // The unique constraint of Name has been dropped + // in favour of a unique together of name and + // provider identity. + ID: "202408181235", + Migrate: func(tx *gorm.DB) error { + err := tx.AutoMigrate(&types.User{}) + if err != nil { + return err + } + + // Set up indexes and unique constraints outside of GORM, it does not support + // conditional unique constraints. + // This ensures the following: + // - A user name and provider_identifier is unique + // - A provider_identifier is unique + // - A user name is unique if there is no provider_identifier is not set + for _, idx := range []string{ + "DROP INDEX IF EXISTS idx_provider_identifier", + "DROP INDEX IF EXISTS idx_name_provider_identifier", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_provider_identifier ON users (provider_identifier) WHERE provider_identifier IS NOT NULL;", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_name_provider_identifier ON users (name,provider_identifier);", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_name_no_provider_identifier ON users (name) WHERE provider_identifier IS NULL;", + } { + err = tx.Exec(idx).Error + if err != nil { + return fmt.Errorf("creating username index: %w", err) + } + } + + return nil }, Rollback: func(db *gorm.DB) error { return nil }, }, @@ -591,10 +627,10 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) { } if cfg.Sqlite.WriteAheadLog { - if err := db.Exec(` + if err := db.Exec(fmt.Sprintf(` PRAGMA journal_mode=WAL; - PRAGMA wal_autocheckpoint=0; - `).Error; err != nil { + PRAGMA wal_autocheckpoint=%d; + `, cfg.Sqlite.WALAutoCheckPoint)).Error; err != nil { return nil, fmt.Errorf("setting WAL mode: %w", err) } } diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 48a93d459f..8b7d3888a8 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -1,6 +1,7 @@ package db import ( + "database/sql" "fmt" "io" "net/netip" @@ -8,6 +9,7 @@ import ( "path/filepath" "slices" "sort" + "strings" "testing" "time" @@ -16,6 +18,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gorm.io/gorm" "zgo.at/zcache/v2" ) @@ -44,7 +47,7 @@ func TestMigrations(t *testing.T) { routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) { return GetRoutes(rx) }) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, routes, 10) want := types.Routes{ @@ -70,7 +73,7 @@ func TestMigrations(t *testing.T) { routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) { return GetRoutes(rx) }) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, routes, 4) want := types.Routes{ @@ -120,19 +123,19 @@ func TestMigrations(t *testing.T) { dbPath: "testdata/0-23-0-to-0-24-0-preauthkey-tags-table.sqlite", wantFunc: func(t *testing.T, h *HSDatabase) { keys, err := Read(h.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { - kratest, err := ListPreAuthKeys(rx, "kratest") + kratest, err := ListPreAuthKeysByUser(rx, 1) // kratest if err != nil { return nil, err } - testkra, err := ListPreAuthKeys(rx, "testkra") + testkra, err := ListPreAuthKeysByUser(rx, 2) // testkra if err != nil { return nil, err } return append(kratest, testkra...), nil }) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, keys, 5) want := []types.PreAuthKey{ @@ -177,7 +180,7 @@ func TestMigrations(t *testing.T) { nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) { return ListNodes(rx) }) - assert.NoError(t, err) + require.NoError(t, err) for _, node := range nodes { assert.Falsef(t, node.MachineKey.IsZero(), "expected non zero machinekey") @@ -261,3 +264,120 @@ func testCopyOfDatabase(src string) (string, error) { func emptyCache() *zcache.Cache[string, types.Node] { return zcache.New[string, types.Node](time.Minute, time.Hour) } + +// requireConstraintFailed checks if the error is a constraint failure with +// either SQLite and PostgreSQL error messages. +func requireConstraintFailed(t *testing.T, err error) { + t.Helper() + require.Error(t, err) + if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") { + require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error()) + } +} + +func TestConstraints(t *testing.T) { + tests := []struct { + name string + run func(*testing.T, *gorm.DB) + }{ + { + name: "no-duplicate-username-if-no-oidc", + run: func(t *testing.T, db *gorm.DB) { + _, err := CreateUser(db, "user1") + require.NoError(t, err) + _, err = CreateUser(db, "user1") + requireConstraintFailed(t, err) + }, + }, + { + name: "no-oidc-duplicate-username-and-id", + run: func(t *testing.T, db *gorm.DB) { + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } + user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} + + err := db.Save(&user).Error + require.NoError(t, err) + + user = types.User{ + Model: gorm.Model{ID: 2}, + Name: "user1", + } + user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} + + err = db.Save(&user).Error + requireConstraintFailed(t, err) + }, + }, + { + name: "no-oidc-duplicate-id", + run: func(t *testing.T, db *gorm.DB) { + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } + user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} + + err := db.Save(&user).Error + require.NoError(t, err) + + user = types.User{ + Model: gorm.Model{ID: 2}, + Name: "user1.1", + } + user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} + + err = db.Save(&user).Error + requireConstraintFailed(t, err) + }, + }, + { + name: "allow-duplicate-username-cli-then-oidc", + run: func(t *testing.T, db *gorm.DB) { + _, err := CreateUser(db, "user1") // Create CLI username + require.NoError(t, err) + + user := types.User{ + Name: "user1", + ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true}, + } + + err = db.Save(&user).Error + require.NoError(t, err) + }, + }, + { + name: "allow-duplicate-username-oidc-then-cli", + run: func(t *testing.T, db *gorm.DB) { + user := types.User{ + Name: "user1", + ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true}, + } + + err := db.Save(&user).Error + require.NoError(t, err) + + _, err = CreateUser(db, "user1") // Create CLI username + require.NoError(t, err) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name+"-postgres", func(t *testing.T) { + db := newPostgresTestDB(t) + tt.run(t, db.DB.Debug()) + }) + t.Run(tt.name+"-sqlite", func(t *testing.T) { + db, err := newSQLiteTestDB() + if err != nil { + t.Fatalf("creating database: %s", err) + } + + tt.run(t, db.DB.Debug()) + }) + + } +} diff --git a/hscontrol/db/ip_test.go b/hscontrol/db/ip_test.go index b9a7582327..0e5b6ad43a 100644 --- a/hscontrol/db/ip_test.go +++ b/hscontrol/db/ip_test.go @@ -12,6 +12,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "tailscale.com/net/tsaddr" "tailscale.com/types/ptr" ) @@ -457,7 +458,12 @@ func TestBackfillIPAddresses(t *testing.T) { t.Run(tt.name, func(t *testing.T) { db := tt.dbFunc() - alloc, err := NewIPAllocator(db, tt.prefix4, tt.prefix6, types.IPAllocationStrategySequential) + alloc, err := NewIPAllocator( + db, + tt.prefix4, + tt.prefix6, + types.IPAllocationStrategySequential, + ) if err != nil { t.Fatalf("failed to set up ip alloc: %s", err) } @@ -482,24 +488,29 @@ func TestBackfillIPAddresses(t *testing.T) { } func TestIPAllocatorNextNoReservedIPs(t *testing.T) { - alloc, err := NewIPAllocator(db, ptr.To(tsaddr.CGNATRange()), ptr.To(tsaddr.TailscaleULARange()), types.IPAllocationStrategySequential) + alloc, err := NewIPAllocator( + db, + ptr.To(tsaddr.CGNATRange()), + ptr.To(tsaddr.TailscaleULARange()), + types.IPAllocationStrategySequential, + ) if err != nil { t.Fatalf("failed to set up ip alloc: %s", err) } // Validate that we do not give out 100.100.100.100 nextQuad100, err := alloc.next(na("100.100.100.99"), ptr.To(tsaddr.CGNATRange())) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, na("100.100.100.101"), *nextQuad100) // Validate that we do not give out fd7a:115c:a1e0::53 nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), ptr.To(tsaddr.TailscaleULARange())) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, na("fd7a:115c:a1e0::54"), *nextQuad100v6) // Validate that we do not give out fd7a:115c:a1e0::53 nextChrome, err := alloc.next(na("100.115.91.255"), ptr.To(tsaddr.CGNATRange())) t.Logf("chrome: %s", nextChrome.String()) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, na("100.115.94.0"), *nextChrome) } diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 52030bfbd0..180a01a6db 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -92,15 +92,15 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) { }) } -func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) { +func (hsdb *HSDatabase) getNode(uid types.UserID, name string) (*types.Node, error) { return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { - return getNode(rx, user, name) + return getNode(rx, uid, name) }) } // getNode finds a Node by name and user and returns the Node struct. -func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) { - nodes, err := ListNodesByUser(tx, user) +func getNode(tx *gorm.DB, uid types.UserID, name string) (*types.Node, error) { + nodes, err := ListNodesByUser(tx, uid) if err != nil { return nil, err } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 005ff56e4c..2c61886f46 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -17,6 +17,7 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/puzpuzpuz/xsync/v3" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/check.v1" "gorm.io/gorm" "tailscale.com/net/tsaddr" @@ -29,10 +30,10 @@ func (s *Suite) TestGetNode(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "testnode") + _, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -50,7 +51,7 @@ func (s *Suite) TestGetNode(c *check.C) { trx := db.DB.Save(node) c.Assert(trx.Error, check.IsNil) - _, err = db.getNode("test", "testnode") + _, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) } @@ -58,7 +59,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, true, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.GetNodeByID(0) @@ -87,7 +88,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.GetNodeByID(0) @@ -135,7 +136,7 @@ func (s *Suite) TestHardDeleteNode(c *check.C) { _, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]()) c.Assert(err, check.IsNil) - _, err = db.getNode(user.Name, "testnode3") + _, err = db.getNode(types.UserID(user.ID), "testnode3") c.Assert(err, check.NotNil) } @@ -143,7 +144,7 @@ func (s *Suite) TestListPeers(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.GetNodeByID(0) @@ -237,7 +238,8 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { for _, name := range []string{"test", "admin"} { user, err := db.CreateUser(name) c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, false, nil, nil) + + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, nil) c.Assert(err, check.IsNil) stor = append(stor, base{user, pak}) } @@ -304,10 +306,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(testPeers), check.Equals, 9) - adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers) + adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user}) c.Assert(err, check.IsNil) - testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers) + testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user}) c.Assert(err, check.IsNil) peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) @@ -370,10 +372,10 @@ func (s *Suite) TestExpireNode(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "testnode") + _, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -391,7 +393,7 @@ func (s *Suite) TestExpireNode(c *check.C) { } db.DB.Save(node) - nodeFromDB, err := db.getNode("test", "testnode") + nodeFromDB, err := db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) c.Assert(nodeFromDB, check.NotNil) @@ -401,7 +403,7 @@ func (s *Suite) TestExpireNode(c *check.C) { err = db.NodeSetExpiry(nodeFromDB.ID, now) c.Assert(err, check.IsNil) - nodeFromDB, err = db.getNode("test", "testnode") + nodeFromDB, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) c.Assert(nodeFromDB.IsExpired(), check.Equals, true) @@ -411,10 +413,10 @@ func (s *Suite) TestSetTags(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "testnode") + _, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -437,7 +439,7 @@ func (s *Suite) TestSetTags(c *check.C) { sTags := []string{"tag:test", "tag:foo"} err = db.SetTags(node.ID, sTags) c.Assert(err, check.IsNil) - node, err = db.getNode("test", "testnode") + node, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) c.Assert(node.ForcedTags, check.DeepEquals, sTags) @@ -445,7 +447,7 @@ func (s *Suite) TestSetTags(c *check.C) { eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} err = db.SetTags(node.ID, eTags) c.Assert(err, check.IsNil) - node, err = db.getNode("test", "testnode") + node, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) c.Assert( node.ForcedTags, @@ -456,7 +458,7 @@ func (s *Suite) TestSetTags(c *check.C) { // test removing tags err = db.SetTags(node.ID, []string{}) c.Assert(err, check.IsNil) - node, err = db.getNode("test", "testnode") + node, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) c.Assert(node.ForcedTags, check.DeepEquals, []string{}) } @@ -646,18 +648,18 @@ func TestAutoApproveRoutes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - adb, err := newTestDB() - assert.NoError(t, err) + adb, err := newSQLiteTestDB() + require.NoError(t, err) pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl)) - assert.NoError(t, err) - assert.NotNil(t, pol) + require.NoError(t, err) + require.NotNil(t, pol) user, err := adb.CreateUser("test") - assert.NoError(t, err) + require.NoError(t, err) - pak, err := adb.CreatePreAuthKey(user.Name, false, false, false, nil, nil) - assert.NoError(t, err) + pak, err := adb.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, nil) + require.NoError(t, err) nodeKey := key.NewNode() machineKey := key.NewMachine() @@ -679,21 +681,30 @@ func TestAutoApproveRoutes(t *testing.T) { } trx := adb.DB.Save(&node) - assert.NoError(t, trx.Error) + require.NoError(t, trx.Error) sendUpdate, err := adb.SaveNodeRoutes(&node) - assert.NoError(t, err) + require.NoError(t, err) assert.False(t, sendUpdate) node0ByID, err := adb.GetNodeByID(0) + require.NoError(t, err) + + users, err := adb.ListUsers() assert.NoError(t, err) - // TODO(kradalby): Check state update - err = adb.EnableAutoApprovedRoutes(pol, node0ByID) + nodes, err := adb.ListNodes() assert.NoError(t, err) - enabledRoutes, err := adb.GetEnabledRoutes(node0ByID) + pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes) assert.NoError(t, err) + + // TODO(kradalby): Check state update + err = adb.EnableAutoApprovedRoutes(pm, node0ByID) + require.NoError(t, err) + + enabledRoutes, err := adb.GetEnabledRoutes(node0ByID) + require.NoError(t, err) assert.Len(t, enabledRoutes, len(tt.want)) tsaddr.SortPrefixes(enabledRoutes) @@ -780,19 +791,19 @@ func generateRandomNumber(t *testing.T, max int64) int64 { } func TestListEphemeralNodes(t *testing.T) { - db, err := newTestDB() + db, err := newSQLiteTestDB() if err != nil { t.Fatalf("creating db: %s", err) } user, err := db.CreateUser("test") - assert.NoError(t, err) + require.NoError(t, err) - pak, err := db.CreatePreAuthKey(user.Name, false, false, false, nil, nil) - assert.NoError(t, err) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, nil) + require.NoError(t, err) - pakEph, err := db.CreatePreAuthKey(user.Name, false, false, true, nil, nil) - assert.NoError(t, err) + pakEph, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, true, nil, nil) + require.NoError(t, err) node := types.Node{ ID: 0, @@ -815,16 +826,16 @@ func TestListEphemeralNodes(t *testing.T) { } err = db.DB.Save(&node).Error - assert.NoError(t, err) + require.NoError(t, err) err = db.DB.Save(&nodeEph).Error - assert.NoError(t, err) + require.NoError(t, err) nodes, err := db.ListNodes() - assert.NoError(t, err) + require.NoError(t, err) ephemeralNodes, err := db.ListEphemeralNodes() - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, nodes, 2) assert.Len(t, ephemeralNodes, 1) @@ -836,16 +847,16 @@ func TestListEphemeralNodes(t *testing.T) { } func TestRenameNode(t *testing.T) { - db, err := newTestDB() + db, err := newSQLiteTestDB() if err != nil { t.Fatalf("creating db: %s", err) } user, err := db.CreateUser("test") - assert.NoError(t, err) + require.NoError(t, err) user2, err := db.CreateUser("test2") - assert.NoError(t, err) + require.NoError(t, err) node := types.Node{ ID: 0, @@ -866,10 +877,10 @@ func TestRenameNode(t *testing.T) { } err = db.DB.Save(&node).Error - assert.NoError(t, err) + require.NoError(t, err) err = db.DB.Save(&node2).Error - assert.NoError(t, err) + require.NoError(t, err) err = db.DB.Transaction(func(tx *gorm.DB) error { _, err := RegisterNode(tx, node, nil, nil) @@ -879,10 +890,10 @@ func TestRenameNode(t *testing.T) { _, err = RegisterNode(tx, node2, nil, nil) return err }) - assert.NoError(t, err) + require.NoError(t, err) nodes, err := db.ListNodes() - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, nodes, 2) @@ -904,26 +915,26 @@ func TestRenameNode(t *testing.T) { err = db.Write(func(tx *gorm.DB) error { return RenameNode(tx, nodes[0].ID, "newname") }) - assert.NoError(t, err) + require.NoError(t, err) nodes, err = db.ListNodes() - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, nodes, 2) - assert.Equal(t, nodes[0].Hostname, "test") - assert.Equal(t, nodes[0].GivenName, "newname") + assert.Equal(t, "test", nodes[0].Hostname) + assert.Equal(t, "newname", nodes[0].GivenName) // Nodes can reuse name that is no longer used err = db.Write(func(tx *gorm.DB) error { return RenameNode(tx, nodes[1].ID, "test") }) - assert.NoError(t, err) + require.NoError(t, err) nodes, err = db.ListNodes() - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, nodes, 2) - assert.Equal(t, nodes[0].Hostname, "test") - assert.Equal(t, nodes[0].GivenName, "newname") - assert.Equal(t, nodes[1].GivenName, "test") + assert.Equal(t, "test", nodes[0].Hostname) + assert.Equal(t, "newname", nodes[0].GivenName) + assert.Equal(t, "test", nodes[1].GivenName) // Nodes cannot be renamed to used names err = db.Write(func(tx *gorm.DB) error { diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index 26a2131b07..e24bd86e08 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -23,8 +23,7 @@ var ( ) func (hsdb *HSDatabase) CreatePreAuthKey( - // TODO(kradalby): Should be ID, not name - userName string, + uid types.UserID, reusable bool, preApproved bool, ephemeral bool, @@ -32,22 +31,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey( aclTags []string, ) (*types.PreAuthKey, error) { return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) { - return CreatePreAuthKey(tx, userName, reusable, preApproved, ephemeral, expiration, aclTags) + return CreatePreAuthKey(tx, uid, reusable, preApproved, ephemeral, expiration, aclTags) }) } // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. func CreatePreAuthKey( tx *gorm.DB, - // TODO(kradalby): Should be ID, not name - userName string, + uid types.UserID, reusable bool, preApproved bool, ephemeral bool, expiration *time.Time, aclTags []string, ) (*types.PreAuthKey, error) { - user, err := GetUserByUsername(tx, userName) + user, err := GetUserByID(tx, uid) if err != nil { return nil, err } @@ -92,15 +90,15 @@ func CreatePreAuthKey( return &key, nil } -func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { +func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) { return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { - return ListPreAuthKeys(rx, userName) + return ListPreAuthKeysByUser(rx, uid) }) } -// ListPreAuthKeys returns the list of PreAuthKeys for a user. -func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) { - user, err := GetUserByUsername(tx, userName) +// ListPreAuthKeysByUser returns the list of PreAuthKeys for a user. +func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, error) { + user, err := GetUserByID(tx, uid) if err != nil { return nil, err } diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index e655186765..2884cb6cd5 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -11,14 +11,14 @@ import ( ) func (*Suite) TestCreatePreAuthKey(c *check.C) { - _, err := db.CreatePreAuthKey("bogus", true, false, false, nil, nil) - + // ID does not exist + _, err := db.CreatePreAuthKey(12345, true, false, false, nil, nil) c.Assert(err, check.NotNil) user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - key, err := db.CreatePreAuthKey(user.Name, true, false, false, nil, nil) + key, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, false, nil, nil) c.Assert(err, check.IsNil) // Did we get a valid key? @@ -26,17 +26,18 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { c.Assert(len(key.Key), check.Equals, 48) // Make sure the User association is populated - c.Assert(key.User.Name, check.Equals, user.Name) + c.Assert(key.User.ID, check.Equals, user.ID) - _, err = db.ListPreAuthKeys("bogus") + // ID does not exist + _, err = db.ListPreAuthKeys(1000000) c.Assert(err, check.NotNil) - keys, err := db.ListPreAuthKeys(user.Name) + keys, err := db.ListPreAuthKeys(types.UserID(user.ID)) c.Assert(err, check.IsNil) c.Assert(len(keys), check.Equals, 1) // Make sure the User association is populated - c.Assert((keys)[0].User.Name, check.Equals, user.Name) + c.Assert((keys)[0].User.ID, check.Equals, user.ID) } func (*Suite) TestExpiredPreAuthKey(c *check.C) { @@ -44,7 +45,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) { c.Assert(err, check.IsNil) now := time.Now().Add(-5 * time.Second) - pak, err := db.CreatePreAuthKey(user.Name, true, false, false, &now, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, false, &now, nil) c.Assert(err, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) @@ -72,7 +73,7 @@ func (*Suite) TestValidateKeyOk(c *check.C) { user, err := db.CreateUser("test3") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, true, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, false, nil, nil) c.Assert(err, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) @@ -84,7 +85,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { user, err := db.CreateUser("test4") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ @@ -106,7 +107,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { user, err := db.CreateUser("test5") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, true, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ @@ -128,7 +129,7 @@ func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { user, err := db.CreateUser("test6") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, nil) c.Assert(err, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) @@ -140,7 +141,7 @@ func (*Suite) TestExpirePreauthKey(c *check.C) { user, err := db.CreateUser("test3") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, true, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, false, nil, nil) c.Assert(err, check.IsNil) c.Assert(pak.Expiration, check.IsNil) @@ -157,7 +158,7 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { user, err := db.CreateUser("test6") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, nil) c.Assert(err, check.IsNil) pak.Used = true db.DB.Save(&pak) @@ -170,15 +171,15 @@ func (*Suite) TestPreAuthKeyACLTags(c *check.C) { user, err := db.CreateUser("test8") c.Assert(err, check.IsNil) - _, err = db.CreatePreAuthKey(user.Name, false, false, false, nil, []string{"badtag"}) + _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, []string{"badtag"}) c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected tags := []string{"tag:test1", "tag:test2"} tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} - _, err = db.CreatePreAuthKey(user.Name, false, false, false, nil, tagsWithDuplicate) + _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, tagsWithDuplicate) c.Assert(err, check.IsNil) - listedPaks, err := db.ListPreAuthKeys("test8") + listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID)) c.Assert(err, check.IsNil) gotTags := listedPaks[0].Proto().GetAclTags() sort.Sort(sort.StringSlice(gotTags)) diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 086261aa78..0a72c4278e 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -598,18 +598,18 @@ func failoverRoute( } func (hsdb *HSDatabase) EnableAutoApprovedRoutes( - aclPolicy *policy.ACLPolicy, + polMan policy.PolicyManager, node *types.Node, ) error { return hsdb.Write(func(tx *gorm.DB) error { - return EnableAutoApprovedRoutes(tx, aclPolicy, node) + return EnableAutoApprovedRoutes(tx, polMan, node) }) } // EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy. func EnableAutoApprovedRoutes( tx *gorm.DB, - aclPolicy *policy.ACLPolicy, + polMan policy.PolicyManager, node *types.Node, ) error { if node.IPv4 == nil && node.IPv6 == nil { @@ -630,16 +630,11 @@ func EnableAutoApprovedRoutes( continue } - routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers( - netip.Prefix(advertisedRoute.Prefix), - ) - if err != nil { - return fmt.Errorf("failed to resolve autoApprovers for route(%d) for node(%s %d): %w", advertisedRoute.ID, node.Hostname, node.ID, err) - } + routeApprovers := polMan.ApproversForRoute(netip.Prefix(advertisedRoute.Prefix)) log.Trace(). Str("node", node.Hostname). - Str("user", node.User.Name). + Uint("user.id", node.User.ID). Strs("routeApprovers", routeApprovers). Str("prefix", netip.Prefix(advertisedRoute.Prefix).String()). Msg("looking up route for autoapproving") @@ -649,7 +644,7 @@ func EnableAutoApprovedRoutes( approvedRoutes = append(approvedRoutes, advertisedRoute) } else { // TODO(kradalby): figure out how to get this to depend on less stuff - approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias) + approvedIps, err := polMan.ExpandAlias(approvedAlias) if err != nil { return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err) } diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index ccb9799c06..412a256ebe 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -35,10 +35,10 @@ func (s *Suite) TestGetRoutes(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "test_get_route_node") + _, err = db.getNode(types.UserID(user.ID), "test_get_route_node") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix("10.0.0.0/24") @@ -79,10 +79,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "test_enable_route_node") + _, err = db.getNode(types.UserID(user.ID), "test_enable_route_node") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -153,10 +153,10 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "test_enable_route_node") + _, err = db.getNode(types.UserID(user.ID), "test_enable_route_node") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -234,10 +234,10 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "test_enable_route_node") + _, err = db.getNode(types.UserID(user.ID), "test_enable_route_node") c.Assert(err, check.NotNil) prefix, err := netip.ParsePrefix( diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index 237ae290c6..cfced36ea3 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -1,12 +1,17 @@ package db import ( + "context" "log" + "net/url" "os" + "strconv" + "strings" "testing" "github.com/juanfont/headscale/hscontrol/types" "gopkg.in/check.v1" + "zombiezen.com/go/postgrestest" ) func Test(t *testing.T) { @@ -36,13 +41,15 @@ func (s *Suite) ResetDB(c *check.C) { // } var err error - db, err = newTestDB() + db, err = newSQLiteTestDB() if err != nil { c.Fatal(err) } } -func newTestDB() (*HSDatabase, error) { +// TODO(kradalby): make this a t.Helper when we dont depend +// on check test framework. +func newSQLiteTestDB() (*HSDatabase, error) { var err error tmpDir, err = os.MkdirTemp("", "headscale-db-test-*") if err != nil { @@ -53,7 +60,7 @@ func newTestDB() (*HSDatabase, error) { db, err = NewHeadscaleDatabase( types.DatabaseConfig{ - Type: "sqlite3", + Type: types.DatabaseSqlite, Sqlite: types.SqliteConfig{ Path: tmpDir + "/headscale_test.db", }, @@ -68,3 +75,53 @@ func newTestDB() (*HSDatabase, error) { return db, nil } + +func newPostgresTestDB(t *testing.T) *HSDatabase { + t.Helper() + + var err error + tmpDir, err = os.MkdirTemp("", "headscale-db-test-*") + if err != nil { + t.Fatal(err) + } + + log.Printf("database path: %s", tmpDir+"/headscale_test.db") + + ctx := context.Background() + srv, err := postgrestest.Start(ctx) + if err != nil { + t.Fatal(err) + } + t.Cleanup(srv.Cleanup) + + u, err := srv.CreateDatabase(ctx) + if err != nil { + t.Fatal(err) + } + t.Logf("created local postgres: %s", u) + pu, _ := url.Parse(u) + + pass, _ := pu.User.Password() + port, _ := strconv.Atoi(pu.Port()) + + db, err = NewHeadscaleDatabase( + types.DatabaseConfig{ + Type: types.DatabasePostgres, + Postgres: types.PostgresConfig{ + Host: pu.Hostname(), + User: pu.User.Username(), + Name: strings.TrimLeft(pu.Path, "/"), + Pass: pass, + Port: port, + Ssl: "disable", + }, + }, + "", + emptyCache(), + ) + if err != nil { + t.Fatal(err) + } + + return db +} diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 135276c76b..0eaa9ea348 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -28,11 +28,9 @@ func CreateUser(tx *gorm.DB, name string) (*types.User, error) { if err != nil { return nil, err } - user := types.User{} - if err := tx.Where("name = ?", name).First(&user).Error; err == nil { - return nil, ErrUserExists + user := types.User{ + Name: name, } - user.Name = name if err := tx.Create(&user).Error; err != nil { return nil, fmt.Errorf("creating user: %w", err) } @@ -40,21 +38,21 @@ func CreateUser(tx *gorm.DB, name string) (*types.User, error) { return &user, nil } -func (hsdb *HSDatabase) DestroyUser(name string) error { +func (hsdb *HSDatabase) DestroyUser(uid types.UserID) error { return hsdb.Write(func(tx *gorm.DB) error { - return DestroyUser(tx, name) + return DestroyUser(tx, uid) }) } // DestroyUser destroys a User. Returns error if the User does // not exist or if there are nodes associated with it. -func DestroyUser(tx *gorm.DB, name string) error { - user, err := GetUserByUsername(tx, name) +func DestroyUser(tx *gorm.DB, uid types.UserID) error { + user, err := GetUserByID(tx, uid) if err != nil { - return ErrUserNotFound + return err } - nodes, err := ListNodesByUser(tx, name) + nodes, err := ListNodesByUser(tx, uid) if err != nil { return err } @@ -62,7 +60,7 @@ func DestroyUser(tx *gorm.DB, name string) error { return ErrUserStillHasNodes } - keys, err := ListPreAuthKeys(tx, name) + keys, err := ListPreAuthKeysByUser(tx, uid) if err != nil { return err } @@ -80,17 +78,17 @@ func DestroyUser(tx *gorm.DB, name string) error { return nil } -func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { +func (hsdb *HSDatabase) RenameUser(uid types.UserID, newName string) error { return hsdb.Write(func(tx *gorm.DB) error { - return RenameUser(tx, oldName, newName) + return RenameUser(tx, uid, newName) }) } // RenameUser renames a User. Returns error if the User does // not exist or if another User exists with the new name. -func RenameUser(tx *gorm.DB, oldName, newName string) error { +func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error { var err error - oldUser, err := GetUserByUsername(tx, oldName) + oldUser, err := GetUserByID(tx, uid) if err != nil { return err } @@ -98,50 +96,25 @@ func RenameUser(tx *gorm.DB, oldName, newName string) error { if err != nil { return err } - _, err = GetUserByUsername(tx, newName) - if err == nil { - return ErrUserExists - } - if !errors.Is(err, ErrUserNotFound) { - return err - } oldUser.Name = newName - if result := tx.Save(&oldUser); result.Error != nil { - return result.Error + if err := tx.Save(&oldUser).Error; err != nil { + return err } return nil } -func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { +func (hsdb *HSDatabase) GetUserByID(uid types.UserID) (*types.User, error) { return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) { - return GetUserByUsername(rx, name) + return GetUserByID(rx, uid) }) } -func GetUserByUsername(tx *gorm.DB, name string) (*types.User, error) { +func GetUserByID(tx *gorm.DB, uid types.UserID) (*types.User, error) { user := types.User{} - if result := tx.First(&user, "name = ?", name); errors.Is( - result.Error, - gorm.ErrRecordNotFound, - ) { - return nil, ErrUserNotFound - } - - return &user, nil -} - -func (hsdb *HSDatabase) GetUserByID(id types.UserID) (*types.User, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) { - return GetUserByID(rx, id) - }) -} - -func GetUserByID(tx *gorm.DB, id types.UserID) (*types.User, error) { - user := types.User{} - if result := tx.First(&user, "id = ?", id); errors.Is( + if result := tx.First(&user, "id = ?", uid); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { @@ -169,54 +142,69 @@ func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) { return &user, nil } -func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { +func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) { return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) { - return ListUsers(rx) + return ListUsers(rx, where...) }) } // ListUsers gets all the existing users. -func ListUsers(tx *gorm.DB) ([]types.User, error) { +func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) { + if len(where) > 1 { + return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where)) + } + + var user *types.User + if len(where) == 1 { + user = where[0] + } + users := []types.User{} - if err := tx.Find(&users).Error; err != nil { + if err := tx.Where(user).Find(&users).Error; err != nil { return nil, err } return users, nil } -// ListNodesByUser gets all the nodes in a given user. -func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) { - err := util.CheckForFQDNRules(name) +// GetUserByName returns a user if the provided username is +// unique, and otherwise an error. +func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { + users, err := hsdb.ListUsers(&types.User{Name: name}) if err != nil { return nil, err } - user, err := GetUserByUsername(tx, name) - if err != nil { - return nil, err + + if len(users) == 0 { + return nil, ErrUserNotFound } + if len(users) != 1 { + return nil, fmt.Errorf("expected exactly one user, found %d", len(users)) + } + + return &users[0], nil +} + +// ListNodesByUser gets all the nodes in a given user. +func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) { nodes := types.Nodes{} - if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil { + if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: uint(uid)}).Find(&nodes).Error; err != nil { return nil, err } return nodes, nil } -func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error { +func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, uid types.UserID) error { return hsdb.Write(func(tx *gorm.DB) error { - return AssignNodeToUser(tx, node, username) + return AssignNodeToUser(tx, node, uid) }) } // AssignNodeToUser assigns a Node to a user. -func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error { - err := util.CheckForFQDNRules(username) - if err != nil { - return err - } - user, err := GetUserByUsername(tx, username) +func AssignNodeToUser(tx *gorm.DB, node *types.Node, uid types.UserID) error { + user, err := GetUserByID(tx, uid) if err != nil { return err } diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 18edf89d56..e34111bdcb 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -1,6 +1,8 @@ package db import ( + "strings" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" @@ -17,24 +19,24 @@ func (s *Suite) TestCreateAndDestroyUser(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(users), check.Equals, 1) - err = db.DestroyUser("test") + err = db.DestroyUser(types.UserID(user.ID)) c.Assert(err, check.IsNil) - _, err = db.GetUserByName("test") + _, err = db.GetUserByID(types.UserID(user.ID)) c.Assert(err, check.NotNil) } func (s *Suite) TestDestroyUserErrors(c *check.C) { - err := db.DestroyUser("test") + err := db.DestroyUser(9998) c.Assert(err, check.Equals, ErrUserNotFound) user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, nil) c.Assert(err, check.IsNil) - err = db.DestroyUser("test") + err = db.DestroyUser(types.UserID(user.ID)) c.Assert(err, check.IsNil) result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key) @@ -44,7 +46,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { user, err = db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err = db.CreatePreAuthKey(user.Name, false, false, false, nil, nil) + pak, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ @@ -57,7 +59,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) - err = db.DestroyUser("test") + err = db.DestroyUser(types.UserID(user.ID)) c.Assert(err, check.Equals, ErrUserStillHasNodes) } @@ -70,24 +72,29 @@ func (s *Suite) TestRenameUser(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(users), check.Equals, 1) - err = db.RenameUser("test", "test-renamed") + err = db.RenameUser(types.UserID(userTest.ID), "test-renamed") c.Assert(err, check.IsNil) - _, err = db.GetUserByName("test") - c.Assert(err, check.Equals, ErrUserNotFound) + users, err = db.ListUsers(&types.User{Name: "test"}) + c.Assert(err, check.Equals, nil) + c.Assert(len(users), check.Equals, 0) - _, err = db.GetUserByName("test-renamed") + users, err = db.ListUsers(&types.User{Name: "test-renamed"}) c.Assert(err, check.IsNil) + c.Assert(len(users), check.Equals, 1) - err = db.RenameUser("test-does-not-exit", "test") + err = db.RenameUser(99988, "test") c.Assert(err, check.Equals, ErrUserNotFound) userTest2, err := db.CreateUser("test2") c.Assert(err, check.IsNil) c.Assert(userTest2.Name, check.Equals, "test2") - err = db.RenameUser("test2", "test-renamed") - c.Assert(err, check.Equals, ErrUserExists) + want := "UNIQUE constraint failed" + err = db.RenameUser(types.UserID(userTest2.ID), "test-renamed") + if err == nil || !strings.Contains(err.Error(), want) { + c.Fatalf("expected failure with unique constraint, want: %q got: %q", want, err) + } } func (s *Suite) TestSetMachineUser(c *check.C) { @@ -97,7 +104,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) { newUser, err := db.CreateUser("new") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ @@ -111,15 +118,15 @@ func (s *Suite) TestSetMachineUser(c *check.C) { c.Assert(trx.Error, check.IsNil) c.Assert(node.UserID, check.Equals, oldUser.ID) - err = db.AssignNodeToUser(&node, newUser.Name) + err = db.AssignNodeToUser(&node, types.UserID(newUser.ID)) c.Assert(err, check.IsNil) c.Assert(node.UserID, check.Equals, newUser.ID) c.Assert(node.User.Name, check.Equals, newUser.Name) - err = db.AssignNodeToUser(&node, "non-existing-user") + err = db.AssignNodeToUser(&node, 9584849) c.Assert(err, check.Equals, ErrUserNotFound) - err = db.AssignNodeToUser(&node, newUser.Name) + err = db.AssignNodeToUser(&node, types.UserID(newUser.ID)) c.Assert(err, check.IsNil) c.Assert(node.UserID, check.Equals, newUser.ID) c.Assert(node.User.Name, check.Equals, newUser.Name) diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 5b7aa1ebc6..5df9791b09 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -21,7 +21,6 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/db" - "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" ) @@ -58,6 +57,11 @@ func (api headscaleV1APIServer) CreateUser( return nil, err } + err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + if err != nil { + return nil, fmt.Errorf("updating resources using user: %w", err) + } + return &v1.CreateUserResponse{User: user.Proto()}, nil } @@ -65,28 +69,43 @@ func (api headscaleV1APIServer) RenameUser( ctx context.Context, request *v1.RenameUserRequest, ) (*v1.RenameUserResponse, error) { - err := api.h.db.RenameUser(request.GetOldName(), request.GetNewName()) + oldUser, err := api.h.db.GetUserByName(request.GetOldName()) + if err != nil { + return nil, err + } + + err = api.h.db.RenameUser(types.UserID(oldUser.ID), request.GetNewName()) if err != nil { return nil, err } - user, err := api.h.db.GetUserByName(request.GetNewName()) + newUser, err := api.h.db.GetUserByName(request.GetNewName()) if err != nil { return nil, err } - return &v1.RenameUserResponse{User: user.Proto()}, nil + return &v1.RenameUserResponse{User: newUser.Proto()}, nil } func (api headscaleV1APIServer) DeleteUser( ctx context.Context, request *v1.DeleteUserRequest, ) (*v1.DeleteUserResponse, error) { - err := api.h.db.DestroyUser(request.GetName()) + user, err := api.h.db.GetUserByName(request.GetName()) + if err != nil { + return nil, err + } + + err = api.h.db.DestroyUser(types.UserID(user.ID)) if err != nil { return nil, err } + err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + if err != nil { + return nil, fmt.Errorf("updating resources using user: %w", err) + } + return &v1.DeleteUserResponse{}, nil } @@ -131,8 +150,13 @@ func (api headscaleV1APIServer) CreatePreAuthKey( } } + user, err := api.h.db.GetUserByName(request.GetUser()) + if err != nil { + return nil, err + } + preAuthKey, err := api.h.db.CreatePreAuthKey( - request.GetUser(), + types.UserID(user.ID), request.GetReusable(), request.GetPreApproved(), request.GetEphemeral(), @@ -169,7 +193,12 @@ func (api headscaleV1APIServer) ListPreAuthKeys( ctx context.Context, request *v1.ListPreAuthKeysRequest, ) (*v1.ListPreAuthKeysResponse, error) { - preAuthKeys, err := api.h.db.ListPreAuthKeys(request.GetUser()) + user, err := api.h.db.GetUserByName(request.GetUser()) + if err != nil { + return nil, err + } + + preAuthKeys, err := api.h.db.ListPreAuthKeys(types.UserID(user.ID)) if err != nil { return nil, err } @@ -222,6 +251,11 @@ func (api headscaleV1APIServer) RegisterNode( return nil, err } + err = nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + if err != nil { + return nil, fmt.Errorf("updating resources using node: %w", err) + } + return &v1.RegisterNodeResponse{Node: node.Proto()}, nil } @@ -455,10 +489,20 @@ func (api headscaleV1APIServer) ListNodes( ctx context.Context, request *v1.ListNodesRequest, ) (*v1.ListNodesResponse, error) { + // TODO(kradalby): it looks like this can be simplified a lot, + // the filtering of nodes by user, vs nodes as a whole can + // probably be done once. + // TODO(kradalby): This should be done in one tx. + isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap() if request.GetUser() != "" { + user, err := api.h.db.GetUserByName(request.GetUser()) + if err != nil { + return nil, err + } + nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) { - return db.ListNodesByUser(rx, request.GetUser()) + return db.ListNodesByUser(rx, types.UserID(user.ID)) }) if err != nil { return nil, err @@ -499,10 +543,7 @@ func (api headscaleV1APIServer) ListNodes( resp.Online = true } - validTags, invalidTags := api.h.ACLPolicy.TagsOfNode( - node, - ) - resp.InvalidTags = invalidTags + validTags := api.h.polMan.Tags(node) resp.ValidTags = validTags response[index] = resp } @@ -514,12 +555,18 @@ func (api headscaleV1APIServer) MoveNode( ctx context.Context, request *v1.MoveNodeRequest, ) (*v1.MoveNodeResponse, error) { + // TODO(kradalby): This should be done in one tx. node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId())) if err != nil { return nil, err } - err = api.h.db.AssignNodeToUser(node, request.GetUser()) + user, err := api.h.db.GetUserByName(request.GetUser()) + if err != nil { + return nil, err + } + + err = api.h.db.AssignNodeToUser(node, types.UserID(user.ID)) if err != nil { return nil, err } @@ -772,11 +819,6 @@ func (api headscaleV1APIServer) SetPolicy( p := request.GetPolicy() - pol, err := policy.LoadACLPolicyFromBytes([]byte(p)) - if err != nil { - return nil, fmt.Errorf("loading ACL policy file: %w", err) - } - // Validate and reject configuration that would error when applied // when creating a map response. This requires nodes, so there is still // a scenario where they might be allowed if the server has no nodes @@ -786,14 +828,13 @@ func (api headscaleV1APIServer) SetPolicy( if err != nil { return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) } - - _, err = pol.CompileFilterRules(nodes) + changed, err := api.h.polMan.SetPolicy([]byte(p)) if err != nil { - return nil, fmt.Errorf("verifying policy rules: %w", err) + return nil, fmt.Errorf("setting policy: %w", err) } if len(nodes) > 0 { - _, err = pol.CompileSSHPolicy(nodes[0], nodes) + _, err = api.h.polMan.SSHPolicy(nodes[0]) if err != nil { return nil, fmt.Errorf("verifying SSH rules: %w", err) } @@ -804,12 +845,13 @@ func (api headscaleV1APIServer) SetPolicy( return nil, err } - api.h.ACLPolicy = pol - - ctx := types.NotifyCtx(context.Background(), "acl-update", "na") - api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StateFullUpdate, - }) + // Only send update if the packet filter has changed. + if changed { + ctx := types.NotifyCtx(context.Background(), "acl-update", "na") + api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ + Type: types.StateFullUpdate, + }) + } response := &v1.SetPolicyResponse{ Policy: updated.Data, diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 72ec4e4235..3858df9339 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" "strconv" "strings" @@ -56,6 +57,65 @@ func parseCabailityVersion(req *http.Request) (tailcfg.CapabilityVersion, error) return tailcfg.CapabilityVersion(clientCapabilityVersion), nil } +func (h *Headscale) handleVerifyRequest( + req *http.Request, +) (bool, error) { + body, err := io.ReadAll(req.Body) + if err != nil { + return false, fmt.Errorf("cannot read request body: %w", err) + } + + var derpAdmitClientRequest tailcfg.DERPAdmitClientRequest + if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil { + return false, fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err) + } + + nodes, err := h.db.ListNodes() + if err != nil { + return false, fmt.Errorf("cannot list nodes: %w", err) + } + + return nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic), nil +} + +// see https://github.com/tailscale/tailscale/blob/964282d34f06ecc06ce644769c66b0b31d118340/derp/derp_server.go#L1159, Derp use verifyClientsURL to verify whether a client is allowed to connect to the DERP server. +func (h *Headscale) VerifyHandler( + writer http.ResponseWriter, + req *http.Request, +) { + if req.Method != http.MethodPost { + http.Error(writer, "Wrong method", http.StatusMethodNotAllowed) + + return + } + log.Debug(). + Str("handler", "/verify"). + Msg("verify client") + + allow, err := h.handleVerifyRequest(req) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to verify client") + http.Error(writer, "Internal error", http.StatusInternalServerError) + } + + resp := tailcfg.DERPAdmitClientResponse{ + Allow: allow, + } + + writer.Header().Set("Content-Type", "application/json") + writer.WriteHeader(http.StatusOK) + err = json.NewEncoder(writer).Encode(resp) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } +} + // KeyHandler provides the Headscale pub key // Listens in /key. func (h *Headscale) KeyHandler( diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 3db1e159bd..51c96f8c87 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -55,6 +55,7 @@ type Mapper struct { cfg *types.Config derpMap *tailcfg.DERPMap notif *notifier.Notifier + polMan policy.PolicyManager uid string created time.Time @@ -71,6 +72,7 @@ func NewMapper( cfg *types.Config, derpMap *tailcfg.DERPMap, notif *notifier.Notifier, + polMan policy.PolicyManager, ) *Mapper { uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) @@ -79,6 +81,7 @@ func NewMapper( cfg: cfg, derpMap: derpMap, notif: notif, + polMan: polMan, uid: uid, created: time.Now(), @@ -153,10 +156,9 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { func (m *Mapper) fullMapResponse( node *types.Node, peers types.Nodes, - pol *policy.ACLPolicy, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { - resp, err := m.baseWithConfigMapResponse(node, pol, capVer) + resp, err := m.baseWithConfigMapResponse(node, capVer) if err != nil { return nil, err } @@ -164,11 +166,10 @@ func (m *Mapper) fullMapResponse( err = appendPeerChanges( resp, true, // full change - pol, + m.polMan, node, capVer, peers, - peers, m.cfg, ) if err != nil { @@ -182,7 +183,6 @@ func (m *Mapper) fullMapResponse( func (m *Mapper) FullMapResponse( mapRequest tailcfg.MapRequest, node *types.Node, - pol *policy.ACLPolicy, messages ...string, ) ([]byte, error) { peers, err := m.ListPeers(node.ID) @@ -190,7 +190,7 @@ func (m *Mapper) FullMapResponse( return nil, err } - resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version) + resp, err := m.fullMapResponse(node, peers, mapRequest.Version) if err != nil { return nil, err } @@ -204,10 +204,9 @@ func (m *Mapper) FullMapResponse( func (m *Mapper) ReadOnlyMapResponse( mapRequest tailcfg.MapRequest, node *types.Node, - pol *policy.ACLPolicy, messages ...string, ) ([]byte, error) { - resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version) + resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version) if err != nil { return nil, err } @@ -243,7 +242,6 @@ func (m *Mapper) PeerChangedResponse( node *types.Node, changed map[types.NodeID]bool, patches []*tailcfg.PeerChange, - pol *policy.ACLPolicy, messages ...string, ) ([]byte, error) { resp := m.baseMapResponse() @@ -273,10 +271,9 @@ func (m *Mapper) PeerChangedResponse( err = appendPeerChanges( &resp, false, // partial change - pol, + m.polMan, node, mapRequest.Version, - peers, changedNodes, m.cfg, ) @@ -303,7 +300,7 @@ func (m *Mapper) PeerChangedResponse( // Add the node itself, it might have changed, and particularly // if there are no patches or changes, this is a self update. - tailnode, err := tailNode(node, mapRequest.Version, pol, m.cfg) + tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.cfg) if err != nil { return nil, err } @@ -318,7 +315,6 @@ func (m *Mapper) PeerChangedPatchResponse( mapRequest tailcfg.MapRequest, node *types.Node, changed []*tailcfg.PeerChange, - pol *policy.ACLPolicy, ) ([]byte, error) { resp := m.baseMapResponse() resp.PeersChangedPatch = changed @@ -447,12 +443,11 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse { // incremental. func (m *Mapper) baseWithConfigMapResponse( node *types.Node, - pol *policy.ACLPolicy, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { resp := m.baseMapResponse() - tailnode, err := tailNode(node, capVer, pol, m.cfg) + tailnode, err := tailNode(node, capVer, m.polMan, m.cfg) if err != nil { return nil, err } @@ -505,34 +500,30 @@ func appendPeerChanges( resp *tailcfg.MapResponse, fullChange bool, - pol *policy.ACLPolicy, + polMan policy.PolicyManager, node *types.Node, capVer tailcfg.CapabilityVersion, - peers types.Nodes, changed types.Nodes, cfg *types.Config, ) error { - packetFilter, err := pol.CompileFilterRules(append(peers, node)) - if err != nil { - return err - } + filter := polMan.Filter() - sshPolicy, err := pol.CompileSSHPolicy(node, peers) + sshPolicy, err := polMan.SSHPolicy(node) if err != nil { return err } // If there are filter rules present, see if there are any nodes that cannot // access each-other at all and remove them from the peers. - if len(packetFilter) > 0 { - changed = policy.FilterNodesByACL(node, changed, packetFilter) + if len(filter) > 0 { + changed = policy.FilterNodesByACL(node, changed, filter) } profiles := generateUserProfiles(node, changed) dnsConfig := generateDNSConfig(cfg, node) - tailPeers, err := tailNodes(changed, capVer, pol, cfg) + tailPeers, err := tailNodes(changed, capVer, polMan, cfg) if err != nil { return err } @@ -557,7 +548,7 @@ func appendPeerChanges( // new PacketFilters field and "base" allows us to send a full update when we // have to send an empty list, avoiding the hack in the else block. resp.PacketFilters = map[string][]tailcfg.FilterRule{ - "base": policy.ReduceFilterRules(node, packetFilter), + "base": policy.ReduceFilterRules(node, filter), } } else { // This is a hack to avoid sending an empty list of packet filters. @@ -565,11 +556,11 @@ func appendPeerChanges( // be omitted, causing the client to consider it unchanged, keeping the // previous packet filter. Worst case, this can cause a node that previously // has access to a node to _not_ loose access if an empty (allow none) is sent. - reduced := policy.ReduceFilterRules(node, packetFilter) + reduced := policy.ReduceFilterRules(node, filter) if len(reduced) > 0 { resp.PacketFilter = reduced } else { - resp.PacketFilter = packetFilter + resp.PacketFilter = filter } } diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index d5f0128421..b1d61c52ef 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -159,6 +159,9 @@ func Test_fullMapResponse(t *testing.T) { lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC) expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC) + user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"} + user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"} + mini := &types.Node{ ID: 0, MachineKey: mustMK( @@ -173,8 +176,8 @@ func Test_fullMapResponse(t *testing.T) { IPv4: iap("100.64.0.1"), Hostname: "mini", GivenName: "mini", - UserID: 0, - User: types.User{Name: "mini"}, + UserID: user1.ID, + User: user1, ForcedTags: []string{}, AuthKey: &types.PreAuthKey{}, LastSeen: &lastSeen, @@ -254,8 +257,8 @@ func Test_fullMapResponse(t *testing.T) { IPv4: iap("100.64.0.2"), Hostname: "peer1", GivenName: "peer1", - UserID: 0, - User: types.User{Name: "mini"}, + UserID: user1.ID, + User: user1, ForcedTags: []string{}, LastSeen: &lastSeen, Approved: true, @@ -310,8 +313,8 @@ func Test_fullMapResponse(t *testing.T) { IPv4: iap("100.64.0.3"), Hostname: "peer2", GivenName: "peer2", - UserID: 1, - User: types.User{Name: "peer2"}, + UserID: user2.ID, + User: user2, ForcedTags: []string{}, LastSeen: &lastSeen, Expiry: &expire, @@ -461,17 +464,19 @@ func Test_fullMapResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node)) + mappy := NewMapper( nil, tt.cfg, tt.derpMap, nil, + polMan, ) got, err := mappy.fullMapResponse( tt.node, tt.peers, - tt.pol, 0, ) diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index a093ea99fa..cbe05aa752 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -14,7 +14,7 @@ import ( func tailNodes( nodes types.Nodes, capVer tailcfg.CapabilityVersion, - pol *policy.ACLPolicy, + polMan policy.PolicyManager, cfg *types.Config, ) ([]*tailcfg.Node, error) { tNodes := make([]*tailcfg.Node, len(nodes)) @@ -23,7 +23,7 @@ func tailNodes( node, err := tailNode( node, capVer, - pol, + polMan, cfg, ) if err != nil { @@ -40,7 +40,7 @@ func tailNodes( func tailNode( node *types.Node, capVer tailcfg.CapabilityVersion, - pol *policy.ACLPolicy, + polMan policy.PolicyManager, cfg *types.Config, ) (*tailcfg.Node, error) { addrs := node.Prefixes() @@ -81,7 +81,7 @@ func tailNode( return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err) } - tags, _ := pol.TagsOfNode(node) + tags := polMan.Tags(node) tags = lo.Uniq(append(tags, node.ForcedTags...)) tNode := tailcfg.Node{ diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 59c8b6463e..a49e046994 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -185,6 +185,7 @@ func TestTailNode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{}, types.Nodes{tt.node}) cfg := &types.Config{ BaseDomain: tt.baseDomain, DNSConfig: tt.dnsConfig, @@ -193,7 +194,7 @@ func TestTailNode(t *testing.T) { got, err := tailNode( tt.node, 0, - tt.pol, + polMan, cfg, ) @@ -246,7 +247,7 @@ func TestNodeExpiry(t *testing.T) { tn, err := tailNode( node, 0, - &policy.ACLPolicy{}, + &policy.PolicyManagerV1{}, &types.Config{}, ) if err != nil { diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 487fdd266b..b4251f0a8e 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -17,6 +17,7 @@ import ( "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/templates" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -52,6 +53,7 @@ type AuthProviderOIDC struct { registrationCache *zcache.Cache[string, key.MachinePublic] notifier *notifier.Notifier ipAlloc *db.IPAllocator + polMan policy.PolicyManager oidcProvider *oidc.Provider oauth2Config *oauth2.Config @@ -64,6 +66,7 @@ func NewAuthProviderOIDC( db *db.HSDatabase, notif *notifier.Notifier, ipAlloc *db.IPAllocator, + polMan policy.PolicyManager, ) (*AuthProviderOIDC, error) { var err error // grab oidc config if it hasn't been already @@ -95,6 +98,7 @@ func NewAuthProviderOIDC( registrationCache: registrationCache, notifier: notif, ipAlloc: ipAlloc, + polMan: polMan, oidcProvider: oidcProvider, oauth2Config: oauth2Config, @@ -412,31 +416,53 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( ) (*types.User, error) { var user *types.User var err error - user, err = a.db.GetUserByOIDCIdentifier(claims.Sub) + user, err = a.db.GetUserByOIDCIdentifier(claims.Identifier()) if err != nil && !errors.Is(err, db.ErrUserNotFound) { return nil, fmt.Errorf("creating or updating user: %w", err) } // This check is for legacy, if the user cannot be found by the OIDC identifier // look it up by username. This should only be needed once. - if user == nil { - user, err = a.db.GetUserByName(claims.Username) - if err != nil && !errors.Is(err, db.ErrUserNotFound) { - return nil, fmt.Errorf("creating or updating user: %w", err) - } + // This branch will presist for a number of versions after the OIDC migration and + // then be removed following a deprecation. + // TODO(kradalby): Remove when strip_email_domain and migration is removed + // after #2170 is cleaned up. + if a.cfg.MapLegacyUsers && user == nil { + log.Trace().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user not found by OIDC identifier, looking up by username") + if oldUsername, err := getUserName(claims, a.cfg.StripEmaildomain); err == nil { + log.Trace().Str("old_username", oldUsername).Str("sub", claims.Sub).Msg("found username") + user, err = a.db.GetUserByName(oldUsername) + if err != nil && !errors.Is(err, db.ErrUserNotFound) { + return nil, fmt.Errorf("getting user: %w", err) + } - // if the user is still not found, create a new empty user. - if user == nil { - user = &types.User{} + // If the user exists, but it already has a provider identifier (OIDC sub), create a new user. + // This is to prevent users that have already been migrated to the new OIDC format + // to be updated with the new OIDC identifier inexplicitly which might be the cause of an + // account takeover. + if user != nil && user.ProviderIdentifier.Valid { + log.Info().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user found by username, but has provider identifier, creating new user.") + user = &types.User{} + } } } + // if the user is still not found, create a new empty user. + if user == nil { + user = &types.User{} + } + user.FromClaim(claims) err = a.db.DB.Save(user).Error if err != nil { return nil, fmt.Errorf("creating or updating user: %w", err) } + err = usersChangedHook(a.db, a.polMan, a.notifier) + if err != nil { + return nil, fmt.Errorf("updating resources using user: %w", err) + } + return user, nil } @@ -459,7 +485,33 @@ func (a *AuthProviderOIDC) registerNode( ) if err != nil { return nil, fmt.Errorf("could not register node: %w", err) + } + + err = nodesChangedHook(a.db, a.polMan, a.notifier) + if err != nil { + return nil, fmt.Errorf("updating resources using node: %w", err) } return node, nil } + +// TODO(kradalby): Reintroduce when strip_email_domain is removed +// after #2170 is cleaned up +// DEPRECATED: DO NOT USE +func getUserName( + claims *types.OIDCClaims, + stripEmaildomain bool, +) (string, error) { + if !claims.EmailVerified { + return "", fmt.Errorf("email not verified") + } + userName, err := util.NormalizeToFQDNRules( + claims.Email, + stripEmaildomain, + ) + if err != nil { + return "", err + } + + return userName, nil +} diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index ff73985ba3..5848ec333c 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -137,20 +137,21 @@ func GenerateFilterAndSSHRulesForTests( policy *ACLPolicy, node *types.Node, peers types.Nodes, + users []types.User, ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { // If there is no policy defined, we default to allow all if policy == nil { return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil } - rules, err := policy.CompileFilterRules(append(peers, node)) + rules, err := policy.CompileFilterRules(users, append(peers, node)) if err != nil { return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules") - sshPolicy, err := policy.CompileSSHPolicy(node, peers) + sshPolicy, err := policy.CompileSSHPolicy(node, users, peers) if err != nil { return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } @@ -161,6 +162,7 @@ func GenerateFilterAndSSHRulesForTests( // CompileFilterRules takes a set of nodes and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. func (pol *ACLPolicy) CompileFilterRules( + users []types.User, nodes types.Nodes, ) ([]tailcfg.FilterRule, error) { if pol == nil { @@ -176,9 +178,14 @@ func (pol *ACLPolicy) CompileFilterRules( var srcIPs []string for srcIndex, src := range acl.Sources { - srcs, err := pol.expandSource(src, nodes) + srcs, err := pol.expandSource(src, users, nodes) if err != nil { - return nil, fmt.Errorf("parsing policy, acl index: %d->%d: %w", index, srcIndex, err) + return nil, fmt.Errorf( + "parsing policy, acl index: %d->%d: %w", + index, + srcIndex, + err, + ) } srcIPs = append(srcIPs, srcs...) } @@ -197,6 +204,7 @@ func (pol *ACLPolicy) CompileFilterRules( expanded, err := pol.ExpandAlias( nodes, + users, alias, ) if err != nil { @@ -281,6 +289,7 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F func (pol *ACLPolicy) CompileSSHPolicy( node *types.Node, + users []types.User, peers types.Nodes, ) (*tailcfg.SSHPolicy, error) { if pol == nil { @@ -312,7 +321,7 @@ func (pol *ACLPolicy) CompileSSHPolicy( for index, sshACL := range pol.SSHs { var dest netipx.IPSetBuilder for _, src := range sshACL.Destinations { - expanded, err := pol.ExpandAlias(append(peers, node), src) + expanded, err := pol.ExpandAlias(append(peers, node), users, src) if err != nil { return nil, err } @@ -335,12 +344,21 @@ func (pol *ACLPolicy) CompileSSHPolicy( case "check": checkAction, err := sshCheckAction(sshACL.CheckPeriod) if err != nil { - return nil, fmt.Errorf("parsing SSH policy, parsing check duration, index: %d: %w", index, err) + return nil, fmt.Errorf( + "parsing SSH policy, parsing check duration, index: %d: %w", + index, + err, + ) } else { action = *checkAction } default: - return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", sshACL.Action, index, err) + return nil, fmt.Errorf( + "parsing SSH policy, unknown action %q, index: %d: %w", + sshACL.Action, + index, + err, + ) } principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) @@ -363,6 +381,7 @@ func (pol *ACLPolicy) CompileSSHPolicy( } else { expandedSrcs, err := pol.ExpandAlias( peers, + users, rawSrc, ) if err != nil { @@ -512,9 +531,10 @@ func parseProtocol(protocol string) ([]int, bool, error) { // with the given src alias. func (pol *ACLPolicy) expandSource( src string, + users []types.User, nodes types.Nodes, ) ([]string, error) { - ipSet, err := pol.ExpandAlias(nodes, src) + ipSet, err := pol.ExpandAlias(nodes, users, src) if err != nil { return []string{}, err } @@ -538,6 +558,7 @@ func (pol *ACLPolicy) expandSource( // and transform these in IPAddresses. func (pol *ACLPolicy) ExpandAlias( nodes types.Nodes, + users []types.User, alias string, ) (*netipx.IPSet, error) { if isWildcard(alias) { @@ -552,12 +573,12 @@ func (pol *ACLPolicy) ExpandAlias( // if alias is a group if isGroup(alias) { - return pol.expandIPsFromGroup(alias, nodes) + return pol.expandIPsFromGroup(alias, users, nodes) } // if alias is a tag if isTag(alias) { - return pol.expandIPsFromTag(alias, nodes) + return pol.expandIPsFromTag(alias, users, nodes) } if isAutoGroup(alias) { @@ -565,7 +586,7 @@ func (pol *ACLPolicy) ExpandAlias( } // if alias is a user - if ips, err := pol.expandIPsFromUser(alias, nodes); ips != nil { + if ips, err := pol.expandIPsFromUser(alias, users, nodes); ips != nil { return ips, err } @@ -574,7 +595,7 @@ func (pol *ACLPolicy) ExpandAlias( if h, ok := pol.Hosts[alias]; ok { log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") - return pol.ExpandAlias(nodes, h.String()) + return pol.ExpandAlias(nodes, users, h.String()) } // if alias is an IP @@ -599,7 +620,7 @@ func (pol *ACLPolicy) ExpandAlias( // TODO(kradalby): It is quite hard to understand what this function is doing, // it seems like it trying to ensure that we dont include nodes that are tagged // when we look up the nodes owned by a user. -// This should be refactored to be more clear as part of the Tags work in #1369 +// This should be refactored to be more clear as part of the Tags work in #1369. func excludeCorrectlyTaggedNodes( aclPolicy *ACLPolicy, nodes types.Nodes, @@ -751,16 +772,17 @@ func (pol *ACLPolicy) expandUsersFromGroup( func (pol *ACLPolicy) expandIPsFromGroup( group string, + users []types.User, nodes types.Nodes, ) (*netipx.IPSet, error) { var build netipx.IPSetBuilder - users, err := pol.expandUsersFromGroup(group) + userTokens, err := pol.expandUsersFromGroup(group) if err != nil { return &netipx.IPSet{}, err } - for _, user := range users { - filteredNodes := filterNodesByUser(nodes, user) + for _, user := range userTokens { + filteredNodes := filterNodesByUser(nodes, users, user) for _, node := range filteredNodes { node.AppendToIPSet(&build) } @@ -771,6 +793,7 @@ func (pol *ACLPolicy) expandIPsFromGroup( func (pol *ACLPolicy) expandIPsFromTag( alias string, + users []types.User, nodes types.Nodes, ) (*netipx.IPSet, error) { var build netipx.IPSetBuilder @@ -803,7 +826,7 @@ func (pol *ACLPolicy) expandIPsFromTag( // filter out nodes per tag owner for _, user := range owners { - nodes := filterNodesByUser(nodes, user) + nodes := filterNodesByUser(nodes, users, user) for _, node := range nodes { if node.Hostinfo == nil { continue @@ -820,11 +843,12 @@ func (pol *ACLPolicy) expandIPsFromTag( func (pol *ACLPolicy) expandIPsFromUser( user string, + users []types.User, nodes types.Nodes, ) (*netipx.IPSet, error) { var build netipx.IPSetBuilder - filteredNodes := filterNodesByUser(nodes, user) + filteredNodes := filterNodesByUser(nodes, users, user) filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user) // shortcurcuit if we have no nodes to get ips from. @@ -953,10 +977,43 @@ func (pol *ACLPolicy) TagsOfNode( return validTags, invalidTags } -func filterNodesByUser(nodes types.Nodes, user string) types.Nodes { +// filterNodesByUser returns a list of nodes that match the given userToken from a +// policy. +// Matching nodes are determined by first matching the user token to a user by checking: +// - If it is an ID that mactches the user database ID +// - It is the Provider Identifier from OIDC +// - It matches the username or email of a user +// +// If the token matches more than one user, zero nodes will returned. +func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) types.Nodes { var out types.Nodes + + var potentialUsers []types.User + for _, user := range users { + if user.ProviderIdentifier.Valid && user.ProviderIdentifier.String == userToken { + // If a user is matching with a known unique field, + // disgard all other users and only keep the current + // user. + potentialUsers = []types.User{user} + + break + } + if user.Email == userToken { + potentialUsers = append(potentialUsers, user) + } + if user.Name == userToken { + potentialUsers = append(potentialUsers, user) + } + } + + if len(potentialUsers) != 1 { + return nil + } + + user := potentialUsers[0] + for _, node := range nodes { - if node.User.Username() == user { + if node.User.ID == user.ID { out = append(out, node) } } @@ -977,10 +1034,7 @@ func FilterNodesByACL( continue } - log.Printf("Checking if %s can access %s", node.Hostname, peer.Hostname) - if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) { - log.Printf("CAN ACCESS %s can access %s", node.Hostname, peer.Hostname) result = append(result, peer) } } diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index 1c6e4de863..b00cec12df 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -1,9 +1,12 @@ package policy import ( + "database/sql" "errors" + "math/rand/v2" "net/netip" "slices" + "sort" "testing" "github.com/google/go-cmp/cmp" @@ -11,9 +14,10 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/spf13/viper" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go4.org/netipx" "gopkg.in/check.v1" + "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" ) @@ -375,18 +379,24 @@ func TestParsing(t *testing.T) { return } - rules, err := pol.CompileFilterRules(types.Nodes{ - &types.Node{ - IPv4: iap("100.100.100.100"), + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "testuser", + } + rules, err := pol.CompileFilterRules( + []types.User{ + user, }, - &types.Node{ - IPv4: iap("200.200.200.200"), - User: types.User{ - Name: "testuser", + types.Nodes{ + &types.Node{ + IPv4: iap("100.100.100.100"), }, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }) + &types.Node{ + IPv4: iap("200.200.200.200"), + User: user, + Hostinfo: &tailcfg.Hostinfo{}, + }, + }) if (err != nil) != tt.wantErr { t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) @@ -533,7 +543,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { c.Assert(pol.ACLs, check.HasLen, 6) c.Assert(err, check.IsNil) - rules, err := pol.CompileFilterRules(types.Nodes{}) + rules, err := pol.CompileFilterRules([]types.User{}, types.Nodes{}) c.Assert(err, check.NotNil) c.Assert(rules, check.IsNil) } @@ -549,7 +559,12 @@ func (s *Suite) TestInvalidAction(c *check.C) { }, }, } - _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests( + pol, + &types.Node{}, + types.Nodes{}, + []types.User{}, + ) c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) } @@ -568,7 +583,12 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) { }, }, } - _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests( + pol, + &types.Node{}, + types.Nodes{}, + []types.User{}, + ) c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) } @@ -584,7 +604,12 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) { }, } - _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests( + pol, + &types.Node{}, + types.Nodes{}, + []types.User{}, + ) c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) } @@ -860,7 +885,25 @@ func Test_expandPorts(t *testing.T) { } } -func Test_listNodesInUser(t *testing.T) { +func Test_filterNodesByUser(t *testing.T) { + users := []types.User{ + {Model: gorm.Model{ID: 1}, Name: "marc"}, + {Model: gorm.Model{ID: 2}, Name: "joe", Email: "joe@headscale.net"}, + { + Model: gorm.Model{ID: 3}, + Name: "mikael", + Email: "mikael@headscale.net", + ProviderIdentifier: sql.NullString{String: "http://oidc.org/1234", Valid: true}, + }, + {Model: gorm.Model{ID: 4}, Name: "mikael2", Email: "mikael@headscale.net"}, + {Model: gorm.Model{ID: 5}, Name: "mikael", Email: "mikael2@headscale.net"}, + {Model: gorm.Model{ID: 6}, Name: "http://oidc.org/1234", Email: "mikael@headscale.net"}, + {Model: gorm.Model{ID: 7}, Name: "1"}, + {Model: gorm.Model{ID: 8}, Name: "alex", Email: "alex@headscale.net"}, + {Model: gorm.Model{ID: 9}, Name: "alex@headscale.net"}, + {Model: gorm.Model{ID: 10}, Email: "http://oidc.org/1234"}, + } + type args struct { nodes types.Nodes user string @@ -874,50 +917,258 @@ func Test_listNodesInUser(t *testing.T) { name: "1 node in user", args: args{ nodes: types.Nodes{ - &types.Node{User: types.User{Name: "joe"}}, + &types.Node{User: users[1]}, }, user: "joe", }, want: types.Nodes{ - &types.Node{User: types.User{Name: "joe"}}, + &types.Node{User: users[1]}, }, }, { name: "3 nodes, 2 in user", args: args{ nodes: types.Nodes{ - &types.Node{ID: 1, User: types.User{Name: "joe"}}, - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, }, user: "marc", }, want: types.Nodes{ - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, }, }, { name: "5 nodes, 0 in user", args: args{ nodes: types.Nodes{ - &types.Node{ID: 1, User: types.User{Name: "joe"}}, - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, - &types.Node{ID: 4, User: types.User{Name: "marc"}}, - &types.Node{ID: 5, User: types.User{Name: "marc"}}, + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, + &types.Node{ID: 4, User: users[0]}, + &types.Node{ID: 5, User: users[0]}, }, user: "mickael", }, want: nil, }, + { + name: "match-by-provider-ident", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "http://oidc.org/1234", + }, + want: types.Nodes{ + &types.Node{ID: 2, User: users[2]}, + }, + }, + { + name: "match-by-email", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[2]}, + &types.Node{ID: 8, User: users[7]}, + }, + user: "joe@headscale.net", + }, + want: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + }, + }, + { + name: "multi-match-is-zero", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[2]}, + &types.Node{ID: 3, User: users[3]}, + }, + user: "mikael@headscale.net", + }, + want: nil, + }, + { + name: "multi-email-first-match-is-zero", + args: args{ + nodes: types.Nodes{ + // First match email, then provider id + &types.Node{ID: 3, User: users[3]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "mikael@headscale.net", + }, + want: nil, + }, + { + name: "multi-username-first-match-is-zero", + args: args{ + nodes: types.Nodes{ + // First match username, then provider id + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "mikael", + }, + want: nil, + }, + { + name: "all-users-duplicate-username-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "mikael", + }, + want: nil, + }, + { + name: "all-users-unique-username-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "marc", + }, + want: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + }, + }, + { + name: "all-users-no-username-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "not-working", + }, + want: nil, + }, + { + name: "all-users-duplicate-email-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "mikael@headscale.net", + }, + want: nil, + }, + { + name: "all-users-duplicate-email-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + &types.Node{ID: 8, User: users[7]}, + }, + user: "joe@headscale.net", + }, + want: types.Nodes{ + &types.Node{ID: 2, User: users[1]}, + }, + }, + { + name: "email-as-username-duplicate", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[7]}, + &types.Node{ID: 2, User: users[8]}, + }, + user: "alex@headscale.net", + }, + want: nil, + }, + { + name: "all-users-no-email-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "not-working@headscale.net", + }, + want: nil, + }, + { + name: "all-users-provider-id-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + &types.Node{ID: 6, User: users[5]}, + }, + user: "http://oidc.org/1234", + }, + want: types.Nodes{ + &types.Node{ID: 3, User: users[2]}, + }, + }, + { + name: "all-users-no-provider-id-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + &types.Node{ID: 6, User: users[5]}, + }, + user: "http://oidc.org/4321", + }, + want: nil, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got := filterNodesByUser(test.args.nodes, test.args.user) - - if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { - t.Errorf("listNodesInUser() = (-want +got):\n%s", diff) + for range 1000 { + ns := test.args.nodes + rand.Shuffle(len(ns), func(i, j int) { + ns[i], ns[j] = ns[j], ns[i] + }) + us := users + rand.Shuffle(len(us), func(i, j int) { + us[i], us[j] = us[j], us[i] + }) + got := filterNodesByUser(ns, us, test.args.user) + sort.Slice(got, func(i, j int) bool { + return got[i].ID < got[j].ID + }) + + if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { + t.Errorf("filterNodesByUser() = (-want +got):\n%s", diff) + } } }) } @@ -940,6 +1191,12 @@ func Test_expandAlias(t *testing.T) { return s } + users := []types.User{ + {Model: gorm.Model{ID: 1}, Name: "joe"}, + {Model: gorm.Model{ID: 2}, Name: "marc"}, + {Model: gorm.Model{ID: 3}, Name: "mickael"}, + } + type field struct { pol ACLPolicy } @@ -989,19 +1246,19 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1022,19 +1279,19 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1185,7 +1442,7 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", @@ -1194,7 +1451,7 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", @@ -1203,11 +1460,11 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, + User: users[0], }, }, }, @@ -1260,21 +1517,21 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], ForcedTags: []string{"tag:hr-webserver"}, }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], ForcedTags: []string{"tag:hr-webserver"}, }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1295,12 +1552,12 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], ForcedTags: []string{"tag:hr-webserver"}, }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", @@ -1309,11 +1566,11 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1350,12 +1607,12 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], Hostinfo: &tailcfg.Hostinfo{}, }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{}, }, }, @@ -1368,6 +1625,7 @@ func Test_expandAlias(t *testing.T) { t.Run(test.name, func(t *testing.T) { got, err := test.field.pol.ExpandAlias( test.args.nodes, + users, test.args.alias, ) if (err != nil) != test.wantErr { @@ -1715,6 +1973,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.field.pol.CompileFilterRules( + []types.User{}, tt.args.nodes, ) if (err != nil) != tt.wantErr { @@ -1824,16 +2083,31 @@ func TestTheInternet(t *testing.T) { for i := range internetPrefs { if internetPrefs[i].String() != hsExitNodeDest[i].IP { - t.Errorf("prefix from internet set %q != hsExit list %q", internetPrefs[i].String(), hsExitNodeDest[i].IP) + t.Errorf( + "prefix from internet set %q != hsExit list %q", + internetPrefs[i].String(), + hsExitNodeDest[i].IP, + ) } } if len(internetPrefs) != len(hsExitNodeDest) { - t.Fatalf("expected same length of prefixes, internet: %d, hsExit: %d", len(internetPrefs), len(hsExitNodeDest)) + t.Fatalf( + "expected same length of prefixes, internet: %d, hsExit: %d", + len(internetPrefs), + len(hsExitNodeDest), + ) } } func TestReduceFilterRules(t *testing.T) { + users := []types.User{ + {Model: gorm.Model{ID: 1}, Name: "mickael"}, + {Model: gorm.Model{ID: 2}, Name: "user1"}, + {Model: gorm.Model{ID: 3}, Name: "user2"}, + {Model: gorm.Model{ID: 4}, Name: "user100"}, + } + tests := []struct { name string node *types.Node @@ -1855,13 +2129,13 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), - User: types.User{Name: "mickael"}, + User: users[0], }, peers: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - User: types.User{Name: "mickael"}, + User: users[0], }, }, want: []tailcfg.FilterRule{}, @@ -1888,7 +2162,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{ netip.MustParsePrefix("10.33.0.0/16"), @@ -1899,7 +2173,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -1967,19 +2241,19 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, peers: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, // "internal" exit node &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -2026,17 +2300,22 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ { - SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, DstPorts: []tailcfg.NetPortRange{ { IP: "100.64.0.100/32", @@ -2049,7 +2328,12 @@ func TestReduceFilterRules(t *testing.T) { }, }, { - SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, DstPorts: hsExitNodeDest, }, }, @@ -2113,7 +2397,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -2122,17 +2406,22 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ { - SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, DstPorts: []tailcfg.NetPortRange{ { IP: "100.64.0.100/32", @@ -2145,7 +2434,12 @@ func TestReduceFilterRules(t *testing.T) { }, }, { - SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, DstPorts: []tailcfg.NetPortRange{ {IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny}, {IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny}, @@ -2215,26 +2509,34 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ - RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, + RoutableIPs: []netip.Prefix{ + netip.MustParsePrefix("8.0.0.0/16"), + netip.MustParsePrefix("16.0.0.0/16"), + }, }, }, peers: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ { - SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, DstPorts: []tailcfg.NetPortRange{ { IP: "100.64.0.100/32", @@ -2247,7 +2549,12 @@ func TestReduceFilterRules(t *testing.T) { }, }, { - SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, DstPorts: []tailcfg.NetPortRange{ { IP: "8.0.0.0/8", @@ -2292,26 +2599,34 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ - RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, + RoutableIPs: []netip.Prefix{ + netip.MustParsePrefix("8.0.0.0/8"), + netip.MustParsePrefix("16.0.0.0/8"), + }, }, }, peers: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ { - SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, DstPorts: []tailcfg.NetPortRange{ { IP: "100.64.0.100/32", @@ -2324,7 +2639,12 @@ func TestReduceFilterRules(t *testing.T) { }, }, { - SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, DstPorts: []tailcfg.NetPortRange{ { IP: "8.0.0.0/16", @@ -2362,7 +2682,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, }, @@ -2372,7 +2692,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2400,6 +2720,7 @@ func TestReduceFilterRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, _ := tt.pol.CompileFilterRules( + users, append(tt.peers, tt.node), ) @@ -3299,7 +3620,11 @@ func TestSSHRules(t *testing.T) { SSHUsers: map[string]string{ "autogroup:nonroot": "=", }, - Action: &tailcfg.SSHAction{Accept: true, AllowAgentForwarding: true, AllowLocalPortForwarding: true}, + Action: &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + }, }, { SSHUsers: map[string]string{ @@ -3310,7 +3635,11 @@ func TestSSHRules(t *testing.T) { Any: true, }, }, - Action: &tailcfg.SSHAction{Accept: true, AllowAgentForwarding: true, AllowLocalPortForwarding: true}, + Action: &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + }, }, { Principals: []*tailcfg.SSHPrincipal{ @@ -3321,7 +3650,11 @@ func TestSSHRules(t *testing.T) { SSHUsers: map[string]string{ "autogroup:nonroot": "=", }, - Action: &tailcfg.SSHAction{Accept: true, AllowAgentForwarding: true, AllowLocalPortForwarding: true}, + Action: &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + }, }, { SSHUsers: map[string]string{ @@ -3332,7 +3665,11 @@ func TestSSHRules(t *testing.T) { Any: true, }, }, - Action: &tailcfg.SSHAction{Accept: true, AllowAgentForwarding: true, AllowLocalPortForwarding: true}, + Action: &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + }, }, }}, }, @@ -3391,8 +3728,8 @@ func TestSSHRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers) - assert.NoError(t, err) + got, err := tt.pol.CompileSSHPolicy(&tt.node, []types.User{}, tt.peers) + require.NoError(t, err) if diff := cmp.Diff(tt.want, got); diff != "" { t.Errorf("TestSSHRules() unexpected result (-want +got):\n%s", diff) @@ -3474,14 +3811,17 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { RequestTags: []string{"tag:test"}, } + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } + node := &types.Node{ - ID: 0, - Hostname: "testnodes", - IPv4: iap("100.64.0.1"), - UserID: 0, - User: types.User{ - Name: "user1", - }, + ID: 0, + Hostname: "testnodes", + IPv4: iap("100.64.0.1"), + UserID: 0, + User: user, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, } @@ -3498,8 +3838,8 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) - assert.NoError(t, err) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{user}) + require.NoError(t, err) want := []tailcfg.FilterRule{ { @@ -3532,7 +3872,8 @@ func TestInvalidTagValidUser(t *testing.T) { IPv4: iap("100.64.0.1"), UserID: 1, User: types.User{ - Name: "user1", + Model: gorm.Model{ID: 1}, + Name: "user1", }, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, @@ -3549,8 +3890,13 @@ func TestInvalidTagValidUser(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) - assert.NoError(t, err) + got, _, err := GenerateFilterAndSSHRulesForTests( + pol, + node, + types.Nodes{}, + []types.User{node.User}, + ) + require.NoError(t, err) want := []tailcfg.FilterRule{ { @@ -3583,7 +3929,8 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { IPv4: iap("100.64.0.1"), UserID: 1, User: types.User{ - Name: "user1", + Model: gorm.Model{ID: 1}, + Name: "user1", }, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, @@ -3608,8 +3955,13 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { // c.Assert(rules[0].DstPorts, check.HasLen, 1) // c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) - assert.NoError(t, err) + got, _, err := GenerateFilterAndSSHRulesForTests( + pol, + node, + types.Nodes{}, + []types.User{node.User}, + ) + require.NoError(t, err) want := []tailcfg.FilterRule{ { @@ -3637,15 +3989,17 @@ func TestValidTagInvalidUser(t *testing.T) { Hostname: "webserver", RequestTags: []string{"tag:webapp"}, } + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } node := &types.Node{ - ID: 1, - Hostname: "webserver", - IPv4: iap("100.64.0.1"), - UserID: 1, - User: types.User{ - Name: "user1", - }, + ID: 1, + Hostname: "webserver", + IPv4: iap("100.64.0.1"), + UserID: 1, + User: user, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, } @@ -3656,13 +4010,11 @@ func TestValidTagInvalidUser(t *testing.T) { } nodes2 := &types.Node{ - ID: 2, - Hostname: "user", - IPv4: iap("100.64.0.2"), - UserID: 1, - User: types.User{ - Name: "user1", - }, + ID: 2, + Hostname: "user", + IPv4: iap("100.64.0.2"), + UserID: 1, + User: user, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo2, } @@ -3678,8 +4030,13 @@ func TestValidTagInvalidUser(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}) - assert.NoError(t, err) + got, _, err := GenerateFilterAndSSHRulesForTests( + pol, + node, + types.Nodes{nodes2}, + []types.User{user}, + ) + require.NoError(t, err) want := []tailcfg.FilterRule{ { diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go new file mode 100644 index 0000000000..7dbaed33c9 --- /dev/null +++ b/hscontrol/policy/pm.go @@ -0,0 +1,181 @@ +package policy + +import ( + "fmt" + "io" + "net/netip" + "os" + "sync" + + "github.com/juanfont/headscale/hscontrol/types" + "go4.org/netipx" + "tailscale.com/tailcfg" + "tailscale.com/util/deephash" +) + +type PolicyManager interface { + Filter() []tailcfg.FilterRule + SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error) + Tags(*types.Node) []string + ApproversForRoute(netip.Prefix) []string + ExpandAlias(string) (*netipx.IPSet, error) + SetPolicy([]byte) (bool, error) + SetUsers(users []types.User) (bool, error) + SetNodes(nodes types.Nodes) (bool, error) +} + +func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) { + policyFile, err := os.Open(path) + if err != nil { + return nil, err + } + defer policyFile.Close() + + policyBytes, err := io.ReadAll(policyFile) + if err != nil { + return nil, err + } + + return NewPolicyManager(policyBytes, users, nodes) +} + +func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) { + var pol *ACLPolicy + var err error + if polB != nil && len(polB) > 0 { + pol, err = LoadACLPolicyFromBytes(polB) + if err != nil { + return nil, fmt.Errorf("parsing policy: %w", err) + } + } + + pm := PolicyManagerV1{ + pol: pol, + users: users, + nodes: nodes, + } + + _, err = pm.updateLocked() + if err != nil { + return nil, err + } + + return &pm, nil +} + +func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nodes) (PolicyManager, error) { + pm := PolicyManagerV1{ + pol: pol, + users: users, + nodes: nodes, + } + + _, err := pm.updateLocked() + if err != nil { + return nil, err + } + + return &pm, nil +} + +type PolicyManagerV1 struct { + mu sync.Mutex + pol *ACLPolicy + + users []types.User + nodes types.Nodes + + filterHash deephash.Sum + filter []tailcfg.FilterRule +} + +// updateLocked updates the filter rules based on the current policy and nodes. +// It must be called with the lock held. +func (pm *PolicyManagerV1) updateLocked() (bool, error) { + filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes) + if err != nil { + return false, fmt.Errorf("compiling filter rules: %w", err) + } + + filterHash := deephash.Hash(&filter) + if filterHash == pm.filterHash { + return false, nil + } + + pm.filter = filter + pm.filterHash = filterHash + + return true, nil +} + +func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule { + pm.mu.Lock() + defer pm.mu.Unlock() + return pm.filter +} + +func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + + return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes) +} + +func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) { + pol, err := LoadACLPolicyFromBytes(polB) + if err != nil { + return false, fmt.Errorf("parsing policy: %w", err) + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + pm.pol = pol + + return pm.updateLocked() +} + +// SetUsers updates the users in the policy manager and updates the filter rules. +func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + + pm.users = users + return pm.updateLocked() +} + +// SetNodes updates the nodes in the policy manager and updates the filter rules. +func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + pm.nodes = nodes + return pm.updateLocked() +} + +func (pm *PolicyManagerV1) Tags(node *types.Node) []string { + if pm == nil { + return nil + } + + tags, _ := pm.pol.TagsOfNode(node) + return tags +} + +func (pm *PolicyManagerV1) ApproversForRoute(route netip.Prefix) []string { + // TODO(kradalby): This can be a parse error of the address in the policy, + // in the new policy this will be typed and not a problem, in this policy + // we will just return empty list + if pm.pol == nil { + return nil + } + approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route) + return approvers +} + +func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) { + ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, alias) + if err != nil { + return nil, err + } + return ips, nil +} diff --git a/hscontrol/policy/pm_test.go b/hscontrol/policy/pm_test.go new file mode 100644 index 0000000000..24b78e4d28 --- /dev/null +++ b/hscontrol/policy/pm_test.go @@ -0,0 +1,158 @@ +package policy + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/tailcfg" +) + +func TestPolicySetChange(t *testing.T) { + users := []types.User{ + { + Model: gorm.Model{ID: 1}, + Name: "testuser", + }, + } + tests := []struct { + name string + users []types.User + nodes types.Nodes + policy []byte + wantUsersChange bool + wantNodesChange bool + wantPolicyChange bool + wantFilter []tailcfg.FilterRule + }{ + { + name: "set-nodes", + nodes: types.Nodes{ + { + IPv4: iap("100.64.0.2"), + User: users[0], + }, + }, + wantNodesChange: false, + wantFilter: []tailcfg.FilterRule{ + { + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + name: "set-users", + users: users, + wantUsersChange: false, + wantFilter: []tailcfg.FilterRule{ + { + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + name: "set-users-and-node", + users: users, + nodes: types.Nodes{ + { + IPv4: iap("100.64.0.2"), + User: users[0], + }, + }, + wantUsersChange: false, + wantNodesChange: true, + wantFilter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + name: "set-policy", + policy: []byte(` +{ +"acls": [ + { + "action": "accept", + "src": [ + "100.64.0.61", + ], + "dst": [ + "100.64.0.62:*", + ], + }, + ], +} + `), + wantPolicyChange: true, + wantFilter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.61/32"}, + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pol := ` +{ + "groups": { + "group:example": [ + "testuser", + ], + }, + + "hosts": { + "host-1": "100.64.0.1", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "group:example", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +` + pm, err := NewPolicyManager([]byte(pol), []types.User{}, types.Nodes{}) + require.NoError(t, err) + + if tt.policy != nil { + change, err := pm.SetPolicy(tt.policy) + require.NoError(t, err) + + assert.Equal(t, tt.wantPolicyChange, change) + } + + if tt.users != nil { + change, err := pm.SetUsers(tt.users) + require.NoError(t, err) + + assert.Equal(t, tt.wantUsersChange, change) + } + + if tt.nodes != nil { + change, err := pm.SetNodes(tt.nodes) + require.NoError(t, err) + + assert.Equal(t, tt.wantNodesChange, change) + } + + if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" { + t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/poll.go b/hscontrol/poll.go index a8ae01f44f..e6047d4550 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -286,7 +286,7 @@ func (m *mapSession) serveLongPoll() { switch update.Type { case types.StateFullUpdate: m.tracef("Sending Full MapResponse") - data, err = m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) + data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) case types.StatePeerChanged: changed := make(map[types.NodeID]bool, len(update.ChangeNodes)) @@ -296,12 +296,12 @@ func (m *mapSession) serveLongPoll() { lastMessage = update.Message m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage) updateType = "change" case types.StatePeerChangedPatch: m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches, m.h.ACLPolicy) + data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches) updateType = "patch" case types.StatePeerRemoved: changed := make(map[types.NodeID]bool, len(update.Removed)) @@ -310,13 +310,13 @@ func (m *mapSession) serveLongPoll() { changed[nodeID] = false } m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage) updateType = "remove" case types.StateSelfUpdate: lastMessage = update.Message m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) // create the map so an empty (self) update is sent - data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, m.h.ACLPolicy, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage) updateType = "remove" case types.StateDERPUpdated: m.tracef("Sending DERPUpdate MapResponse") @@ -488,9 +488,12 @@ func (m *mapSession) handleEndpointUpdate() { return } - if m.h.ACLPolicy != nil { + // TODO(kradalby): Only update the node that has actually changed + nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier) + + if m.h.polMan != nil { // update routes with peer information - err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node) + err := m.h.db.EnableAutoApprovedRoutes(m.h.polMan, m.node) if err != nil { m.errf(err, "Error running auto approved routes") mapResponseEndpointUpdates.WithLabelValues("error").Inc() @@ -544,7 +547,7 @@ func (m *mapSession) handleEndpointUpdate() { func (m *mapSession) handleReadOnlyRequest() { m.tracef("Client asked for a lite update, responding without peers") - mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node, m.h.ACLPolicy) + mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node) if err != nil { m.errf(err, "Failed to create MapResponse") http.Error(m.w, "", http.StatusInternalServerError) diff --git a/hscontrol/templates/windows.go b/hscontrol/templates/windows.go index b233bac46f..680d66559c 100644 --- a/hscontrol/templates/windows.go +++ b/hscontrol/templates/windows.go @@ -13,7 +13,7 @@ func Windows(url string) *elem.Element { elem.Text("headscale - Windows"), ), elem.Body(attrs.Props{ - attrs.Style : bodyStyle.ToInline(), + attrs.Style: bodyStyle.ToInline(), }, headerOne("headscale: Windows configuration"), elem.P(nil, @@ -21,7 +21,8 @@ func Windows(url string) *elem.Element { elem.A(attrs.Props{ attrs.Href: "https://tailscale.com/download/windows", attrs.Rel: "noreferrer noopener", - attrs.Target: "_blank"}, + attrs.Target: "_blank", + }, elem.Text("Tailscale for Windows ")), elem.Text("and install it."), ), diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 65d7244422..665efe1525 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -28,8 +28,9 @@ const ( maxDuration time.Duration = 1<<63 - 1 ) -var errOidcMutuallyExclusive = errors.New( - "oidc_client_secret and oidc_client_secret_path are mutually exclusive", +var ( + errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive") + errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable") ) type IPAllocationStrategy string @@ -104,8 +105,9 @@ type Nameservers struct { } type SqliteConfig struct { - Path string - WriteAheadLog bool + Path string + WriteAheadLog bool + WALAutoCheckPoint int } type PostgresConfig struct { @@ -164,8 +166,10 @@ type OIDCConfig struct { AllowedDomains []string AllowedUsers []string AllowedGroups []string + StripEmaildomain bool Expiry time.Duration UseExpiryFromToken bool + MapLegacyUsers bool } type DERPConfig struct { @@ -276,11 +280,14 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("database.postgres.conn_max_idle_time_secs", 3600) viper.SetDefault("database.sqlite.write_ahead_log", true) + viper.SetDefault("database.sqlite.wal_autocheckpoint", 1000) // SQLite default viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"}) + viper.SetDefault("oidc.strip_email_domain", true) viper.SetDefault("oidc.only_start_if_oidc_is_available", true) viper.SetDefault("oidc.expiry", "180d") viper.SetDefault("oidc.use_expiry_from_token", false) + viper.SetDefault("oidc.map_legacy_users", true) viper.SetDefault("logtail.enabled", false) viper.SetDefault("randomize_client_port", false) @@ -326,14 +333,18 @@ func validateServerConfig() error { depr.warn("dns_config.use_username_in_magic_dns") depr.warn("dns.use_username_in_magic_dns") - depr.fatal("oidc.strip_email_domain") + // TODO(kradalby): Reintroduce when strip_email_domain is removed + // after #2170 is cleaned up + // depr.fatal("oidc.strip_email_domain") depr.fatal("dns.use_username_in_musername_in_magic_dns") depr.fatal("dns_config.use_username_in_musername_in_magic_dns") depr.Log() for _, removed := range []string{ - "oidc.strip_email_domain", + // TODO(kradalby): Reintroduce when strip_email_domain is removed + // after #2170 is cleaned up + // "oidc.strip_email_domain", "dns_config.use_username_in_musername_in_magic_dns", } { if viper.IsSet(removed) { @@ -550,7 +561,8 @@ func databaseConfig() DatabaseConfig { Path: util.AbsolutePathFromConfigPath( viper.GetString("database.sqlite.path"), ), - WriteAheadLog: viper.GetBool("database.sqlite.write_ahead_log"), + WriteAheadLog: viper.GetBool("database.sqlite.write_ahead_log"), + WALAutoCheckPoint: viper.GetInt("database.sqlite.wal_autocheckpoint"), }, Postgres: PostgresConfig{ Host: viper.GetString("database.postgres.host"), @@ -843,11 +855,10 @@ func LoadServerConfig() (*Config, error) { // - DERP run on their own domains // - Control plane runs on login.tailscale.com/controlplane.tailscale.com // - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net) - if dnsConfig.BaseDomain != "" && - strings.Contains(serverURL, dnsConfig.BaseDomain) { - return nil, errors.New( - "server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.", - ) + if dnsConfig.BaseDomain != "" { + if err := isSafeServerURL(serverURL, dnsConfig.BaseDomain); err != nil { + return nil, err + } } nodeManagement := nodeManagementConfig() @@ -915,6 +926,10 @@ func LoadServerConfig() (*Config, error) { } }(), UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"), + // TODO(kradalby): Remove when strip_email_domain is removed + // after #2170 is cleaned up + StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), + MapLegacyUsers: viper.GetBool("oidc.map_legacy_users"), }, LogTail: logTailConfig, @@ -944,6 +959,37 @@ func LoadServerConfig() (*Config, error) { }, nil } +// BaseDomain cannot be a suffix of the server URL. +// This is because Tailscale takes over the domain in BaseDomain, +// causing the headscale server and DERP to be unreachable. +// For Tailscale upstream, the following is true: +// - DERP run on their own domains. +// - Control plane runs on login.tailscale.com/controlplane.tailscale.com. +// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net). +func isSafeServerURL(serverURL, baseDomain string) error { + server, err := url.Parse(serverURL) + if err != nil { + return err + } + + serverDomainParts := strings.Split(server.Host, ".") + baseDomainParts := strings.Split(baseDomain, ".") + + if len(serverDomainParts) <= len(baseDomainParts) { + return nil + } + + s := len(serverDomainParts) + b := len(baseDomainParts) + for i := range len(baseDomainParts) { + if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] { + return nil + } + } + + return errServerURLSuffix +} + type deprecator struct { warns set.Set[string] fatals set.Set[string] diff --git a/hscontrol/types/config_test.go b/hscontrol/types/config_test.go index 70c0ce7a1b..58382ca5ab 100644 --- a/hscontrol/types/config_test.go +++ b/hscontrol/types/config_test.go @@ -1,6 +1,7 @@ package types import ( + "fmt" "os" "path/filepath" "testing" @@ -8,6 +9,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/spf13/viper" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" ) @@ -35,8 +37,17 @@ func TestReadConfig(t *testing.T) { MagicDNS: true, BaseDomain: "example.com", Nameservers: Nameservers{ - Global: []string{"1.1.1.1", "1.0.0.1", "2606:4700:4700::1111", "2606:4700:4700::1001", "https://dns.nextdns.io/abc123"}, - Split: map[string][]string{"darp.headscale.net": {"1.1.1.1", "8.8.8.8"}, "foo.bar.com": {"1.1.1.1"}}, + Global: []string{ + "1.1.1.1", + "1.0.0.1", + "2606:4700:4700::1111", + "2606:4700:4700::1001", + "https://dns.nextdns.io/abc123", + }, + Split: map[string][]string{ + "darp.headscale.net": {"1.1.1.1", "8.8.8.8"}, + "foo.bar.com": {"1.1.1.1"}, + }, }, ExtraRecords: []tailcfg.DNSRecord{ {Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"}, @@ -91,8 +102,17 @@ func TestReadConfig(t *testing.T) { MagicDNS: false, BaseDomain: "example.com", Nameservers: Nameservers{ - Global: []string{"1.1.1.1", "1.0.0.1", "2606:4700:4700::1111", "2606:4700:4700::1001", "https://dns.nextdns.io/abc123"}, - Split: map[string][]string{"darp.headscale.net": {"1.1.1.1", "8.8.8.8"}, "foo.bar.com": {"1.1.1.1"}}, + Global: []string{ + "1.1.1.1", + "1.0.0.1", + "2606:4700:4700::1111", + "2606:4700:4700::1001", + "https://dns.nextdns.io/abc123", + }, + Split: map[string][]string{ + "darp.headscale.net": {"1.1.1.1", "8.8.8.8"}, + "foo.bar.com": {"1.1.1.1"}, + }, }, ExtraRecords: []tailcfg.DNSRecord{ {Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"}, @@ -139,7 +159,7 @@ func TestReadConfig(t *testing.T) { return LoadServerConfig() }, want: nil, - wantErr: "server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.", + wantErr: errServerURLSuffix.Error(), }, { name: "base-domain-not-in-server-url", @@ -186,7 +206,7 @@ func TestReadConfig(t *testing.T) { t.Run(tt.name, func(t *testing.T) { viper.Reset() err := LoadConfig(tt.configPath, true) - assert.NoError(t, err) + require.NoError(t, err) conf, err := tt.setup(t) @@ -196,7 +216,7 @@ func TestReadConfig(t *testing.T) { return } - assert.NoError(t, err) + require.NoError(t, err) if diff := cmp.Diff(tt.want, conf); diff != "" { t.Errorf("ReadConfig() mismatch (-want +got):\n%s", diff) @@ -276,10 +296,10 @@ func TestReadConfigFromEnv(t *testing.T) { viper.Reset() err := LoadConfig("testdata/minimal.yaml", true) - assert.NoError(t, err) + require.NoError(t, err) conf, err := tt.setup(t) - assert.NoError(t, err) + require.NoError(t, err) if diff := cmp.Diff(tt.want, conf); diff != "" { t.Errorf("ReadConfig() mismatch (-want +got):\n%s", diff) @@ -310,13 +330,25 @@ noise: // Check configuration validation errors (1) err = LoadConfig(tmpDir, false) - assert.NoError(t, err) + require.NoError(t, err) err = validateServerConfig() - assert.Error(t, err) - assert.Contains(t, err.Error(), "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both") - assert.Contains(t, err.Error(), "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are") - assert.Contains(t, err.Error(), "Fatal config error: server_url must start with https:// or http://") + require.Error(t, err) + assert.Contains( + t, + err.Error(), + "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both", + ) + assert.Contains( + t, + err.Error(), + "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are", + ) + assert.Contains( + t, + err.Error(), + "Fatal config error: server_url must start with https:// or http://", + ) // Check configuration validation errors (2) configYaml = []byte(`--- @@ -331,5 +363,66 @@ tls_letsencrypt_challenge_type: TLS-ALPN-01 t.Fatalf("Couldn't write file %s", configFilePath) } err = LoadConfig(tmpDir, false) - assert.NoError(t, err) + require.NoError(t, err) +} + +// OK +// server_url: headscale.com, base: clients.headscale.com +// server_url: headscale.com, base: headscale.net +// +// NOT OK +// server_url: server.headscale.com, base: headscale.com. +func TestSafeServerURL(t *testing.T) { + tests := []struct { + serverURL, baseDomain, + wantErr string + }{ + { + serverURL: "https://example.com", + baseDomain: "example.org", + }, + { + serverURL: "https://headscale.com", + baseDomain: "headscale.com", + }, + { + serverURL: "https://headscale.com", + baseDomain: "clients.headscale.com", + }, + { + serverURL: "https://headscale.com", + baseDomain: "clients.subdomain.headscale.com", + }, + { + serverURL: "https://headscale.kristoffer.com", + baseDomain: "mybase", + }, + { + serverURL: "https://server.headscale.com", + baseDomain: "headscale.com", + wantErr: errServerURLSuffix.Error(), + }, + { + serverURL: "https://server.subdomain.headscale.com", + baseDomain: "headscale.com", + wantErr: errServerURLSuffix.Error(), + }, + { + serverURL: "http://foo\x00", + wantErr: `parse "http://foo\x00": net/url: invalid control character in URL`, + }, + } + + for _, tt := range tests { + testName := fmt.Sprintf("server=%s domain=%s", tt.serverURL, tt.baseDomain) + t.Run(testName, func(t *testing.T) { + err := isSafeServerURL(tt.serverURL, tt.baseDomain) + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + + return + } + assert.NoError(t, err) + }) + } } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index e95958747f..661680c014 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -229,6 +229,16 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes { return found } +func (nodes Nodes) ContainsNodeKey(nodeKey key.NodePublic) bool { + for _, node := range nodes { + if node.NodeKey == nodeKey { + return true + } + } + + return false +} + func (node *Node) Proto() *v1.Node { nodeProto := &v1.Node{ Id: uint64(node.ID), diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go index 7b8be423de..c23f34552f 100644 --- a/hscontrol/types/preauth_key.go +++ b/hscontrol/types/preauth_key.go @@ -27,7 +27,7 @@ type PreAuthKey struct { func (key *PreAuthKey) Proto() *v1.PreAuthKey { protoKey := v1.PreAuthKey{ - User: key.User.Name, + User: key.User.Username(), Id: strconv.FormatUint(key.ID, util.Base10), Key: key.Key, Ephemeral: key.Ephemeral, diff --git a/hscontrol/types/testdata/base-domain-in-server-url.yaml b/hscontrol/types/testdata/base-domain-in-server-url.yaml index 683e021837..2d6a4694a0 100644 --- a/hscontrol/types/testdata/base-domain-in-server-url.yaml +++ b/hscontrol/types/testdata/base-domain-in-server-url.yaml @@ -8,7 +8,7 @@ prefixes: database: type: sqlite3 -server_url: "https://derp.no" +server_url: "https://server.derp.no" dns: magic_dns: true diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index f983d7f52e..60fbbeda04 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -2,6 +2,8 @@ package types import ( "cmp" + "database/sql" + "net/mail" "strconv" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -19,10 +21,14 @@ type UserID uint64 // that contain our machines. type User struct { gorm.Model + // The index `idx_name_provider_identifier` is to enforce uniqueness + // between Name and ProviderIdentifier. This ensures that + // you can have multiple users with the same name in OIDC, + // but not if you only run with CLI users. // Username for the user, is used if email is empty // Should not be used, please use Username(). - Name string `gorm:"unique"` + Name string // Typically the full name of the user DisplayName string @@ -34,7 +40,7 @@ type User struct { // Unique identifier of the user from OIDC, // comes from `sub` claim in the OIDC token // and is used to lookup the user. - ProviderIdentifier string `gorm:"index"` + ProviderIdentifier sql.NullString // Provider is the origin of the user account, // same as RegistrationMethod, without authkey. @@ -51,7 +57,7 @@ type User struct { // should be used throughout headscale, in information returned to the // user and the Policy engine. func (u *User) Username() string { - return cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10)) + return cmp.Or(u.Email, u.Name, u.ProviderIdentifier.String, strconv.FormatUint(uint64(u.ID), 10)) } // DisplayNameOrUsername returns the DisplayName if it exists, otherwise @@ -107,7 +113,7 @@ func (u *User) Proto() *v1.User { CreatedAt: timestamppb.New(u.CreatedAt), DisplayName: u.DisplayName, Email: u.Email, - ProviderId: u.ProviderIdentifier, + ProviderId: u.ProviderIdentifier.String, Provider: u.Provider, ProfilePicUrl: u.ProfilePicURL, } @@ -116,6 +122,7 @@ func (u *User) Proto() *v1.User { type OIDCClaims struct { // Sub is the user's unique identifier at the provider. Sub string `json:"sub"` + Iss string `json:"iss"` // Name is the user's full name. Name string `json:"name,omitempty"` @@ -126,13 +133,27 @@ type OIDCClaims struct { Username string `json:"preferred_username,omitempty"` } +func (c *OIDCClaims) Identifier() string { + return c.Iss + "/" + c.Sub +} + // FromClaim overrides a User from OIDC claims. // All fields will be updated, except for the ID. func (u *User) FromClaim(claims *OIDCClaims) { - u.ProviderIdentifier = claims.Sub + err := util.CheckForFQDNRules(claims.Username) + if err == nil { + u.Name = claims.Username + } + + if claims.EmailVerified { + _, err = mail.ParseAddress(claims.Email) + if err == nil { + u.Email = claims.Email + } + } + + u.ProviderIdentifier = sql.NullString{String: claims.Identifier(), Valid: true} u.DisplayName = claims.Name - u.Email = claims.Email - u.Name = claims.Username u.ProfilePicURL = claims.ProfilePictureURL u.Provider = util.RegisterMethodOIDC } diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go index f57576f4aa..bf43eb507a 100644 --- a/hscontrol/util/dns.go +++ b/hscontrol/util/dns.go @@ -182,3 +182,33 @@ func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { return fqdns } + +// TODO(kradalby): Reintroduce when strip_email_domain is removed +// after #2170 is cleaned up +// DEPRECATED: DO NOT USE +// NormalizeToFQDNRules will replace forbidden chars in user +// it can also return an error if the user doesn't respect RFC 952 and 1123. +func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) { + + name = strings.ToLower(name) + name = strings.ReplaceAll(name, "'", "") + atIdx := strings.Index(name, "@") + if stripEmailDomain && atIdx > 0 { + name = name[:atIdx] + } else { + name = strings.ReplaceAll(name, "@", ".") + } + name = invalidCharsInUserRegex.ReplaceAllString(name, "-") + + for _, elt := range strings.Split(name, ".") { + if len(elt) > LabelHostnameLength { + return "", fmt.Errorf( + "label %v is more than 63 chars: %w", + elt, + ErrInvalidUserName, + ) + } + } + + return name, nil +} diff --git a/hscontrol/util/string_test.go b/hscontrol/util/string_test.go index 87a8be1c0b..2c392ab464 100644 --- a/hscontrol/util/string_test.go +++ b/hscontrol/util/string_test.go @@ -4,12 +4,13 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGenerateRandomStringDNSSafe(t *testing.T) { for i := 0; i < 100000; i++ { str, err := GenerateRandomStringDNSSafe(8) - assert.Nil(t, err) + require.NoError(t, err) assert.Len(t, str, 8) } } diff --git a/integration/README.md b/integration/README.md index e5676a44c6..56247c52f8 100644 --- a/integration/README.md +++ b/integration/README.md @@ -11,10 +11,10 @@ Tests are located in files ending with `_test.go` and the framework are located ## Running integration tests locally -The easiest way to run tests locally is to use `[act](INSERT LINK)`, a local GitHub Actions runner: +The easiest way to run tests locally is to use [act](https://github.com/nektos/act), a local GitHub Actions runner: ``` -act pull_request -W .github/workflows/test-integration-v2-TestPingAllByIP.yaml +act pull_request -W .github/workflows/test-integration.yaml ``` Alternatively, the `docker run` command in each GitHub workflow file can be used. diff --git a/integration/acl_test.go b/integration/acl_test.go index 1da8213d9f..6606a13220 100644 --- a/integration/acl_test.go +++ b/integration/acl_test.go @@ -12,6 +12,7 @@ import ( "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var veryLargeDestination = []string{ @@ -54,7 +55,7 @@ func aclScenario( ) *Scenario { t.Helper() scenario, err := NewScenario(dockertestMaxWait()) - assertNoErr(t, err) + require.NoError(t, err) spec := map[string]int{ "user1": clientsPerUser, @@ -77,10 +78,10 @@ func aclScenario( hsic.WithACLPolicy(policy), hsic.WithTestName("acl"), ) - assertNoErr(t, err) + require.NoError(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + require.NoError(t, err) return scenario } @@ -267,7 +268,7 @@ func TestACLHostsInNetMapTable(t *testing.T) { for name, testCase := range tests { t.Run(name, func(t *testing.T) { scenario, err := NewScenario(dockertestMaxWait()) - assertNoErr(t, err) + require.NoError(t, err) spec := testCase.users @@ -275,22 +276,22 @@ func TestACLHostsInNetMapTable(t *testing.T) { []tsic.Option{}, hsic.WithACLPolicy(&testCase.policy), ) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) allClients, err := scenario.ListTailscaleClients() - assertNoErr(t, err) + require.NoError(t, err) err = scenario.WaitForTailscaleSyncWithPeerCount(testCase.want["user1"]) - assertNoErrSync(t, err) + require.NoError(t, err) for _, client := range allClients { status, err := client.Status() - assertNoErr(t, err) + require.NoError(t, err) user := status.User[status.Self.UserID].LoginName - assert.Equal(t, (testCase.want[user]), len(status.Peer)) + assert.Len(t, status.Peer, (testCase.want[user])) } }) } @@ -319,23 +320,23 @@ func TestACLAllowUser80Dst(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) user1Clients, err := scenario.ListTailscaleClients("user1") - assertNoErr(t, err) + require.NoError(t, err) user2Clients, err := scenario.ListTailscaleClients("user2") - assertNoErr(t, err) + require.NoError(t, err) // Test that user1 can visit all user2 for _, client := range user1Clients { for _, peer := range user2Clients { fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) result, err := client.Curl(url) assert.Len(t, result, 13) - assertNoErr(t, err) + require.NoError(t, err) } } @@ -343,14 +344,14 @@ func TestACLAllowUser80Dst(t *testing.T) { for _, client := range user2Clients { for _, peer := range user1Clients { fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) result, err := client.Curl(url) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) } } } @@ -376,10 +377,10 @@ func TestACLDenyAllPort80(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) allClients, err := scenario.ListTailscaleClients() - assertNoErr(t, err) + require.NoError(t, err) allHostnames, err := scenario.ListTailscaleClientsFQDNs() - assertNoErr(t, err) + require.NoError(t, err) for _, client := range allClients { for _, hostname := range allHostnames { @@ -394,7 +395,7 @@ func TestACLDenyAllPort80(t *testing.T) { result, err := client.Curl(url) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) } } } @@ -420,23 +421,23 @@ func TestACLAllowUserDst(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) user1Clients, err := scenario.ListTailscaleClients("user1") - assertNoErr(t, err) + require.NoError(t, err) user2Clients, err := scenario.ListTailscaleClients("user2") - assertNoErr(t, err) + require.NoError(t, err) // Test that user1 can visit all user2 for _, client := range user1Clients { for _, peer := range user2Clients { fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) result, err := client.Curl(url) assert.Len(t, result, 13) - assertNoErr(t, err) + require.NoError(t, err) } } @@ -444,14 +445,14 @@ func TestACLAllowUserDst(t *testing.T) { for _, client := range user2Clients { for _, peer := range user1Clients { fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) result, err := client.Curl(url) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) } } } @@ -476,23 +477,23 @@ func TestACLAllowStarDst(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) user1Clients, err := scenario.ListTailscaleClients("user1") - assertNoErr(t, err) + require.NoError(t, err) user2Clients, err := scenario.ListTailscaleClients("user2") - assertNoErr(t, err) + require.NoError(t, err) // Test that user1 can visit all user2 for _, client := range user1Clients { for _, peer := range user2Clients { fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) result, err := client.Curl(url) assert.Len(t, result, 13) - assertNoErr(t, err) + require.NoError(t, err) } } @@ -500,14 +501,14 @@ func TestACLAllowStarDst(t *testing.T) { for _, client := range user2Clients { for _, peer := range user1Clients { fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) result, err := client.Curl(url) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) } } } @@ -537,23 +538,23 @@ func TestACLNamedHostsCanReachBySubnet(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) user1Clients, err := scenario.ListTailscaleClients("user1") - assertNoErr(t, err) + require.NoError(t, err) user2Clients, err := scenario.ListTailscaleClients("user2") - assertNoErr(t, err) + require.NoError(t, err) // Test that user1 can visit all user2 for _, client := range user1Clients { for _, peer := range user2Clients { fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) result, err := client.Curl(url) assert.Len(t, result, 13) - assertNoErr(t, err) + require.NoError(t, err) } } @@ -561,14 +562,14 @@ func TestACLNamedHostsCanReachBySubnet(t *testing.T) { for _, client := range user2Clients { for _, peer := range user1Clients { fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) result, err := client.Curl(url) assert.Len(t, result, 13) - assertNoErr(t, err) + require.NoError(t, err) } } } @@ -679,10 +680,10 @@ func TestACLNamedHostsCanReach(t *testing.T) { test1ip4 := netip.MustParseAddr("100.64.0.1") test1ip6 := netip.MustParseAddr("fd7a:115c:a1e0::1") test1, err := scenario.FindTailscaleClientByIP(test1ip6) - assertNoErr(t, err) + require.NoError(t, err) test1fqdn, err := test1.FQDN() - assertNoErr(t, err) + require.NoError(t, err) test1ip4URL := fmt.Sprintf("http://%s/etc/hostname", test1ip4.String()) test1ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test1ip6.String()) test1fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test1fqdn) @@ -690,10 +691,10 @@ func TestACLNamedHostsCanReach(t *testing.T) { test2ip4 := netip.MustParseAddr("100.64.0.2") test2ip6 := netip.MustParseAddr("fd7a:115c:a1e0::2") test2, err := scenario.FindTailscaleClientByIP(test2ip6) - assertNoErr(t, err) + require.NoError(t, err) test2fqdn, err := test2.FQDN() - assertNoErr(t, err) + require.NoError(t, err) test2ip4URL := fmt.Sprintf("http://%s/etc/hostname", test2ip4.String()) test2ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test2ip6.String()) test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn) @@ -701,10 +702,10 @@ func TestACLNamedHostsCanReach(t *testing.T) { test3ip4 := netip.MustParseAddr("100.64.0.3") test3ip6 := netip.MustParseAddr("fd7a:115c:a1e0::3") test3, err := scenario.FindTailscaleClientByIP(test3ip6) - assertNoErr(t, err) + require.NoError(t, err) test3fqdn, err := test3.FQDN() - assertNoErr(t, err) + require.NoError(t, err) test3ip4URL := fmt.Sprintf("http://%s/etc/hostname", test3ip4.String()) test3ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test3ip6.String()) test3fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test3fqdn) @@ -719,7 +720,7 @@ func TestACLNamedHostsCanReach(t *testing.T) { test3ip4URL, result, ) - assertNoErr(t, err) + require.NoError(t, err) result, err = test1.Curl(test3ip6URL) assert.Lenf( @@ -730,7 +731,7 @@ func TestACLNamedHostsCanReach(t *testing.T) { test3ip6URL, result, ) - assertNoErr(t, err) + require.NoError(t, err) result, err = test1.Curl(test3fqdnURL) assert.Lenf( @@ -741,7 +742,7 @@ func TestACLNamedHostsCanReach(t *testing.T) { test3fqdnURL, result, ) - assertNoErr(t, err) + require.NoError(t, err) // test2 can query test3 result, err = test2.Curl(test3ip4URL) @@ -753,7 +754,7 @@ func TestACLNamedHostsCanReach(t *testing.T) { test3ip4URL, result, ) - assertNoErr(t, err) + require.NoError(t, err) result, err = test2.Curl(test3ip6URL) assert.Lenf( @@ -764,7 +765,7 @@ func TestACLNamedHostsCanReach(t *testing.T) { test3ip6URL, result, ) - assertNoErr(t, err) + require.NoError(t, err) result, err = test2.Curl(test3fqdnURL) assert.Lenf( @@ -775,33 +776,33 @@ func TestACLNamedHostsCanReach(t *testing.T) { test3fqdnURL, result, ) - assertNoErr(t, err) + require.NoError(t, err) // test3 cannot query test1 result, err = test3.Curl(test1ip4URL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) result, err = test3.Curl(test1ip6URL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) result, err = test3.Curl(test1fqdnURL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) // test3 cannot query test2 result, err = test3.Curl(test2ip4URL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) result, err = test3.Curl(test2ip6URL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) result, err = test3.Curl(test2fqdnURL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) // test1 can query test2 result, err = test1.Curl(test2ip4URL) @@ -814,7 +815,7 @@ func TestACLNamedHostsCanReach(t *testing.T) { result, ) - assertNoErr(t, err) + require.NoError(t, err) result, err = test1.Curl(test2ip6URL) assert.Lenf( t, @@ -824,7 +825,7 @@ func TestACLNamedHostsCanReach(t *testing.T) { test2ip6URL, result, ) - assertNoErr(t, err) + require.NoError(t, err) result, err = test1.Curl(test2fqdnURL) assert.Lenf( @@ -835,20 +836,20 @@ func TestACLNamedHostsCanReach(t *testing.T) { test2fqdnURL, result, ) - assertNoErr(t, err) + require.NoError(t, err) // test2 cannot query test1 result, err = test2.Curl(test1ip4URL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) result, err = test2.Curl(test1ip6URL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) result, err = test2.Curl(test1fqdnURL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) }) } } @@ -946,10 +947,10 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { test1ip6 := netip.MustParseAddr("fd7a:115c:a1e0::1") test1, err := scenario.FindTailscaleClientByIP(test1ip) assert.NotNil(t, test1) - assertNoErr(t, err) + require.NoError(t, err) test1fqdn, err := test1.FQDN() - assertNoErr(t, err) + require.NoError(t, err) test1ipURL := fmt.Sprintf("http://%s/etc/hostname", test1ip.String()) test1ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test1ip6.String()) test1fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test1fqdn) @@ -958,10 +959,10 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { test2ip6 := netip.MustParseAddr("fd7a:115c:a1e0::2") test2, err := scenario.FindTailscaleClientByIP(test2ip) assert.NotNil(t, test2) - assertNoErr(t, err) + require.NoError(t, err) test2fqdn, err := test2.FQDN() - assertNoErr(t, err) + require.NoError(t, err) test2ipURL := fmt.Sprintf("http://%s/etc/hostname", test2ip.String()) test2ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test2ip6.String()) test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn) @@ -976,7 +977,7 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { test2ipURL, result, ) - assertNoErr(t, err) + require.NoError(t, err) result, err = test1.Curl(test2ip6URL) assert.Lenf( @@ -987,7 +988,7 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { test2ip6URL, result, ) - assertNoErr(t, err) + require.NoError(t, err) result, err = test1.Curl(test2fqdnURL) assert.Lenf( @@ -998,19 +999,19 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { test2fqdnURL, result, ) - assertNoErr(t, err) + require.NoError(t, err) result, err = test2.Curl(test1ipURL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) result, err = test2.Curl(test1ip6URL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) result, err = test2.Curl(test1fqdnURL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) }) } } @@ -1020,7 +1021,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { t.Parallel() scenario, err := NewScenario(dockertestMaxWait()) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) spec := map[string]int{ @@ -1046,19 +1047,19 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { "HEADSCALE_POLICY_MODE": "database", }), ) - assertNoErr(t, err) + require.NoError(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + require.NoError(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + require.NoError(t, err) user1Clients, err := scenario.ListTailscaleClients("user1") - assertNoErr(t, err) + require.NoError(t, err) user2Clients, err := scenario.ListTailscaleClients("user2") - assertNoErr(t, err) + require.NoError(t, err) all := append(user1Clients, user2Clients...) @@ -1070,19 +1071,19 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { } fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) result, err := client.Curl(url) assert.Len(t, result, 13) - assertNoErr(t, err) + require.NoError(t, err) } } headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) p := policy.ACLPolicy{ ACLs: []policy.ACL{ @@ -1100,7 +1101,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { policyFilePath := "/etc/headscale/policy.json" err = headscale.WriteFile(policyFilePath, pBytes) - assertNoErr(t, err) + require.NoError(t, err) // No policy is present at this time. // Add a new policy from a file. @@ -1113,7 +1114,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { policyFilePath, }, ) - assertNoErr(t, err) + require.NoError(t, err) // Get the current policy and check // if it is the same as the one we set. @@ -1129,7 +1130,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { }, &output, ) - assertNoErr(t, err) + require.NoError(t, err) assert.Len(t, output.ACLs, 1) @@ -1141,14 +1142,14 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { for _, client := range user1Clients { for _, peer := range user2Clients { fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) result, err := client.Curl(url) assert.Len(t, result, 13) - assertNoErr(t, err) + require.NoError(t, err) } } @@ -1156,14 +1157,14 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { for _, client := range user2Clients { for _, peer := range user1Clients { fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) result, err := client.Curl(url) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) } } } diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 6fbdd9e42e..e0a614012c 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -3,6 +3,7 @@ package integration import ( "context" "crypto/tls" + "encoding/json" "errors" "fmt" "io" @@ -10,14 +11,19 @@ import ( "net" "net/http" "net/netip" + "sort" "strconv" "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/hsic" + "github.com/oauth2-proxy/mockoidc" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" "github.com/samber/lo" @@ -48,20 +54,34 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { scenario := AuthOIDCScenario{ Scenario: baseScenario, } - defer scenario.ShutdownAssertNoPanics(t) + // defer scenario.ShutdownAssertNoPanics(t) + // Logins to MockOIDC is served by a queue with a strict order, + // if we use more than one node per user, the order of the logins + // will not be deterministic and the test will fail. spec := map[string]int{ - "user1": len(MustTestVersions), + "user1": 1, + "user2": 1, + } + + mockusers := []mockoidc.MockUser{ + oidcMockUser("user1", true), + oidcMockUser("user2", false), } - oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL) + oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) assertNoErrf(t, "failed to run mock OIDC server: %s", err) + defer scenario.mockOIDC.Close() oidcMap := map[string]string{ "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, "CREDENTIALS_DIRECTORY_TEST": "/tmp", "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + // TODO(kradalby): Remove when strip_email_domain is removed + // after #2170 is cleaned up + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "0", + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", } err = scenario.CreateHeadscaleEnv( @@ -91,6 +111,55 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { success := pingAllHelper(t, allClients, allAddrs) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + var listUsers []v1.User + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "users", + "list", + "--output", + "json", + }, + &listUsers, + ) + assertNoErr(t, err) + + want := []v1.User{ + { + Id: "1", + Name: "user1", + }, + { + Id: "2", + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: oidcConfig.Issuer + "/user1", + }, + { + Id: "3", + Name: "user2", + }, + { + Id: "4", + Name: "user2", + Email: "", // Unverified + Provider: "oidc", + ProviderId: oidcConfig.Issuer + "/user2", + }, + } + + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].Id < listUsers[j].Id + }) + + if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + t.Fatalf("unexpected users: %s", diff) + } } // This test is really flaky. @@ -111,11 +180,16 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) spec := map[string]int{ - "user1": 3, + "user1": 1, + "user2": 1, } - oidcConfig, err := scenario.runMockOIDC(shortAccessTTL) + oidcConfig, err := scenario.runMockOIDC(shortAccessTTL, []mockoidc.MockUser{ + oidcMockUser("user1", true), + oidcMockUser("user2", false), + }) assertNoErrf(t, "failed to run mock OIDC server: %s", err) + defer scenario.mockOIDC.Close() oidcMap := map[string]string{ "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, @@ -159,6 +233,297 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { assertTailscaleNodesLogout(t, allClients) } +// TODO(kradalby): +// - Test that creates a new user when one exists when migration is turned off +// - Test that takes over a user when one exists when migration is turned on +// - But email is not verified +// - stripped email domain on/off +func TestOIDC024UserCreation(t *testing.T) { + IntegrationSkip(t) + + tests := []struct { + name string + config map[string]string + emailVerified bool + cliUsers []string + oidcUsers []string + want func(iss string) []v1.User + }{ + { + name: "no-migration-verified-email", + config: map[string]string{ + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "0", + }, + emailVerified: true, + cliUsers: []string{"user1", "user2"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []v1.User { + return []v1.User{ + { + Id: "1", + Name: "user1", + }, + { + Id: "2", + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: "3", + Name: "user2", + }, + { + Id: "4", + Name: "user2", + Email: "user2@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + { + name: "no-migration-not-verified-email", + config: map[string]string{ + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "0", + }, + emailVerified: false, + cliUsers: []string{"user1", "user2"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []v1.User { + return []v1.User{ + { + Id: "1", + Name: "user1", + }, + { + Id: "2", + Name: "user1", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: "3", + Name: "user2", + }, + { + Id: "4", + Name: "user2", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + { + name: "migration-strip-domains-verified-email", + config: map[string]string{ + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "1", + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "1", + }, + emailVerified: true, + cliUsers: []string{"user1", "user2"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []v1.User { + return []v1.User{ + { + Id: "1", + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: "2", + Name: "user2", + Email: "user2@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + { + name: "migration-strip-domains-not-verified-email", + config: map[string]string{ + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "1", + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "1", + }, + emailVerified: false, + cliUsers: []string{"user1", "user2"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []v1.User { + return []v1.User{ + { + Id: "1", + Name: "user1", + }, + { + Id: "2", + Name: "user1", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: "3", + Name: "user2", + }, + { + Id: "4", + Name: "user2", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + { + name: "migration-no-strip-domains-verified-email", + config: map[string]string{ + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "1", + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", + }, + emailVerified: true, + cliUsers: []string{"user1.headscale.net", "user2.headscale.net"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []v1.User { + return []v1.User{ + // Hmm I think we will have to overwrite the initial name here + // createuser with "user1.headscale.net", but oidc with "user1" + { + Id: "1", + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: "2", + Name: "user2", + Email: "user2@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + { + name: "migration-no-strip-domains-not-verified-email", + config: map[string]string{ + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "1", + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", + }, + emailVerified: false, + cliUsers: []string{"user1.headscale.net", "user2.headscale.net"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []v1.User { + return []v1.User{ + { + Id: "1", + Name: "user1.headscale.net", + }, + { + Id: "2", + Name: "user1", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: "3", + Name: "user2.headscale.net", + }, + { + Id: "4", + Name: "user2", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + baseScenario, err := NewScenario(dockertestMaxWait()) + assertNoErr(t, err) + + scenario := AuthOIDCScenario{ + Scenario: baseScenario, + } + defer scenario.ShutdownAssertNoPanics(t) + + spec := map[string]int{} + for _, user := range tt.cliUsers { + spec[user] = 1 + } + + var mockusers []mockoidc.MockUser + for _, user := range tt.oidcUsers { + mockusers = append(mockusers, oidcMockUser(user, tt.emailVerified)) + } + + oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) + assertNoErrf(t, "failed to run mock OIDC server: %s", err) + defer scenario.mockOIDC.Close() + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, + "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + } + + for k, v := range tt.config { + oidcMap[k] = v + } + + err = scenario.CreateHeadscaleEnv( + spec, + hsic.WithTestName("oidcmigration"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithHostnameAsServerURL(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), + ) + assertNoErrHeadscaleEnv(t, err) + + // Ensure that the nodes have logged in, this is what + // triggers user creation via OIDC. + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + want := tt.want(oidcConfig.Issuer) + + var listUsers []v1.User + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "users", + "list", + "--output", + "json", + }, + &listUsers, + ) + assertNoErr(t, err) + + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].Id < listUsers[j].Id + }) + + if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + t.Errorf("unexpected users: %s", diff) + } + }) + } +} + func (s *AuthOIDCScenario) CreateHeadscaleEnv( users map[string]int, opts ...hsic.Option, @@ -174,6 +539,13 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv( } for userName, clientCount := range users { + if clientCount != 1 { + // OIDC scenario only supports one client per user. + // This is because the MockOIDC server can only serve login + // requests based on a queue it has been given on startup. + // We currently only populates it with one login request per user. + return fmt.Errorf("client count must be 1 for OIDC scenario.") + } log.Printf("creating user %s with %d clients", userName, clientCount) err = s.CreateUser(userName) if err != nil { @@ -194,7 +566,7 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv( return nil } -func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConfig, error) { +func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) (*types.OIDCConfig, error) { port, err := dockertestutil.RandomFreeHostPort() if err != nil { log.Fatalf("could not find an open port: %s", err) @@ -205,6 +577,11 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf hostname := fmt.Sprintf("hs-oidcmock-%s", hash) + usersJSON, err := json.Marshal(users) + if err != nil { + return nil, err + } + mockOidcOptions := &dockertest.RunOptions{ Name: hostname, Cmd: []string{"headscale", "mockoidc"}, @@ -219,11 +596,12 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf "MOCKOIDC_CLIENT_ID=superclient", "MOCKOIDC_CLIENT_SECRET=supersecret", fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()), + fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)), }, } headscaleBuildOptions := &dockertest.BuildOptions{ - Dockerfile: "Dockerfile.debug", + Dockerfile: hsic.IntegrationTestDockerFileName, ContextDir: dockerContextPath, } @@ -310,45 +688,40 @@ func (s *AuthOIDCScenario) runTailscaleUp( log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String()) - if err := s.pool.Retry(func() error { - log.Printf("%s logging in with url", c.Hostname()) - httpClient := &http.Client{Transport: insecureTransport} - ctx := context.Background() - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) - resp, err := httpClient.Do(req) - if err != nil { - log.Printf( - "%s failed to login using url %s: %s", - c.Hostname(), - loginURL, - err, - ) - - return err - } + log.Printf("%s logging in with url", c.Hostname()) + httpClient := &http.Client{Transport: insecureTransport} + ctx := context.Background() + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) + resp, err := httpClient.Do(req) + if err != nil { + log.Printf( + "%s failed to login using url %s: %s", + c.Hostname(), + loginURL, + err, + ) - if resp.StatusCode != http.StatusOK { - log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status) + return err + } - return errStatusCodeNotOK - } + if resp.StatusCode != http.StatusOK { + log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status) + body, _ := io.ReadAll(resp.Body) + log.Printf("body: %s", body) - defer resp.Body.Close() + return errStatusCodeNotOK + } - _, err = io.ReadAll(resp.Body) - if err != nil { - log.Printf("%s failed to read response body: %s", c.Hostname(), err) + defer resp.Body.Close() - return err - } + _, err = io.ReadAll(resp.Body) + if err != nil { + log.Printf("%s failed to read response body: %s", c.Hostname(), err) - return nil - }); err != nil { return err } log.Printf("Finished request for %s to join tailnet", c.Hostname()) - return nil }) @@ -395,3 +768,12 @@ func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) { assert.Equal(t, "NeedsLogin", status.BackendState) } } + +func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser { + return mockoidc.MockUser{ + Subject: username, + PreferredUsername: username, + Email: fmt.Sprintf("%s@headscale.net", username), + EmailVerified: emailVerified, + } +} diff --git a/integration/cli_test.go b/integration/cli_test.go index dd2c8c2772..521b6e2b99 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -14,6 +14,7 @@ import ( "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" + "golang.org/x/exp/slices" ) func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error { @@ -843,117 +844,85 @@ func TestNodeTagCommand(t *testing.T) { ) } -func TestNodeAdvertiseTagNoACLCommand(t *testing.T) { +func TestNodeAdvertiseTagCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) - assertNoErr(t, err) - defer scenario.ShutdownAssertNoPanics(t) - - spec := map[string]int{ - "user1": 1, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:test"})}, hsic.WithTestName("cliadvtags")) - assertNoErr(t, err) - - headscale, err := scenario.Headscale() - assertNoErr(t, err) - - // Test list all nodes after added seconds - resultMachines := make([]*v1.Node, spec["user1"]) - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--tags", - "--output", "json", + tests := []struct { + name string + policy *policy.ACLPolicy + wantTag bool + }{ + { + name: "no-policy", + wantTag: false, }, - &resultMachines, - ) - assert.Nil(t, err) - found := false - for _, node := range resultMachines { - if node.GetInvalidTags() != nil { - for _, tag := range node.GetInvalidTags() { - if tag == "tag:test" { - found = true - } - } - } - } - assert.Equal( - t, - true, - found, - "should not find a node with the tag 'tag:test' in the list of nodes", - ) -} - -func TestNodeAdvertiseTagWithACLCommand(t *testing.T) { - IntegrationSkip(t) - t.Parallel() - - scenario, err := NewScenario(dockertestMaxWait()) - assertNoErr(t, err) - defer scenario.ShutdownAssertNoPanics(t) - - spec := map[string]int{ - "user1": 1, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:exists"})}, hsic.WithTestName("cliadvtags"), hsic.WithACLPolicy( - &policy.ACLPolicy{ - ACLs: []policy.ACL{ - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + { + name: "with-policy", + policy: &policy.ACLPolicy{ + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"*:*"}, + }, + }, + TagOwners: map[string][]string{ + "tag:test": {"user1"}, }, }, - TagOwners: map[string][]string{ - "tag:exists": {"user1"}, - }, + wantTag: true, }, - )) - assertNoErr(t, err) + } - headscale, err := scenario.Headscale() - assertNoErr(t, err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scenario, err := NewScenario(dockertestMaxWait()) + assertNoErr(t, err) + // defer scenario.ShutdownAssertNoPanics(t) - // Test list all nodes after added seconds - resultMachines := make([]*v1.Node, spec["user1"]) - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--tags", - "--output", "json", - }, - &resultMachines, - ) - assert.Nil(t, err) - found := false - for _, node := range resultMachines { - if node.GetValidTags() != nil { - for _, tag := range node.GetValidTags() { - if tag == "tag:exists" { - found = true + spec := map[string]int{ + "user1": 1, + } + + err = scenario.CreateHeadscaleEnv(spec, + []tsic.Option{tsic.WithTags([]string{"tag:test"})}, + hsic.WithTestName("cliadvtags"), + hsic.WithACLPolicy(tt.policy), + ) + assertNoErr(t, err) + + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + // Test list all nodes after added seconds + resultMachines := make([]*v1.Node, spec["user1"]) + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--tags", + "--output", "json", + }, + &resultMachines, + ) + assert.Nil(t, err) + found := false + for _, node := range resultMachines { + if tags := node.GetValidTags(); tags != nil { + found = slices.Contains(tags, "tag:test") } } - } + assert.Equalf( + t, + tt.wantTag, + found, + "'tag:test' found(%t) is the list of nodes, expected %t", found, tt.wantTag, + ) + }) } - assert.Equal( - t, - true, - found, - "should not find a node with the tag 'tag:exists' in the list of nodes", - ) } func TestNodeCommand(t *testing.T) { @@ -1921,7 +1890,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { policyFilePath, }, ) - assert.ErrorContains(t, err, "verifying policy rules: invalid action") + assert.ErrorContains(t, err, "compiling filter rules: invalid action") // The new policy was invalid, the old one should still be in place, which // is none. @@ -1936,3 +1905,4 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { ) assert.ErrorContains(t, err, "acl policy not found") } + diff --git a/integration/derp_verify_endpoint_test.go b/integration/derp_verify_endpoint_test.go new file mode 100644 index 0000000000..adad5b6a49 --- /dev/null +++ b/integration/derp_verify_endpoint_test.go @@ -0,0 +1,96 @@ +package integration + +import ( + "encoding/json" + "fmt" + "net" + "strconv" + "strings" + "testing" + + "github.com/juanfont/headscale/hscontrol/util" + "github.com/juanfont/headscale/integration/dsic" + "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/integrationutil" + "github.com/juanfont/headscale/integration/tsic" + "tailscale.com/tailcfg" +) + +func TestDERPVerifyEndpoint(t *testing.T) { + IntegrationSkip(t) + + // Generate random hostname for the headscale instance + hash, err := util.GenerateRandomStringDNSSafe(6) + assertNoErr(t, err) + testName := "derpverify" + hostname := fmt.Sprintf("hs-%s-%s", testName, hash) + + headscalePort := 8080 + + // Create cert for headscale + certHeadscale, keyHeadscale, err := integrationutil.CreateCertificate(hostname) + assertNoErr(t, err) + + scenario, err := NewScenario(dockertestMaxWait()) + assertNoErr(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + spec := map[string]int{ + "user1": len(MustTestVersions), + } + + derper, err := scenario.CreateDERPServer("head", + dsic.WithCACert(certHeadscale), + dsic.WithVerifyClientURL(fmt.Sprintf("https://%s/verify", net.JoinHostPort(hostname, strconv.Itoa(headscalePort)))), + ) + assertNoErr(t, err) + + derpMap := tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 900: { + RegionID: 900, + RegionCode: "test-derpverify", + RegionName: "TestDerpVerify", + Nodes: []*tailcfg.DERPNode{ + { + Name: "TestDerpVerify", + RegionID: 900, + HostName: derper.GetHostname(), + STUNPort: derper.GetSTUNPort(), + STUNOnly: false, + DERPPort: derper.GetDERPPort(), + }, + }, + }, + }, + } + + err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithCACert(derper.GetCert())}, + hsic.WithHostname(hostname), + hsic.WithPort(headscalePort), + hsic.WithCustomTLS(certHeadscale, keyHeadscale), + hsic.WithHostnameAsServerURL(), + hsic.WithDERPConfig(derpMap)) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + for _, client := range allClients { + report, err := client.DebugDERPRegion("test-derpverify") + assertNoErr(t, err) + successful := false + for _, line := range report.Info { + if strings.Contains(line, "Successfully established a DERP connection with node") { + successful = true + + break + } + } + if !successful { + stJSON, err := json.Marshal(report) + assertNoErr(t, err) + t.Errorf("Client %s could not establish a DERP connection: %s", client.Hostname(), string(stJSON)) + } + } +} diff --git a/integration/dockertestutil/execute.go b/integration/dockertestutil/execute.go index 1b41e32472..9e16f3660c 100644 --- a/integration/dockertestutil/execute.go +++ b/integration/dockertestutil/execute.go @@ -74,7 +74,7 @@ func ExecuteCommand( select { case res := <-resultChan: if res.err != nil { - return stdout.String(), stderr.String(), res.err + return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), res.err) } if res.exitCode != 0 { @@ -83,12 +83,12 @@ func ExecuteCommand( // log.Println("stdout: ", stdout.String()) // log.Println("stderr: ", stderr.String()) - return stdout.String(), stderr.String(), ErrDockertestCommandFailed + return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), ErrDockertestCommandFailed) } return stdout.String(), stderr.String(), nil case <-time.After(execConfig.timeout): - return stdout.String(), stderr.String(), ErrDockertestCommandTimeout + return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), ErrDockertestCommandTimeout) } } diff --git a/integration/dsic/dsic.go b/integration/dsic/dsic.go new file mode 100644 index 0000000000..f8bb85a9a4 --- /dev/null +++ b/integration/dsic/dsic.go @@ -0,0 +1,321 @@ +package dsic + +import ( + "crypto/tls" + "errors" + "fmt" + "log" + "net" + "net/http" + "strconv" + "strings" + "time" + + "github.com/juanfont/headscale/hscontrol/util" + "github.com/juanfont/headscale/integration/dockertestutil" + "github.com/juanfont/headscale/integration/integrationutil" + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" +) + +const ( + dsicHashLength = 6 + dockerContextPath = "../." + caCertRoot = "/usr/local/share/ca-certificates" + DERPerCertRoot = "/usr/local/share/derper-certs" + dockerExecuteTimeout = 60 * time.Second +) + +var errDERPerStatusCodeNotOk = errors.New("DERPer status code not OK") + +// DERPServerInContainer represents DERP Server in Container (DSIC). +type DERPServerInContainer struct { + version string + hostname string + + pool *dockertest.Pool + container *dockertest.Resource + network *dockertest.Network + + stunPort int + derpPort int + caCerts [][]byte + tlsCert []byte + tlsKey []byte + withExtraHosts []string + withVerifyClientURL string + workdir string +} + +// Option represent optional settings that can be given to a +// DERPer instance. +type Option = func(c *DERPServerInContainer) + +// WithCACert adds it to the trusted surtificate of the Tailscale container. +func WithCACert(cert []byte) Option { + return func(dsic *DERPServerInContainer) { + dsic.caCerts = append(dsic.caCerts, cert) + } +} + +// WithOrCreateNetwork sets the Docker container network to use with +// the DERPer instance, if the parameter is nil, a new network, +// isolating the DERPer, will be created. If a network is +// passed, the DERPer instance will join the given network. +func WithOrCreateNetwork(network *dockertest.Network) Option { + return func(tsic *DERPServerInContainer) { + if network != nil { + tsic.network = network + + return + } + + network, err := dockertestutil.GetFirstOrCreateNetwork( + tsic.pool, + tsic.hostname+"-network", + ) + if err != nil { + log.Fatalf("failed to create network: %s", err) + } + + tsic.network = network + } +} + +// WithDockerWorkdir allows the docker working directory to be set. +func WithDockerWorkdir(dir string) Option { + return func(tsic *DERPServerInContainer) { + tsic.workdir = dir + } +} + +// WithVerifyClientURL sets the URL to verify the client. +func WithVerifyClientURL(url string) Option { + return func(tsic *DERPServerInContainer) { + tsic.withVerifyClientURL = url + } +} + +// WithExtraHosts adds extra hosts to the container. +func WithExtraHosts(hosts []string) Option { + return func(tsic *DERPServerInContainer) { + tsic.withExtraHosts = hosts + } +} + +// New returns a new TailscaleInContainer instance. +func New( + pool *dockertest.Pool, + version string, + network *dockertest.Network, + opts ...Option, +) (*DERPServerInContainer, error) { + hash, err := util.GenerateRandomStringDNSSafe(dsicHashLength) + if err != nil { + return nil, err + } + + hostname := fmt.Sprintf("derp-%s-%s", strings.ReplaceAll(version, ".", "-"), hash) + tlsCert, tlsKey, err := integrationutil.CreateCertificate(hostname) + if err != nil { + return nil, fmt.Errorf("failed to create certificates for headscale test: %w", err) + } + dsic := &DERPServerInContainer{ + version: version, + hostname: hostname, + pool: pool, + network: network, + tlsCert: tlsCert, + tlsKey: tlsKey, + stunPort: 3478, //nolint + derpPort: 443, //nolint + } + + for _, opt := range opts { + opt(dsic) + } + + var cmdArgs strings.Builder + fmt.Fprintf(&cmdArgs, "--hostname=%s", hostname) + fmt.Fprintf(&cmdArgs, " --certmode=manual") + fmt.Fprintf(&cmdArgs, " --certdir=%s", DERPerCertRoot) + fmt.Fprintf(&cmdArgs, " --a=:%d", dsic.derpPort) + fmt.Fprintf(&cmdArgs, " --stun=true") + fmt.Fprintf(&cmdArgs, " --stun-port=%d", dsic.stunPort) + if dsic.withVerifyClientURL != "" { + fmt.Fprintf(&cmdArgs, " --verify-client-url=%s", dsic.withVerifyClientURL) + } + + runOptions := &dockertest.RunOptions{ + Name: hostname, + Networks: []*dockertest.Network{dsic.network}, + ExtraHosts: dsic.withExtraHosts, + // we currently need to give us some time to inject the certificate further down. + Entrypoint: []string{"/bin/sh", "-c", "/bin/sleep 3 ; update-ca-certificates ; derper " + cmdArgs.String()}, + ExposedPorts: []string{ + "80/tcp", + fmt.Sprintf("%d/tcp", dsic.derpPort), + fmt.Sprintf("%d/udp", dsic.stunPort), + }, + } + + if dsic.workdir != "" { + runOptions.WorkingDir = dsic.workdir + } + + // dockertest isnt very good at handling containers that has already + // been created, this is an attempt to make sure this container isnt + // present. + err = pool.RemoveContainerByName(hostname) + if err != nil { + return nil, err + } + + var container *dockertest.Resource + buildOptions := &dockertest.BuildOptions{ + Dockerfile: "Dockerfile.derper", + ContextDir: dockerContextPath, + BuildArgs: []docker.BuildArg{}, + } + switch version { + case "head": + buildOptions.BuildArgs = append(buildOptions.BuildArgs, docker.BuildArg{ + Name: "VERSION_BRANCH", + Value: "main", + }) + default: + buildOptions.BuildArgs = append(buildOptions.BuildArgs, docker.BuildArg{ + Name: "VERSION_BRANCH", + Value: "v" + version, + }) + } + container, err = pool.BuildAndRunWithBuildOptions( + buildOptions, + runOptions, + dockertestutil.DockerRestartPolicy, + dockertestutil.DockerAllowLocalIPv6, + dockertestutil.DockerAllowNetworkAdministration, + ) + if err != nil { + return nil, fmt.Errorf( + "%s could not start tailscale DERPer container (version: %s): %w", + hostname, + version, + err, + ) + } + log.Printf("Created %s container\n", hostname) + + dsic.container = container + + for i, cert := range dsic.caCerts { + err = dsic.WriteFile(fmt.Sprintf("%s/user-%d.crt", caCertRoot, i), cert) + if err != nil { + return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err) + } + } + if len(dsic.tlsCert) != 0 { + err = dsic.WriteFile(fmt.Sprintf("%s/%s.crt", DERPerCertRoot, dsic.hostname), dsic.tlsCert) + if err != nil { + return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err) + } + } + if len(dsic.tlsKey) != 0 { + err = dsic.WriteFile(fmt.Sprintf("%s/%s.key", DERPerCertRoot, dsic.hostname), dsic.tlsKey) + if err != nil { + return nil, fmt.Errorf("failed to write TLS key to container: %w", err) + } + } + + return dsic, nil +} + +// Shutdown stops and cleans up the DERPer container. +func (t *DERPServerInContainer) Shutdown() error { + err := t.SaveLog("/tmp/control") + if err != nil { + log.Printf( + "Failed to save log from %s: %s", + t.hostname, + fmt.Errorf("failed to save log: %w", err), + ) + } + + return t.pool.Purge(t.container) +} + +// GetCert returns the TLS certificate of the DERPer instance. +func (t *DERPServerInContainer) GetCert() []byte { + return t.tlsCert +} + +// Hostname returns the hostname of the DERPer instance. +func (t *DERPServerInContainer) Hostname() string { + return t.hostname +} + +// Version returns the running DERPer version of the instance. +func (t *DERPServerInContainer) Version() string { + return t.version +} + +// ID returns the Docker container ID of the DERPServerInContainer +// instance. +func (t *DERPServerInContainer) ID() string { + return t.container.Container.ID +} + +func (t *DERPServerInContainer) GetHostname() string { + return t.hostname +} + +// GetSTUNPort returns the STUN port of the DERPer instance. +func (t *DERPServerInContainer) GetSTUNPort() int { + return t.stunPort +} + +// GetDERPPort returns the DERP port of the DERPer instance. +func (t *DERPServerInContainer) GetDERPPort() int { + return t.derpPort +} + +// WaitForRunning blocks until the DERPer instance is ready to be used. +func (t *DERPServerInContainer) WaitForRunning() error { + url := "https://" + net.JoinHostPort(t.GetHostname(), strconv.Itoa(t.GetDERPPort())) + "/" + log.Printf("waiting for DERPer to be ready at %s", url) + + insecureTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint + insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint + client := &http.Client{Transport: insecureTransport} + + return t.pool.Retry(func() error { + resp, err := client.Get(url) //nolint + if err != nil { + return fmt.Errorf("headscale is not ready: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return errDERPerStatusCodeNotOk + } + + return nil + }) +} + +// ConnectToNetwork connects the DERPer instance to a network. +func (t *DERPServerInContainer) ConnectToNetwork(network *dockertest.Network) error { + return t.container.ConnectToNetwork(network) +} + +// WriteFile save file inside the container. +func (t *DERPServerInContainer) WriteFile(path string, data []byte) error { + return integrationutil.WriteFileToContainer(t.pool, t.container, path, data) +} + +// SaveLog saves the current stdout log of the container to a path +// on the host system. +func (t *DERPServerInContainer) SaveLog(path string) error { + _, _, err := dockertestutil.SaveLog(t.pool, t.container, path) + + return err +} diff --git a/integration/embedded_derp_test.go b/integration/embedded_derp_test.go index 40dd648f79..c187f2ba80 100644 --- a/integration/embedded_derp_test.go +++ b/integration/embedded_derp_test.go @@ -55,7 +55,7 @@ func TestDERPServerWebsocketScenario(t *testing.T) { spec := map[string]ClientsSpec{ "user1": { Plain: 0, - WebsocketDERP: len(MustTestVersions), + WebsocketDERP: 2, }, } @@ -239,10 +239,13 @@ func (s *EmbeddedDERPServerScenario) CreateHeadscaleEnv( if clientCount.WebsocketDERP > 0 { // Containers that use DERP-over-WebSocket + // Note that these clients *must* be built + // from source, which is currently + // only done for HEAD. err = s.CreateTailscaleIsolatedNodesInUser( hash, userName, - "all", + tsic.VersionHead, clientCount.WebsocketDERP, tsic.WithWebsocketDERP(true), ) @@ -307,7 +310,7 @@ func (s *EmbeddedDERPServerScenario) CreateTailscaleIsolatedNodesInUser( cert := hsServer.GetCert() opts = append(opts, - tsic.WithHeadscaleTLS(cert), + tsic.WithCACert(cert), ) user.createWaitGroup.Go(func() error { diff --git a/integration/general_test.go b/integration/general_test.go index 3196302464..c7c962d160 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -18,6 +18,7 @@ import ( "github.com/rs/zerolog/log" "github.com/samber/lo" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" "tailscale.com/client/tailscale/apitype" "tailscale.com/types/key" @@ -244,7 +245,11 @@ func TestEphemeral(t *testing.T) { } func TestEphemeralInAlternateTimezone(t *testing.T) { - testEphemeralWithOptions(t, hsic.WithTestName("ephemeral-tz"), hsic.WithTimezone("America/Los_Angeles")) + testEphemeralWithOptions( + t, + hsic.WithTestName("ephemeral-tz"), + hsic.WithTimezone("America/Los_Angeles"), + ) } func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) { @@ -1164,10 +1169,10 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) { }, &nodeList, ) - assert.Nil(t, err) + require.NoError(t, err) assert.Len(t, nodeList, 2) - assert.True(t, nodeList[0].Online) - assert.True(t, nodeList[1].Online) + assert.True(t, nodeList[0].GetOnline()) + assert.True(t, nodeList[1].GetOnline()) // Delete the first node, which is online _, err = headscale.Execute( @@ -1177,13 +1182,13 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) { "delete", "--identifier", // Delete the last added machine - fmt.Sprintf("%d", nodeList[0].Id), + fmt.Sprintf("%d", nodeList[0].GetId()), "--output", "json", "--force", }, ) - assert.Nil(t, err) + require.NoError(t, err) time.Sleep(2 * time.Second) @@ -1200,9 +1205,8 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) { }, &nodeListAfter, ) - assert.Nil(t, err) + require.NoError(t, err) assert.Len(t, nodeListAfter, 1) - assert.True(t, nodeListAfter[0].Online) - assert.Equal(t, nodeList[1].Id, nodeListAfter[0].Id) - + assert.True(t, nodeListAfter[0].GetOnline()) + assert.Equal(t, nodeList[1].GetId(), nodeListAfter[0].GetId()) } diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 6e57c61b0b..b616ec1e6a 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -1,19 +1,12 @@ package hsic import ( - "bytes" - "crypto/rand" - "crypto/rsa" "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" "encoding/json" - "encoding/pem" "errors" "fmt" "io" "log" - "math/big" "net" "net/http" "net/url" @@ -32,15 +25,19 @@ import ( "github.com/juanfont/headscale/integration/integrationutil" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" + "gopkg.in/yaml.v3" + "tailscale.com/tailcfg" ) const ( - hsicHashLength = 6 - dockerContextPath = "../." - aclPolicyPath = "/etc/headscale/acl.hujson" - tlsCertPath = "/etc/headscale/tls.cert" - tlsKeyPath = "/etc/headscale/tls.key" - headscaleDefaultPort = 8080 + hsicHashLength = 6 + dockerContextPath = "../." + caCertRoot = "/usr/local/share/ca-certificates" + aclPolicyPath = "/etc/headscale/acl.hujson" + tlsCertPath = "/etc/headscale/tls.cert" + tlsKeyPath = "/etc/headscale/tls.key" + headscaleDefaultPort = 8080 + IntegrationTestDockerFileName = "Dockerfile.integration" ) var errHeadscaleStatusCodeNotOk = errors.New("headscale status code not ok") @@ -64,6 +61,7 @@ type HeadscaleInContainer struct { // optional config port int extraPorts []string + caCerts [][]byte hostPortBindings map[string][]string aclPolicy *policy.ACLPolicy env map[string]string @@ -81,6 +79,10 @@ type Option = func(c *HeadscaleInContainer) // HeadscaleInContainer instance. func WithACLPolicy(acl *policy.ACLPolicy) Option { return func(hsic *HeadscaleInContainer) { + if acl == nil { + return + } + // TODO(kradalby): Move somewhere appropriate hsic.env["HEADSCALE_POLICY_PATH"] = aclPolicyPath @@ -88,18 +90,29 @@ func WithACLPolicy(acl *policy.ACLPolicy) Option { } } +// WithCACert adds it to the trusted surtificate of the container. +func WithCACert(cert []byte) Option { + return func(hsic *HeadscaleInContainer) { + hsic.caCerts = append(hsic.caCerts, cert) + } +} + // WithTLS creates certificates and enables HTTPS. func WithTLS() Option { return func(hsic *HeadscaleInContainer) { - cert, key, err := createCertificate(hsic.hostname) + cert, key, err := integrationutil.CreateCertificate(hsic.hostname) if err != nil { log.Fatalf("failed to create certificates for headscale test: %s", err) } - // TODO(kradalby): Move somewhere appropriate - hsic.env["HEADSCALE_TLS_CERT_PATH"] = tlsCertPath - hsic.env["HEADSCALE_TLS_KEY_PATH"] = tlsKeyPath + hsic.tlsCert = cert + hsic.tlsKey = key + } +} +// WithCustomTLS uses the given certificates for the Headscale instance. +func WithCustomTLS(cert, key []byte) Option { + return func(hsic *HeadscaleInContainer) { hsic.tlsCert = cert hsic.tlsKey = key } @@ -146,6 +159,13 @@ func WithTestName(testName string) Option { } } +// WithHostname sets the hostname of the Headscale instance. +func WithHostname(hostname string) Option { + return func(hsic *HeadscaleInContainer) { + hsic.hostname = hostname + } +} + // WithHostnameAsServerURL sets the Headscale ServerURL based on // the Hostname. func WithHostnameAsServerURL() Option { @@ -203,6 +223,34 @@ func WithEmbeddedDERPServerOnly() Option { } } +// WithDERPConfig configures Headscale use a custom +// DERP server only. +func WithDERPConfig(derpMap tailcfg.DERPMap) Option { + return func(hsic *HeadscaleInContainer) { + contents, err := yaml.Marshal(derpMap) + if err != nil { + log.Fatalf("failed to marshal DERP map: %s", err) + + return + } + + hsic.env["HEADSCALE_DERP_PATHS"] = "/etc/headscale/derp.yml" + hsic.filesInContainer = append(hsic.filesInContainer, + fileInContainer{ + path: "/etc/headscale/derp.yml", + contents: contents, + }) + + // Disable global DERP server and embedded DERP server + hsic.env["HEADSCALE_DERP_URLS"] = "" + hsic.env["HEADSCALE_DERP_SERVER_ENABLED"] = "false" + + // Envknob for enabling DERP debug logs + hsic.env["DERP_DEBUG_LOGS"] = "true" + hsic.env["DERP_PROBER_DEBUG_LOGS"] = "true" + } +} + // WithTuning allows changing the tuning settings easily. func WithTuning(batchTimeout time.Duration, mapSessionChanSize int) Option { return func(hsic *HeadscaleInContainer) { @@ -267,7 +315,7 @@ func New( } headscaleBuildOptions := &dockertest.BuildOptions{ - Dockerfile: "Dockerfile.debug", + Dockerfile: IntegrationTestDockerFileName, ContextDir: dockerContextPath, } @@ -307,6 +355,10 @@ func New( "HEADSCALE_DEBUG_HIGH_CARDINALITY_METRICS=1", "HEADSCALE_DEBUG_DUMP_CONFIG=1", } + if hsic.hasTLS() { + hsic.env["HEADSCALE_TLS_CERT_PATH"] = tlsCertPath + hsic.env["HEADSCALE_TLS_KEY_PATH"] = tlsKeyPath + } for key, value := range hsic.env { env = append(env, fmt.Sprintf("%s=%s", key, value)) } @@ -320,7 +372,7 @@ func New( // Cmd: []string{"headscale", "serve"}, // TODO(kradalby): Get rid of this hack, we currently need to give us some // to inject the headscale configuration further down. - Entrypoint: []string{"/bin/bash", "-c", "/bin/sleep 3 ; headscale serve ; /bin/sleep 30"}, + Entrypoint: []string{"/bin/bash", "-c", "/bin/sleep 3 ; update-ca-certificates ; headscale serve ; /bin/sleep 30"}, Env: env, } @@ -358,6 +410,14 @@ func New( hsic.container = container + // Write the CA certificates to the container + for i, cert := range hsic.caCerts { + err = hsic.WriteFile(fmt.Sprintf("%s/user-%d.crt", caCertRoot, i), cert) + if err != nil { + return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err) + } + } + err = hsic.WriteFile("/etc/headscale/config.yaml", []byte(MinimumConfigYAML())) if err != nil { return nil, fmt.Errorf("failed to write headscale config to container: %w", err) @@ -761,86 +821,3 @@ func (t *HeadscaleInContainer) SendInterrupt() error { return nil } - -// nolint -func createCertificate(hostname string) ([]byte, []byte, error) { - // From: - // https://shaneutt.com/blog/golang-ca-and-signed-cert-go/ - - ca := &x509.Certificate{ - SerialNumber: big.NewInt(2019), - Subject: pkix.Name{ - Organization: []string{"Headscale testing INC"}, - Country: []string{"NL"}, - Locality: []string{"Leiden"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(60 * time.Hour), - IsCA: true, - ExtKeyUsage: []x509.ExtKeyUsage{ - x509.ExtKeyUsageClientAuth, - x509.ExtKeyUsageServerAuth, - }, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - BasicConstraintsValid: true, - } - - caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - return nil, nil, err - } - - cert := &x509.Certificate{ - SerialNumber: big.NewInt(1658), - Subject: pkix.Name{ - CommonName: hostname, - Organization: []string{"Headscale testing INC"}, - Country: []string{"NL"}, - Locality: []string{"Leiden"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(60 * time.Minute), - SubjectKeyId: []byte{1, 2, 3, 4, 6}, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature, - DNSNames: []string{hostname}, - } - - certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - return nil, nil, err - } - - certBytes, err := x509.CreateCertificate( - rand.Reader, - cert, - ca, - &certPrivKey.PublicKey, - caPrivKey, - ) - if err != nil { - return nil, nil, err - } - - certPEM := new(bytes.Buffer) - - err = pem.Encode(certPEM, &pem.Block{ - Type: "CERTIFICATE", - Bytes: certBytes, - }) - if err != nil { - return nil, nil, err - } - - certPrivKeyPEM := new(bytes.Buffer) - - err = pem.Encode(certPrivKeyPEM, &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey), - }) - if err != nil { - return nil, nil, err - } - - return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil -} diff --git a/integration/integrationutil/util.go b/integration/integrationutil/util.go index 59eeeb17b4..7b9b63b593 100644 --- a/integration/integrationutil/util.go +++ b/integration/integrationutil/util.go @@ -3,9 +3,16 @@ package integrationutil import ( "archive/tar" "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "fmt" "io" + "math/big" "path/filepath" + "time" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/ory/dockertest/v3" @@ -93,3 +100,86 @@ func FetchPathFromContainer( return buf.Bytes(), nil } + +// nolint +func CreateCertificate(hostname string) ([]byte, []byte, error) { + // From: + // https://shaneutt.com/blog/golang-ca-and-signed-cert-go/ + + ca := &x509.Certificate{ + SerialNumber: big.NewInt(2019), + Subject: pkix.Name{ + Organization: []string{"Headscale testing INC"}, + Country: []string{"NL"}, + Locality: []string{"Leiden"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(60 * time.Hour), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageClientAuth, + x509.ExtKeyUsageServerAuth, + }, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, nil, err + } + + cert := &x509.Certificate{ + SerialNumber: big.NewInt(1658), + Subject: pkix.Name{ + CommonName: hostname, + Organization: []string{"Headscale testing INC"}, + Country: []string{"NL"}, + Locality: []string{"Leiden"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(60 * time.Minute), + SubjectKeyId: []byte{1, 2, 3, 4, 6}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + DNSNames: []string{hostname}, + } + + certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, nil, err + } + + certBytes, err := x509.CreateCertificate( + rand.Reader, + cert, + ca, + &certPrivKey.PublicKey, + caPrivKey, + ) + if err != nil { + return nil, nil, err + } + + certPEM := new(bytes.Buffer) + + err = pem.Encode(certPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + }) + if err != nil { + return nil, nil, err + } + + certPrivKeyPEM := new(bytes.Buffer) + + err = pem.Encode(certPrivKeyPEM, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey), + }) + if err != nil { + return nil, nil, err + } + + return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil +} diff --git a/integration/route_test.go b/integration/route_test.go index f163fa14dd..644cc992fa 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -92,9 +92,9 @@ func TestEnablingRoutes(t *testing.T) { assert.Len(t, routes, 3) for _, route := range routes { - assert.Equal(t, true, route.GetAdvertised()) - assert.Equal(t, false, route.GetEnabled()) - assert.Equal(t, false, route.GetIsPrimary()) + assert.True(t, route.GetAdvertised()) + assert.False(t, route.GetEnabled()) + assert.False(t, route.GetIsPrimary()) } // Verify that no routes has been sent to the client, @@ -139,9 +139,9 @@ func TestEnablingRoutes(t *testing.T) { assert.Len(t, enablingRoutes, 3) for _, route := range enablingRoutes { - assert.Equal(t, true, route.GetAdvertised()) - assert.Equal(t, true, route.GetEnabled()) - assert.Equal(t, true, route.GetIsPrimary()) + assert.True(t, route.GetAdvertised()) + assert.True(t, route.GetEnabled()) + assert.True(t, route.GetIsPrimary()) } time.Sleep(5 * time.Second) @@ -212,18 +212,18 @@ func TestEnablingRoutes(t *testing.T) { assertNoErr(t, err) for _, route := range disablingRoutes { - assert.Equal(t, true, route.GetAdvertised()) + assert.True(t, route.GetAdvertised()) if route.GetId() == routeToBeDisabled.GetId() { - assert.Equal(t, false, route.GetEnabled()) + assert.False(t, route.GetEnabled()) // since this is the only route of this cidr, // it will not failover, and remain Primary // until something can replace it. - assert.Equal(t, true, route.GetIsPrimary()) + assert.True(t, route.GetIsPrimary()) } else { - assert.Equal(t, true, route.GetEnabled()) - assert.Equal(t, true, route.GetIsPrimary()) + assert.True(t, route.GetEnabled()) + assert.True(t, route.GetIsPrimary()) } } @@ -342,9 +342,9 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf("initial routes %#v", routes) for _, route := range routes { - assert.Equal(t, true, route.GetAdvertised()) - assert.Equal(t, false, route.GetEnabled()) - assert.Equal(t, false, route.GetIsPrimary()) + assert.True(t, route.GetAdvertised()) + assert.False(t, route.GetEnabled()) + assert.False(t, route.GetIsPrimary()) } // Verify that no routes has been sent to the client, @@ -391,14 +391,14 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.Len(t, enablingRoutes, 2) // Node 1 is primary - assert.Equal(t, true, enablingRoutes[0].GetAdvertised()) - assert.Equal(t, true, enablingRoutes[0].GetEnabled()) - assert.Equal(t, true, enablingRoutes[0].GetIsPrimary(), "both subnet routers are up, expected r1 to be primary") + assert.True(t, enablingRoutes[0].GetAdvertised()) + assert.True(t, enablingRoutes[0].GetEnabled()) + assert.True(t, enablingRoutes[0].GetIsPrimary(), "both subnet routers are up, expected r1 to be primary") // Node 2 is not primary - assert.Equal(t, true, enablingRoutes[1].GetAdvertised()) - assert.Equal(t, true, enablingRoutes[1].GetEnabled()) - assert.Equal(t, false, enablingRoutes[1].GetIsPrimary(), "both subnet routers are up, expected r2 to be non-primary") + assert.True(t, enablingRoutes[1].GetAdvertised()) + assert.True(t, enablingRoutes[1].GetEnabled()) + assert.False(t, enablingRoutes[1].GetIsPrimary(), "both subnet routers are up, expected r2 to be non-primary") // Verify that the client has routes from the primary machine srs1, err := subRouter1.Status() @@ -446,14 +446,14 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.Len(t, routesAfterMove, 2) // Node 1 is not primary - assert.Equal(t, true, routesAfterMove[0].GetAdvertised()) - assert.Equal(t, true, routesAfterMove[0].GetEnabled()) - assert.Equal(t, false, routesAfterMove[0].GetIsPrimary(), "r1 is down, expected r2 to be primary") + assert.True(t, routesAfterMove[0].GetAdvertised()) + assert.True(t, routesAfterMove[0].GetEnabled()) + assert.False(t, routesAfterMove[0].GetIsPrimary(), "r1 is down, expected r2 to be primary") // Node 2 is primary - assert.Equal(t, true, routesAfterMove[1].GetAdvertised()) - assert.Equal(t, true, routesAfterMove[1].GetEnabled()) - assert.Equal(t, true, routesAfterMove[1].GetIsPrimary(), "r1 is down, expected r2 to be primary") + assert.True(t, routesAfterMove[1].GetAdvertised()) + assert.True(t, routesAfterMove[1].GetEnabled()) + assert.True(t, routesAfterMove[1].GetIsPrimary(), "r1 is down, expected r2 to be primary") srs2, err = subRouter2.Status() @@ -501,16 +501,16 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.Len(t, routesAfterBothDown, 2) // Node 1 is not primary - assert.Equal(t, true, routesAfterBothDown[0].GetAdvertised()) - assert.Equal(t, true, routesAfterBothDown[0].GetEnabled()) - assert.Equal(t, false, routesAfterBothDown[0].GetIsPrimary(), "r1 and r2 is down, expected r2 to _still_ be primary") + assert.True(t, routesAfterBothDown[0].GetAdvertised()) + assert.True(t, routesAfterBothDown[0].GetEnabled()) + assert.False(t, routesAfterBothDown[0].GetIsPrimary(), "r1 and r2 is down, expected r2 to _still_ be primary") // Node 2 is primary // if the node goes down, but no other suitable route is // available, keep the last known good route. - assert.Equal(t, true, routesAfterBothDown[1].GetAdvertised()) - assert.Equal(t, true, routesAfterBothDown[1].GetEnabled()) - assert.Equal(t, true, routesAfterBothDown[1].GetIsPrimary(), "r1 and r2 is down, expected r2 to _still_ be primary") + assert.True(t, routesAfterBothDown[1].GetAdvertised()) + assert.True(t, routesAfterBothDown[1].GetEnabled()) + assert.True(t, routesAfterBothDown[1].GetIsPrimary(), "r1 and r2 is down, expected r2 to _still_ be primary") // TODO(kradalby): Check client status // Both are expected to be down @@ -560,14 +560,14 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.Len(t, routesAfter1Up, 2) // Node 1 is primary - assert.Equal(t, true, routesAfter1Up[0].GetAdvertised()) - assert.Equal(t, true, routesAfter1Up[0].GetEnabled()) - assert.Equal(t, true, routesAfter1Up[0].GetIsPrimary(), "r1 is back up, expected r1 to become be primary") + assert.True(t, routesAfter1Up[0].GetAdvertised()) + assert.True(t, routesAfter1Up[0].GetEnabled()) + assert.True(t, routesAfter1Up[0].GetIsPrimary(), "r1 is back up, expected r1 to become be primary") // Node 2 is not primary - assert.Equal(t, true, routesAfter1Up[1].GetAdvertised()) - assert.Equal(t, true, routesAfter1Up[1].GetEnabled()) - assert.Equal(t, false, routesAfter1Up[1].GetIsPrimary(), "r1 is back up, expected r1 to become be primary") + assert.True(t, routesAfter1Up[1].GetAdvertised()) + assert.True(t, routesAfter1Up[1].GetEnabled()) + assert.False(t, routesAfter1Up[1].GetIsPrimary(), "r1 is back up, expected r1 to become be primary") // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() @@ -614,14 +614,14 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.Len(t, routesAfter2Up, 2) // Node 1 is not primary - assert.Equal(t, true, routesAfter2Up[0].GetAdvertised()) - assert.Equal(t, true, routesAfter2Up[0].GetEnabled()) - assert.Equal(t, true, routesAfter2Up[0].GetIsPrimary(), "r1 and r2 is back up, expected r1 to _still_ be primary") + assert.True(t, routesAfter2Up[0].GetAdvertised()) + assert.True(t, routesAfter2Up[0].GetEnabled()) + assert.True(t, routesAfter2Up[0].GetIsPrimary(), "r1 and r2 is back up, expected r1 to _still_ be primary") // Node 2 is primary - assert.Equal(t, true, routesAfter2Up[1].GetAdvertised()) - assert.Equal(t, true, routesAfter2Up[1].GetEnabled()) - assert.Equal(t, false, routesAfter2Up[1].GetIsPrimary(), "r1 and r2 is back up, expected r1 to _still_ be primary") + assert.True(t, routesAfter2Up[1].GetAdvertised()) + assert.True(t, routesAfter2Up[1].GetEnabled()) + assert.False(t, routesAfter2Up[1].GetIsPrimary(), "r1 and r2 is back up, expected r1 to _still_ be primary") // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() @@ -677,14 +677,14 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf("routes after disabling r1 %#v", routesAfterDisabling1) // Node 1 is not primary - assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised()) - assert.Equal(t, false, routesAfterDisabling1[0].GetEnabled()) - assert.Equal(t, false, routesAfterDisabling1[0].GetIsPrimary()) + assert.True(t, routesAfterDisabling1[0].GetAdvertised()) + assert.False(t, routesAfterDisabling1[0].GetEnabled()) + assert.False(t, routesAfterDisabling1[0].GetIsPrimary()) // Node 2 is primary - assert.Equal(t, true, routesAfterDisabling1[1].GetAdvertised()) - assert.Equal(t, true, routesAfterDisabling1[1].GetEnabled()) - assert.Equal(t, true, routesAfterDisabling1[1].GetIsPrimary()) + assert.True(t, routesAfterDisabling1[1].GetAdvertised()) + assert.True(t, routesAfterDisabling1[1].GetEnabled()) + assert.True(t, routesAfterDisabling1[1].GetIsPrimary()) // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() @@ -735,14 +735,14 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.Len(t, routesAfterEnabling1, 2) // Node 1 is not primary - assert.Equal(t, true, routesAfterEnabling1[0].GetAdvertised()) - assert.Equal(t, true, routesAfterEnabling1[0].GetEnabled()) - assert.Equal(t, false, routesAfterEnabling1[0].GetIsPrimary()) + assert.True(t, routesAfterEnabling1[0].GetAdvertised()) + assert.True(t, routesAfterEnabling1[0].GetEnabled()) + assert.False(t, routesAfterEnabling1[0].GetIsPrimary()) // Node 2 is primary - assert.Equal(t, true, routesAfterEnabling1[1].GetAdvertised()) - assert.Equal(t, true, routesAfterEnabling1[1].GetEnabled()) - assert.Equal(t, true, routesAfterEnabling1[1].GetIsPrimary()) + assert.True(t, routesAfterEnabling1[1].GetAdvertised()) + assert.True(t, routesAfterEnabling1[1].GetEnabled()) + assert.True(t, routesAfterEnabling1[1].GetIsPrimary()) // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() @@ -795,9 +795,9 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf("routes after deleting r2 %#v", routesAfterDeleting2) // Node 1 is primary - assert.Equal(t, true, routesAfterDeleting2[0].GetAdvertised()) - assert.Equal(t, true, routesAfterDeleting2[0].GetEnabled()) - assert.Equal(t, true, routesAfterDeleting2[0].GetIsPrimary()) + assert.True(t, routesAfterDeleting2[0].GetAdvertised()) + assert.True(t, routesAfterDeleting2[0].GetEnabled()) + assert.True(t, routesAfterDeleting2[0].GetIsPrimary()) // Verify that the route is announced from subnet router 1 clientStatus, err = client.Status() @@ -893,9 +893,9 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) { assert.Len(t, routes, 1) // All routes should be auto approved and enabled - assert.Equal(t, true, routes[0].GetAdvertised()) - assert.Equal(t, true, routes[0].GetEnabled()) - assert.Equal(t, true, routes[0].GetIsPrimary()) + assert.True(t, routes[0].GetAdvertised()) + assert.True(t, routes[0].GetEnabled()) + assert.True(t, routes[0].GetIsPrimary()) // Stop advertising route command = []string{ @@ -924,9 +924,9 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) { assert.Len(t, notAdvertisedRoutes, 1) // Route is no longer advertised - assert.Equal(t, false, notAdvertisedRoutes[0].GetAdvertised()) - assert.Equal(t, false, notAdvertisedRoutes[0].GetEnabled()) - assert.Equal(t, true, notAdvertisedRoutes[0].GetIsPrimary()) + assert.False(t, notAdvertisedRoutes[0].GetAdvertised()) + assert.False(t, notAdvertisedRoutes[0].GetEnabled()) + assert.True(t, notAdvertisedRoutes[0].GetIsPrimary()) // Advertise route again command = []string{ @@ -955,9 +955,9 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) { assert.Len(t, reAdvertisedRoutes, 1) // All routes should be auto approved and enabled - assert.Equal(t, true, reAdvertisedRoutes[0].GetAdvertised()) - assert.Equal(t, true, reAdvertisedRoutes[0].GetEnabled()) - assert.Equal(t, true, reAdvertisedRoutes[0].GetIsPrimary()) + assert.True(t, reAdvertisedRoutes[0].GetAdvertised()) + assert.True(t, reAdvertisedRoutes[0].GetEnabled()) + assert.True(t, reAdvertisedRoutes[0].GetIsPrimary()) } func TestAutoApprovedSubRoute2068(t *testing.T) { @@ -1163,9 +1163,9 @@ func TestSubnetRouteACL(t *testing.T) { assert.Len(t, routes, 1) for _, route := range routes { - assert.Equal(t, true, route.GetAdvertised()) - assert.Equal(t, false, route.GetEnabled()) - assert.Equal(t, false, route.GetIsPrimary()) + assert.True(t, route.GetAdvertised()) + assert.False(t, route.GetEnabled()) + assert.False(t, route.GetIsPrimary()) } // Verify that no routes has been sent to the client, @@ -1212,9 +1212,9 @@ func TestSubnetRouteACL(t *testing.T) { assert.Len(t, enablingRoutes, 1) // Node 1 has active route - assert.Equal(t, true, enablingRoutes[0].GetAdvertised()) - assert.Equal(t, true, enablingRoutes[0].GetEnabled()) - assert.Equal(t, true, enablingRoutes[0].GetIsPrimary()) + assert.True(t, enablingRoutes[0].GetAdvertised()) + assert.True(t, enablingRoutes[0].GetEnabled()) + assert.True(t, enablingRoutes[0].GetIsPrimary()) // Verify that the client has routes from the primary machine srs1, _ := subRouter1.Status() diff --git a/integration/scenario.go b/integration/scenario.go index 65b36e504f..c40fd3665e 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -14,12 +14,14 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" + "github.com/juanfont/headscale/integration/dsic" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/ory/dockertest/v3" "github.com/puzpuzpuz/xsync/v3" "github.com/samber/lo" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" "tailscale.com/envknob" ) @@ -140,6 +142,7 @@ type Scenario struct { // TODO(kradalby): support multiple headcales for later, currently only // use one. controlServers *xsync.MapOf[string, ControlServer] + derpServers []*dsic.DERPServerInContainer users map[string]*User @@ -203,11 +206,11 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { if t != nil { stdout, err := os.ReadFile(stdoutPath) - assert.NoError(t, err) + require.NoError(t, err) assert.NotContains(t, string(stdout), "panic") stderr, err := os.ReadFile(stderrPath) - assert.NoError(t, err) + require.NoError(t, err) assert.NotContains(t, string(stderr), "panic") } @@ -224,6 +227,13 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { } } + for _, derp := range s.derpServers { + err := derp.Shutdown() + if err != nil { + log.Printf("failed to tear down derp server: %s", err) + } + } + if err := s.pool.RemoveNetwork(s.network); err != nil { log.Printf("failed to remove network: %s", err) } @@ -353,7 +363,7 @@ func (s *Scenario) CreateTailscaleNodesInUser( hostname := headscale.GetHostname() opts = append(opts, - tsic.WithHeadscaleTLS(cert), + tsic.WithCACert(cert), tsic.WithHeadscaleName(hostname), ) @@ -652,3 +662,20 @@ func (s *Scenario) WaitForTailscaleLogout() error { return nil } + +// CreateDERPServer creates a new DERP server in a container. +func (s *Scenario) CreateDERPServer(version string, opts ...dsic.Option) (*dsic.DERPServerInContainer, error) { + derp, err := dsic.New(s.pool, version, s.network, opts...) + if err != nil { + return nil, fmt.Errorf("failed to create DERP server: %w", err) + } + + err = derp.WaitForRunning() + if err != nil { + return nil, fmt.Errorf("failed to reach DERP server: %w", err) + } + + s.derpServers = append(s.derpServers, derp) + + return derp, nil +} diff --git a/integration/tailscale.go b/integration/tailscale.go index b93cb2feb8..7f5f177e61 100644 --- a/integration/tailscale.go +++ b/integration/tailscale.go @@ -30,6 +30,7 @@ type TailscaleClient interface { FQDN() (string, error) Status(...bool) (*ipnstate.Status, error) Netmap() (*netmap.NetworkMap, error) + DebugDERPRegion(region string) (*ipnstate.DebugDERPRegionReport, error) Netcheck() (*netcheck.Report, error) WaitForNeedsLogin() error WaitForNeedsApprove() error diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index b82b44f430..b37dcb449e 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -12,6 +12,7 @@ import ( "net/netip" "net/url" "os" + "reflect" "strconv" "strings" "time" @@ -32,7 +33,7 @@ const ( defaultPingTimeout = 300 * time.Millisecond defaultPingCount = 10 dockerContextPath = "../." - headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt" + caCertRoot = "/usr/local/share/ca-certificates" dockerExecuteTimeout = 60 * time.Second ) @@ -44,6 +45,11 @@ var ( errTailscaleCannotUpWithoutAuthkey = errors.New("cannot up without authkey") errTailscaleNotConnected = errors.New("tailscale not connected") errTailscaledNotReadyForLogin = errors.New("tailscaled not ready for login") + errInvalidClientConfig = errors.New("verifiably invalid client config requested") +) + +const ( + VersionHead = "head" ) func errTailscaleStatus(hostname string, err error) error { @@ -65,7 +71,7 @@ type TailscaleInContainer struct { fqdn string // optional config - headscaleCert []byte + caCerts [][]byte headscaleHostname string withWebsocketDERP bool withSSH bool @@ -74,17 +80,23 @@ type TailscaleInContainer struct { withExtraHosts []string workdir string netfilter string + + // build options, solely for HEAD + buildConfig TailscaleInContainerBuildConfig +} + +type TailscaleInContainerBuildConfig struct { + tags []string } // Option represent optional settings that can be given to a // Tailscale instance. type Option = func(c *TailscaleInContainer) -// WithHeadscaleTLS takes the certificate of the Headscale instance -// and adds it to the trusted surtificate of the Tailscale container. -func WithHeadscaleTLS(cert []byte) Option { +// WithCACert adds it to the trusted surtificate of the Tailscale container. +func WithCACert(cert []byte) Option { return func(tsic *TailscaleInContainer) { - tsic.headscaleCert = cert + tsic.caCerts = append(tsic.caCerts, cert) } } @@ -113,7 +125,7 @@ func WithOrCreateNetwork(network *dockertest.Network) Option { } // WithHeadscaleName set the name of the headscale instance, -// mostly useful in combination with TLS and WithHeadscaleTLS. +// mostly useful in combination with TLS and WithCACert. func WithHeadscaleName(hsName string) Option { return func(tsic *TailscaleInContainer) { tsic.headscaleHostname = hsName @@ -175,6 +187,22 @@ func WithNetfilter(state string) Option { } } +// WithBuildTag adds an additional value to the `-tags=` parameter +// of the Go compiler, allowing callers to customize the Tailscale client build. +// This option is only meaningful when invoked on **HEAD** versions of the client. +// Attempts to use it with any other version is a bug in the calling code. +func WithBuildTag(tag string) Option { + return func(tsic *TailscaleInContainer) { + if tsic.version != VersionHead { + panic(errInvalidClientConfig) + } + + tsic.buildConfig.tags = append( + tsic.buildConfig.tags, tag, + ) + } +} + // New returns a new TailscaleInContainer instance. func New( pool *dockertest.Pool, @@ -219,18 +247,20 @@ func New( } if tsic.withWebsocketDERP { + if version != VersionHead { + return tsic, errInvalidClientConfig + } + + WithBuildTag("ts_debug_websockets")(tsic) + tailscaleOptions.Env = append( tailscaleOptions.Env, fmt.Sprintf("TS_DEBUG_DERP_WS_CLIENT=%t", tsic.withWebsocketDERP), ) } - if tsic.headscaleHostname != "" { - tailscaleOptions.ExtraHosts = []string{ - "host.docker.internal:host-gateway", - fmt.Sprintf("%s:host-gateway", tsic.headscaleHostname), - } - } + tailscaleOptions.ExtraHosts = append(tailscaleOptions.ExtraHosts, + "host.docker.internal:host-gateway") if tsic.workdir != "" { tailscaleOptions.WorkingDir = tsic.workdir @@ -245,14 +275,36 @@ func New( } var container *dockertest.Resource + + if version != VersionHead { + // build options are not meaningful with pre-existing images, + // let's not lead anyone astray by pretending otherwise. + defaultBuildConfig := TailscaleInContainerBuildConfig{} + hasBuildConfig := !reflect.DeepEqual(defaultBuildConfig, tsic.buildConfig) + if hasBuildConfig { + return tsic, errInvalidClientConfig + } + } + switch version { - case "head": + case VersionHead: buildOptions := &dockertest.BuildOptions{ Dockerfile: "Dockerfile.tailscale-HEAD", ContextDir: dockerContextPath, BuildArgs: []docker.BuildArg{}, } + buildTags := strings.Join(tsic.buildConfig.tags, ",") + if len(buildTags) > 0 { + buildOptions.BuildArgs = append( + buildOptions.BuildArgs, + docker.BuildArg{ + Name: "BUILD_TAGS", + Value: buildTags, + }, + ) + } + container, err = pool.BuildAndRunWithBuildOptions( buildOptions, tailscaleOptions, @@ -294,8 +346,8 @@ func New( tsic.container = container - if tsic.hasTLS() { - err = tsic.WriteFile(headscaleCertPath, tsic.headscaleCert) + for i, cert := range tsic.caCerts { + err = tsic.WriteFile(fmt.Sprintf("%s/user-%d.crt", caCertRoot, i), cert) if err != nil { return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err) } @@ -304,10 +356,6 @@ func New( return tsic, nil } -func (t *TailscaleInContainer) hasTLS() bool { - return len(t.headscaleCert) != 0 -} - // Shutdown stops and cleans up the Tailscale container. func (t *TailscaleInContainer) Shutdown() error { err := t.SaveLog("/tmp/control") @@ -682,6 +730,34 @@ func (t *TailscaleInContainer) watchIPN(ctx context.Context) (*ipn.Notify, error } } +func (t *TailscaleInContainer) DebugDERPRegion(region string) (*ipnstate.DebugDERPRegionReport, error) { + if !util.TailscaleVersionNewerOrEqual("1.34", t.version) { + panic("tsic.DebugDERPRegion() called with unsupported version: " + t.version) + } + + command := []string{ + "tailscale", + "debug", + "derp", + region, + } + + result, stderr, err := t.Execute(command) + if err != nil { + fmt.Printf("stderr: %s\n", stderr) // nolint + + return nil, fmt.Errorf("failed to execute tailscale debug derp command: %w", err) + } + + var report ipnstate.DebugDERPRegionReport + err = json.Unmarshal([]byte(result), &report) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal tailscale derp region report: %w", err) + } + + return &report, err +} + // Netcheck returns the current Netcheck Report (netcheck.Report) of the Tailscale instance. func (t *TailscaleInContainer) Netcheck() (*netcheck.Report, error) { command := []string{ diff --git a/mkdocs.yml b/mkdocs.yml index 5ef54018bd..086c6c12e5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,5 +1,6 @@ +--- site_name: Headscale -site_url: https://juanfont.github.io/headscale +site_url: https://juanfont.github.io/headscale/ edit_uri: blob/main/docs/ # Change the master branch to main as we are using main as a main branch site_author: Headscale authors site_description: >-