diff --git a/api_common.go b/api_common.go index 7905c29dd1..d6190d08d4 100644 --- a/api_common.go +++ b/api_common.go @@ -15,7 +15,7 @@ func (h *Headscale) generateMapResponse( Str("func", "generateMapResponse"). Str("machine", mapRequest.Hostinfo.Hostname). Msg("Creating Map response") - node, err := h.toNode(*machine, h.cfg.BaseDomain, h.cfg.DNSConfig) + node, err := h.toNode(*machine, h.cfg.BaseDomain, h.cfg.RemoveUserFromTaggedDNS, h.cfg.DNSConfig) if err != nil { log.Error(). Caller(). @@ -39,7 +39,7 @@ func (h *Headscale) generateMapResponse( profiles := h.getMapResponseUserProfiles(*machine, peers) - nodePeers, err := h.toNodes(peers, h.cfg.BaseDomain, h.cfg.DNSConfig) + nodePeers, err := h.toNodes(peers, h.cfg.BaseDomain, h.cfg.RemoveUserFromTaggedDNS, h.cfg.DNSConfig) if err != nil { log.Error(). Caller(). diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index c7b332aac0..569398d414 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -140,12 +140,13 @@ func (*Suite) TestDNSConfigLoading(c *check.C) { err = headscale.LoadConfig(tmpDir, false) c.Assert(err, check.IsNil) - dnsConfig, baseDomain := headscale.GetDNSConfig() + dnsConfig, baseDomain, removeUserFromTaggedDNS := headscale.GetDNSConfig() c.Assert(dnsConfig.Nameservers[0].String(), check.Equals, "1.1.1.1") c.Assert(dnsConfig.Resolvers[0].Addr, check.Equals, "1.1.1.1") c.Assert(dnsConfig.Proxied, check.Equals, true) c.Assert(baseDomain, check.Equals, "example.com") + c.Assert(removeUserFromTaggedDNS, check.Equals, false) } func writeConfig(c *check.C, tmpDir string, configYaml []byte) { diff --git a/config-example.yaml b/config-example.yaml index 93aa797ac1..c34fcc59eb 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -254,6 +254,11 @@ dns_config: # `hostname.user.base_domain` (e.g., _myhost.myuser.example.com_). base_domain: example.com + # Defines if a tagged node should ditch the user in MagicDNS FQDN. + # Would result in `.` instead of + # `..` + remove_user_from_tagged_dns: false + # Unix socket used for the CLI to connect without authentication # Note: for production you will want to set this to something like: unix_socket: /var/run/headscale/headscale.sock diff --git a/config.go b/config.go index c0dd1c9864..de921a90b0 100644 --- a/config.go +++ b/config.go @@ -48,6 +48,7 @@ type Config struct { PrivateKeyPath string NoisePrivateKeyPath string BaseDomain string + RemoveUserFromTaggedDNS bool Log LogConfig DisableUpdateCheck bool @@ -373,7 +374,7 @@ func GetLogConfig() LogConfig { } } -func GetDNSConfig() (*tailcfg.DNSConfig, string) { +func GetDNSConfig() (*tailcfg.DNSConfig, string, bool) { if viper.IsSet("dns_config") { dnsConfig := &tailcfg.DNSConfig{} @@ -484,10 +485,17 @@ func GetDNSConfig() (*tailcfg.DNSConfig, string) { baseDomain = "headscale.net" // does not really matter when MagicDNS is not enabled } - return dnsConfig, baseDomain + var removeUserFromTaggedDNS bool + if viper.IsSet("dns_config.remove_user_from_tagged_dns") { + removeUserFromTaggedDNS = viper.GetBool("dns_config.remove_user_from_tagged_dns") + } else { + removeUserFromTaggedDNS = false // does not really matter when MagicDNS is not enabled + } + + return dnsConfig, baseDomain, removeUserFromTaggedDNS } - return nil, "" + return nil, "", false } func GetHeadscaleConfig() (*Config, error) { @@ -502,7 +510,7 @@ func GetHeadscaleConfig() (*Config, error) { }, nil } - dnsConfig, baseDomain := GetDNSConfig() + dnsConfig, baseDomain, removeUserFromTaggedDNS := GetDNSConfig() derpConfig := GetDERPConfig() logConfig := GetLogTailConfig() randomizeClientPort := viper.GetBool("randomize_client_port") @@ -567,7 +575,8 @@ func GetHeadscaleConfig() (*Config, error) { NoisePrivateKeyPath: AbsolutePathFromConfigPath( viper.GetString("noise.private_key_path"), ), - BaseDomain: baseDomain, + BaseDomain: baseDomain, + RemoveUserFromTaggedDNS: removeUserFromTaggedDNS, DERP: derpConfig, diff --git a/machine.go b/machine.go index 1b70b1e207..14e5ef8bb2 100644 --- a/machine.go +++ b/machine.go @@ -673,12 +673,13 @@ func (machines MachinesP) String() string { func (h *Headscale) toNodes( machines Machines, baseDomain string, + removeUserFromTaggedDNS bool, dnsConfig *tailcfg.DNSConfig, ) ([]*tailcfg.Node, error) { nodes := make([]*tailcfg.Node, len(machines)) for index, machine := range machines { - node, err := h.toNode(machine, baseDomain, dnsConfig) + node, err := h.toNode(machine, baseDomain, removeUserFromTaggedDNS, dnsConfig) if err != nil { return nil, err } @@ -694,6 +695,7 @@ func (h *Headscale) toNodes( func (h *Headscale) toNode( machine Machine, baseDomain string, + removeUserFromTaggedDNS bool, dnsConfig *tailcfg.DNSConfig, ) (*tailcfg.Node, error) { var nodeKey key.NodePublic @@ -770,14 +772,26 @@ func (h *Headscale) toNode( keyExpiry = time.Time{} } + tags, _ := getTags(h.aclPolicy, machine, h.cfg.OIDC.StripEmaildomain) + tags = lo.Uniq(append(tags, machine.ForcedTags...)) + var hostname string if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS - hostname = fmt.Sprintf( - "%s.%s.%s", - machine.GivenName, - machine.User.Name, - baseDomain, - ) + + if len(tags) > 0 && removeUserFromTaggedDNS { + hostname = fmt.Sprintf( + "%s.%s", + machine.GivenName, + baseDomain, + ) + } else { + hostname = fmt.Sprintf( + "%s.%s.%s", + machine.GivenName, + machine.User.Name, + baseDomain, + ) + } if len(hostname) > maxHostnameLength { return nil, fmt.Errorf( "hostname %q is too long it cannot except 255 ASCII chars: %w", @@ -793,9 +807,6 @@ func (h *Headscale) toNode( online := machine.isOnline() - tags, _ := getTags(h.aclPolicy, machine, h.cfg.OIDC.StripEmaildomain) - tags = lo.Uniq(append(tags, machine.ForcedTags...)) - node := tailcfg.Node{ ID: tailcfg.NodeID(machine.ID), // this is the actual ID StableID: tailcfg.StableNodeID( diff --git a/routes_test.go b/routes_test.go index b67b3ee937..313f54a432 100644 --- a/routes_test.go +++ b/routes_test.go @@ -439,7 +439,7 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 3) - peer, err := app.toNode(machine1, "headscale.net", nil) + peer, err := app.toNode(machine1, "headscale.net", false, nil) c.Assert(err, check.IsNil) c.Assert(len(peer.AllowedIPs), check.Equals, 3)