diff --git a/.lefthook.toml b/.lefthook.toml index 16e8ad0f..71d75854 100644 --- a/.lefthook.toml +++ b/.lefthook.toml @@ -11,7 +11,7 @@ tags = "go,lint" [pre-commit.commands.test] glob = "*.go" -run = "go test ./..." +run = "go test ./... -race" tags = "go,test" [prepare-commit-msg.commands.gitmoji] diff --git a/README.md b/README.md index 3b100f18..6bef9b20 100644 --- a/README.md +++ b/README.md @@ -29,17 +29,17 @@ dc:org: #dn: dc=org dc:example: #dn: dc=example,dc=org ou:group: #dn: ou=group,dc=example,dc=org cn:owner: &test #dn: cn=admin,ou=group,dc=example,dc=org - objectclass: posixGroup + objectClass: posixGroup gidNumber: 1000 description: Organization owners memberUid: [alice] cn:dev: #dn: cn=dev,ou=group,dc=example,dc=org - objectclass: posixGroup + objectClass: posixGroup gidNumber: 1001 description: Organization developers memberUid: [bob, charlie] cn:qa: #dn: cn=qa,ou=group,dc=example,dc=org - objectclass: posixGroup + objectClass: posixGroup gidNumber: 1002 memberUid: [charlie, eve] cn:ok: #dn: cn=ok,ou=group,dc=example,dc=org @@ -51,7 +51,7 @@ dc:org: #dn: dc=org c:global: #dn: c=global,dc=example,dc=org ou:people: #dn: ou=people,c=global,dc=example,dc=org cn:alice: #dn: cn=alice,ou=people,c=global,dc=example,dc=org - objectclass: [posixAccount, UserMail] + objectClass: [posixAccount, UserMail] .#acl: - !!ldap/acl:allow-on dc=org # allow alice to request everything @@ -65,7 +65,7 @@ dc:org: #dn: dc=org usermail: alice@example.org cn:bob: #dn: cn=bob,ou=people,c=global,dc=example,dc=org - objectclass: posixAccount + objectClass: posixAccount .#acl: - !!ldap/acl:allow-on ou=group,dc=example,dc=org # allow bob request only for user groups @@ -78,7 +78,7 @@ dc:org: #dn: dc=org c:fr: #dn: c=fr,dc=example,dc=org ou:people: #dn: ou=people,c=fr,dc=example,dc=org cn:charlie: #dn: cn=charlie,ou=people,c=fr,dc=example,dc=org - objectclass: posixAccount + objectClass: posixAccount .#acl: - !!ldap/acl:allow-on ou=group,dc=example,dc=org # allow charlie request for all groups... - !!ldap/acl:deny-on cn=admin,ou=group,dc=example,dc=org # ...but to owner group @@ -92,7 +92,7 @@ dc:org: #dn: dc=org c:uk: #dn: c=uk,dc=example,dc=org ou:people: #dn: ou=people,c=fr,dc=example,dc=org cn:eve: #dn: cn=eve,ou=people,c=uk,dc=example,dc=org - objectclass: posixAccount + objectClass: posixAccount #NOTE: eve can't make any LDAP request (no !!ldap/bind:password field) uid: eve homeDirectory: /home/eve diff --git a/go.mod b/go.mod index 24b9b324..8f25438b 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,9 @@ module github.com/xunleii/yaldap go 1.21 +// TODO: remove this once the following issue is resolved: https://github.com/jimlambrt/gldap/pull/58 +replace github.com/jimlambrt/gldap => github.com/xunleii/gldap v0.1.10 + require ( github.com/alecthomas/kong v0.8.1 github.com/go-asn1-ber/asn1-ber v1.5.5 @@ -13,6 +16,7 @@ require ( github.com/puzpuzpuz/xsync/v3 v3.0.2 github.com/stretchr/testify v1.8.4 golang.org/x/exp v0.0.0-20231214170342-aacd6d4b4611 + golang.org/x/sync v0.1.0 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index d0ce3e4f..22d34d20 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,8 @@ github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1F github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/xunleii/gldap v0.1.10 h1:/z/BGhttRhPgsNnYdxW/zaxsA1Rz71e8vtyXa1XZiG4= +github.com/xunleii/gldap v0.1.10/go.mod h1:wQXacI2If7+C8z/IaTIf6Sbb+tqgFoqzujN2AaGzyck= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= @@ -75,6 +77,7 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/internal/ldap/auth/auth.go b/internal/ldap/auth/auth.go index 7d5f63f1..d68b0acd 100644 --- a/internal/ldap/auth/auth.go +++ b/internal/ldap/auth/auth.go @@ -1,7 +1,9 @@ package auth import ( + "context" "fmt" + "sync" "time" xsync "github.com/puzpuzpuz/xsync/v3" @@ -13,15 +15,17 @@ type ( // allowed to perform operations. Sessions struct { reg *xsync.MapOf[int, *Session] + ttl time.Duration } // Session represents a single LDAP authenticated connection. Session struct { refreshable bool - ttl time.Duration expireAt time.Time obj ldap.Object + + sync sync.RWMutex } // SessionOption customize a authnConn before its registration on authConns. @@ -31,23 +35,39 @@ type ( Error struct{ error } ) -const defaultTTL = 5 * time.Minute - // NewSessions returns a new AuthnConns instance. -func NewSessions() *Sessions { - return &Sessions{ +func NewSessions(ctx context.Context, ttl time.Duration) *Sessions { + sessions := &Sessions{ reg: xsync.NewMapOf[int, *Session](), + ttl: ttl, } + + // Run the GC every TTL/2 + go func() { + ticker := time.NewTicker(ttl / 2) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + sessions.GC() + } + } + }() + + return sessions } // NewSession adds the given LDAP object the list of authenticated connections. func (sessions *Sessions) NewSession(id int, obj ldap.Object, opts ...SessionOption) error { - session := &Session{ttl: defaultTTL, obj: obj} + session := &Session{obj: obj} for _, opt := range opts { opt(session) } - session.expireAt = time.Now().Add(session.ttl) + session.expireAt = time.Now().Add(sessions.ttl) eobj := sessions.Session(id) if eobj != nil { @@ -58,6 +78,10 @@ func (sessions *Sessions) NewSession(id int, obj ldap.Object, opts ...SessionOpt return nil } +// Delete removes the given connection ID from the list of authenticated +// connections. +func (sessions *Sessions) Delete(id int) { sessions.reg.Delete(id) } + // Session returns the LDAP object if it is authenticated. Otherwise, if the // connection ID doesn't exist or as expired, it returns nil. // Furthermore, if the connection is expired, it is automatically removed @@ -72,24 +96,33 @@ func (sessions *Sessions) Session(id int) *Session { sessions.reg.Delete(id) return nil case session.refreshable: - session.expireAt = time.Now().Add(session.ttl) + session.sync.Lock() + session.expireAt = time.Now().Add(sessions.ttl) + session.sync.Unlock() } return session } +// GC removes all expired connections from the list of authenticated. +func (sessions Sessions) GC() { + sessions.reg.Range(func(key int, value *Session) bool { + value.sync.RLock() + defer value.sync.RUnlock() + if value.expireAt.Before(time.Now()) { + sessions.reg.Delete(key) + } + return true + }) +} + // Object returns the LDAP object associated with the given session. -func (session Session) Object() ldap.Object { +func (session *Session) Object() ldap.Object { return session.obj } -// AuthnRefreshable allows the given conn to have its expiration date increased +// WithRefreshable allows the given conn to have its expiration date increased // after each operation. -func AuthnRefreshable() SessionOption { +func WithRefreshable() SessionOption { return func(session *Session) { session.refreshable = true } } - -// AuthnTTL customizes the given conn TTL. -func AuthnTTL(ttl time.Duration) SessionOption { - return func(session *Session) { session.ttl = ttl } -} diff --git a/internal/ldap/auth/auth_test.go b/internal/ldap/auth/auth_test.go index dbc7c707..0d77915a 100644 --- a/internal/ldap/auth/auth_test.go +++ b/internal/ldap/auth/auth_test.go @@ -1,6 +1,7 @@ package auth import ( + "context" "sync" "testing" "time" @@ -20,7 +21,7 @@ func (o mockLDAPObject) Bind(string) bool { ret func (o mockLDAPObject) CanSearchOn(string) bool { return true } func TestSessions_NewSession(t *testing.T) { - sessions := NewSessions() + sessions := NewSessions(context.Background(), time.Millisecond) t.Run("AddSession", func(t *testing.T) { obj := &mockLDAPObject{} @@ -37,12 +38,6 @@ func TestSessions_NewSession(t *testing.T) { err := sessions.NewSession(0, &mockLDAPObject{}) assert.Error(t, err) }) - - _ = sessions.NewSession(1, &mockLDAPObject{}, AuthnTTL(0)) - t.Run("AddAlreadyExistingButExpiredSession", func(t *testing.T) { - err := sessions.NewSession(1, &mockLDAPObject{}) - assert.NoError(t, err) - }) } func TestSessions_NewSession_race(t *testing.T) { @@ -70,7 +65,7 @@ func TestSessions_NewSession_race(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sessions := NewSessions() + sessions := NewSessions(context.Background(), time.Millisecond) wg := sync.WaitGroup{} for i, objs := range tt.objs { @@ -91,8 +86,58 @@ func TestSessions_NewSession_race(t *testing.T) { } } +func TestSessions_Delete(t *testing.T) { + sessions := NewSessions(context.Background(), time.Millisecond) + _ = sessions.NewSession(0, &mockLDAPObject{}) + + t.Run("DeleteExistingSession", func(t *testing.T) { + sessions.Delete(0) + + _, exists := sessions.reg.Load(0) + assert.False(t, exists) + }) + + t.Run("DeleteNonExistingSession", func(t *testing.T) { + sessions.Delete(1) + + _, exists := sessions.reg.Load(1) + assert.False(t, exists) + }) +} + +func TestSessions_GC(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sessions := NewSessions(ctx, time.Millisecond) + + _ = sessions.NewSession(0, &mockLDAPObject{}) + t.Run("GCNonExpiredSession", func(t *testing.T) { + sessions.GC() + + _, exists := sessions.reg.Load(0) + assert.True(t, exists) + }) + + t.Run("GCExpiredSession", func(t *testing.T) { + assert.Eventually(t, func() bool { + sessions.GC() + + _, exists := sessions.reg.Load(0) + return !exists + }, time.Millisecond*5, time.Millisecond) + }) + + _ = sessions.NewSession(1, &mockLDAPObject{}) + t.Run("GCGoroutine", func(t *testing.T) { + assert.Eventually(t, func() bool { + _, exists := sessions.reg.Load(1) + return !exists + }, time.Millisecond*5, time.Millisecond) + }) +} + func TestSession_Session(t *testing.T) { - sessions := NewSessions() + sessions := NewSessions(context.Background(), time.Millisecond) obj := &mockLDAPObject{} _ = sessions.NewSession(0, obj) @@ -108,34 +153,37 @@ func TestSession_Session(t *testing.T) { assert.Nil(t, session) }) - _ = sessions.NewSession(1, &mockLDAPObject{}, AuthnTTL(0)) + _ = sessions.NewSession(1, &mockLDAPObject{}) t.Run("GetExpiredSession", func(t *testing.T) { - session := sessions.Session(1) - assert.Nil(t, session) + assert.Eventually(t, func() bool { + session := sessions.Session(1) + return assert.Nil(t, session) + }, time.Millisecond*5, time.Millisecond) }) - _ = sessions.NewSession(2, &mockLDAPObject{}, AuthnRefreshable()) + _ = sessions.NewSession(2, &mockLDAPObject{}) + _ = sessions.NewSession(3, &mockLDAPObject{}, WithRefreshable()) t.Run("GetRefreshableSession", func(t *testing.T) { - session, exists := sessions.reg.Load(0) + session, exists := sessions.reg.Load(2) require.True(t, exists) before := session.expireAt - sessions.Session(0) + sessions.Session(2) after := session.expireAt assert.False(t, after.After(before)) - session, exists = sessions.reg.Load(2) + session, exists = sessions.reg.Load(3) require.True(t, exists) before = session.expireAt - sessions.Session(2) + sessions.Session(3) after = session.expireAt assert.True(t, after.After(before)) }) } func TestSessions_Session_race(t *testing.T) { - sessions := NewSessions() + sessions := NewSessions(context.Background(), time.Millisecond) for i := 0; i < 5; i++ { _ = sessions.NewSession(i, &mockLDAPObject{}) } @@ -175,7 +223,7 @@ func TestSessions_Session_race(t *testing.T) { } func TestSessionRefreshable(t *testing.T) { - sessions := NewSessions() + sessions := NewSessions(context.Background(), time.Millisecond) _ = sessions.NewSession(0, &mockLDAPObject{}) session, exists := sessions.reg.Load(0) @@ -183,25 +231,9 @@ func TestSessionRefreshable(t *testing.T) { require.True(t, exists) require.False(t, session.refreshable) - _ = sessions.NewSession(1, &mockLDAPObject{}, AuthnRefreshable()) + _ = sessions.NewSession(1, &mockLDAPObject{}, WithRefreshable()) session, exists = sessions.reg.Load(1) require.True(t, exists) require.True(t, session.refreshable) } - -func TestSessionTTL(t *testing.T) { - sessions := NewSessions() - - _ = sessions.NewSession(0, &mockLDAPObject{}) - session, exists := sessions.reg.Load(0) - - require.True(t, exists) - require.Equal(t, defaultTTL, session.ttl) - - _ = sessions.NewSession(1, &mockLDAPObject{}, AuthnTTL(60*time.Second)) - session, exists = sessions.reg.Load(1) - - require.True(t, exists) - require.Equal(t, 60*time.Second, session.ttl) -} diff --git a/pkg/cmd/common.go b/pkg/cmd/common.go index f03f69d6..c373b13d 100644 --- a/pkg/cmd/common.go +++ b/pkg/cmd/common.go @@ -7,6 +7,7 @@ import ( "os" "github.com/alecthomas/kong" + "github.com/xunleii/yaldap/pkg/utils" ) type ( @@ -47,6 +48,8 @@ func (l *LogLevel) Decode(ctx *kong.DecodeContext) error { } switch level { + case "trace": + *l = LogLevel(utils.LevelTrace) case "debug": *l = LogLevel(slog.LevelDebug) case "info": @@ -56,7 +59,7 @@ func (l *LogLevel) Decode(ctx *kong.DecodeContext) error { case "error": *l = LogLevel(slog.LevelError) default: - return fmt.Errorf("invalid level '%s': only debug, info, warn and error are allowed", level) + return fmt.Errorf("invalid level '%s': only trace, debug, info, warn and error are allowed", level) } return nil } diff --git a/pkg/cmd/common_test.go b/pkg/cmd/common_test.go index 589af089..994019ad 100644 --- a/pkg/cmd/common_test.go +++ b/pkg/cmd/common_test.go @@ -8,6 +8,7 @@ import ( "github.com/alecthomas/kong" "github.com/stretchr/testify/assert" "github.com/xunleii/yaldap/pkg/cmd" + "github.com/xunleii/yaldap/pkg/utils" ) func TestLogger_Format(t *testing.T) { @@ -26,6 +27,7 @@ func TestLogger_Format(t *testing.T) { func TestLogger_Level(t *testing.T) { levels := map[string]slog.Level{ + "trace": utils.LevelTrace, "debug": slog.LevelDebug, "info": slog.LevelInfo, "warn": slog.LevelWarn, diff --git a/pkg/cmd/server.go b/pkg/cmd/server.go index 791b7c52..d7605414 100644 --- a/pkg/cmd/server.go +++ b/pkg/cmd/server.go @@ -1,23 +1,29 @@ package cmd import ( + "context" "crypto/tls" "crypto/x509" "encoding/pem" "fmt" + "os/signal" + "syscall" + "time" "github.com/alecthomas/kong" "github.com/jimlambrt/gldap" + "github.com/xunleii/yaldap/internal/ldap/auth" "github.com/xunleii/yaldap/pkg/ldap" "github.com/xunleii/yaldap/pkg/ldap/directory" yamldir "github.com/xunleii/yaldap/pkg/ldap/directory/yaml" "github.com/xunleii/yaldap/pkg/utils" + "golang.org/x/sync/errgroup" ) type Server struct { Base `embed:""` - AddrListen string `name:"addr-listen" help:"Address to listen on" default:":389"` + ListenAddr string `name:"listen-address" help:"Address to listen on" default:":389"` Backend struct { Name string `name:"name" help:"Backend which stores the data" enum:"yaml" required:"" placeholder:"BACKEND"` @@ -32,9 +38,11 @@ type Server struct { KeyFile string `name:"tls.key" help:"Path to the key file" optional:"" type:"filecontent" placeholder:"PATH"` } `embed:""` - Version bool `name:"version" help:"Print version information and exit"` + Version bool `name:"version" help:"Print version information and exit"` + SessionTTL time.Duration `name:"session-ttl" help:"Duration of a BIND session before it expires" default:"168h"` } +// Run starts the yaLDAP server using the configuration passed to the command. func (s Server) Run(_ *kong.Context) error { logger := s.Logger() @@ -55,12 +63,24 @@ func (s Server) Run(_ *kong.Context) error { return err } - err = server.Router(ldap.NewMux(logger, directory)) + ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) + + err = server.Router(ldap.NewMux(logger, directory, auth.NewSessions(ctx, s.SessionTTL))) if err != nil { return err } - return server.Run(s.AddrListen, gldap.WithTLSConfig(tlsConfig)) + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { + return server.Run(s.ListenAddr, gldap.WithTLSConfig(tlsConfig)) + }) + + // Graceful shutdown. + <-ctx.Done() + if err := server.Stop(); err != nil { + return err + } + return g.Wait() } func (s Server) NewDirectory() (directory.Directory, error) { diff --git a/pkg/cmd/server_test.go b/pkg/cmd/server_test.go index 95bdf575..4b27e73e 100644 --- a/pkg/cmd/server_test.go +++ b/pkg/cmd/server_test.go @@ -18,11 +18,12 @@ import ( func TestServer_Defaults(t *testing.T) { var actual, expected Server - expected.AddrListen = ":389" + expected.ListenAddr = ":389" expected.Base.Log.Format = "json" expected.Base.Log.Level = LogLevel(slog.LevelInfo) expected.Backend.Name = "yaml" expected.Backend.URL = "file://../ldap/directory/yaml/fixtures/basic.yaml" //nolint:goconst + expected.SessionTTL = 168 * time.Hour expected.TLS.Enable = false expected.TLS.MutualTLS = false @@ -32,10 +33,11 @@ func TestServer_Defaults(t *testing.T) { } func TestServer_YAML_Simple(t *testing.T) { - server := Server{AddrListen: fmt.Sprintf("localhost:%d", freePort(t))} + server := Server{ListenAddr: fmt.Sprintf("localhost:%d", freePort(t))} server.Base.Log.Format = "test" server.Backend.Name = "yaml" server.Backend.URL = "file://../ldap/directory/yaml/fixtures/basic.yaml" + server.SessionTTL = time.Hour go func() { assert.NoError(t, server.Run(nil)) }() @@ -43,7 +45,7 @@ func TestServer_YAML_Simple(t *testing.T) { require.Eventually(t, func() bool { var err error - client, err = ldap.DialURL(fmt.Sprintf("ldap://%s", server.AddrListen)) + client, err = ldap.DialURL(fmt.Sprintf("ldap://%s", server.ListenAddr)) return assert.NoError(t, err) }, 500*time.Millisecond, @@ -51,7 +53,7 @@ func TestServer_YAML_Simple(t *testing.T) { ) defer client.Close() - err := client.Bind("cn=alice,ou=people,c=global,dc=example,dc=org", "alice") + err := client.Bind("cn=alice,ou=people,c=fr,dc=example,dc=org", "alice") require.NoError(t, err) } @@ -60,10 +62,11 @@ func TestServer_YAML_WithTLS(t *testing.T) { cert, err := ca.NewKeyPair("localhost") require.NoError(t, err) - server := Server{AddrListen: fmt.Sprintf("localhost:%d", freePort(t))} + server := Server{ListenAddr: fmt.Sprintf("localhost:%d", freePort(t))} server.Base.Log.Format = "test" server.Backend.Name = "yaml" server.Backend.URL = "file://../ldap/directory/yaml/fixtures/basic.yaml" + server.SessionTTL = time.Hour server.TLS.Enable = true server.TLS.CAFile = string(ca.PublicKey()) server.TLS.CertFile = string(cert.PublicKey()) @@ -75,7 +78,7 @@ func TestServer_YAML_WithTLS(t *testing.T) { require.Eventually(t, func() bool { client, err = ldap.DialURL( - fmt.Sprintf("ldaps://%s", server.AddrListen), + fmt.Sprintf("ldaps://%s", server.ListenAddr), ldap.DialWithTLSConfig(&tls.Config{RootCAs: ca.CertPool()}), ) return assert.NoError(t, err) @@ -85,7 +88,7 @@ func TestServer_YAML_WithTLS(t *testing.T) { ) defer client.Close() - err = client.Bind("cn=alice,ou=people,c=global,dc=example,dc=org", "alice") + err = client.Bind("cn=alice,ou=people,c=fr,dc=example,dc=org", "alice") require.NoError(t, err) } @@ -94,10 +97,11 @@ func TestServer_YAML_WithMutualTLS(t *testing.T) { keypair, err := ca.NewKeyPair("localhost") require.NoError(t, err) - server := Server{AddrListen: fmt.Sprintf("localhost:%d", freePort(t))} + server := Server{ListenAddr: fmt.Sprintf("localhost:%d", freePort(t))} server.Base.Log.Format = "test" server.Backend.Name = "yaml" server.Backend.URL = "file://../ldap/directory/yaml/fixtures/basic.yaml" + server.SessionTTL = time.Hour server.TLS.Enable = true server.TLS.MutualTLS = true server.TLS.CAFile = string(ca.PublicKey()) @@ -117,7 +121,7 @@ func TestServer_YAML_WithMutualTLS(t *testing.T) { require.Eventually(t, func() bool { client, err = ldap.DialURL( - fmt.Sprintf("ldaps://%s", server.AddrListen), + fmt.Sprintf("ldaps://%s", server.ListenAddr), ldap.DialWithTLSConfig(&tls.Config{ RootCAs: ca.CertPool(), Certificates: []tls.Certificate{cert}, @@ -130,7 +134,7 @@ func TestServer_YAML_WithMutualTLS(t *testing.T) { ) defer client.Close() - err = client.Bind("cn=alice,ou=people,c=global,dc=example,dc=org", "alice") + err = client.Bind("cn=alice,ou=people,c=fr,dc=example,dc=org", "alice") require.NoError(t, err) } diff --git a/pkg/ldap/directory/common/types.go b/pkg/ldap/directory/common/types.go index 7ccf9821..b413d611 100644 --- a/pkg/ldap/directory/common/types.go +++ b/pkg/ldap/directory/common/types.go @@ -97,7 +97,7 @@ func (obj Object) CanSearchOn(dn string) bool { func (obj Object) search(scope gldap.Scope, filter *ber.Packet) (objects []ldap.Object, err error) { if match, err := filters.Match(&obj, filter); err != nil { return nil, err - } else if match { + } else if match && scope != gldap.SingleLevel { objects = append(objects, &obj) } diff --git a/pkg/ldap/directory/types.go b/pkg/ldap/directory/types.go index 3c18a219..83b87240 100644 --- a/pkg/ldap/directory/types.go +++ b/pkg/ldap/directory/types.go @@ -7,7 +7,9 @@ import ( type ( // Directory contains all current LDAP object tree, accessible using a base DN. Directory interface { - // BaseDN returns the LDAP object represented by the given DN. If no object found, it returns nil. + // BaseDN returns the LDAP object represented by the given DN. If no object found, + // it returns nil. + // If the given DN is empty, it returns the root object. BaseDN(dn string) Object } diff --git a/pkg/ldap/directory/yaml/README.md b/pkg/ldap/directory/yaml/README.md index 48d64810..0716823b 100644 --- a/pkg/ldap/directory/yaml/README.md +++ b/pkg/ldap/directory/yaml/README.md @@ -56,17 +56,17 @@ dc:org: #dn: dc=org dc:example: #dn: dc=example,dc=org ou:group: #dn: ou=group,dc=example,dc=org cn:owner: &test #dn: cn=admin,ou=group,dc=example,dc=org - objectclass: posixGroup + objectClass: posixGroup gidNumber: 1000 description: Organization owners memberUid: [alice] cn:dev: #dn: cn=dev,ou=group,dc=example,dc=org - objectclass: posixGroup + objectClass: posixGroup gidNumber: 1001 description: Organization developers memberUid: [bob, charlie] cn:qa: #dn: cn=qa,ou=group,dc=example,dc=org - objectclass: posixGroup + objectClass: posixGroup gidNumber: 1002 memberUid: [charlie, eve] cn:ok: #dn: cn=ok,ou=group,dc=example,dc=org @@ -78,7 +78,7 @@ dc:org: #dn: dc=org c:global: #dn: c=global,dc=example,dc=org ou:people: #dn: ou=people,c=global,dc=example,dc=org cn:alice: #dn: cn=alice,ou=people,c=global,dc=example,dc=org - objectclass: [posixAccount, UserMail] + objectClass: [posixAccount, UserMail] .#acl: - !!ldap/acl:allow-on dc=org # allow alice to request everything @@ -92,7 +92,7 @@ dc:org: #dn: dc=org usermail: alice@example.org cn:bob: #dn: cn=bob,ou=people,c=global,dc=example,dc=org - objectclass: posixAccount + objectClass: posixAccount .#acl: - !!ldap/acl:allow-on ou=group,dc=example,dc=org # allow bob request only for user groups @@ -105,7 +105,7 @@ dc:org: #dn: dc=org c:fr: #dn: c=fr,dc=example,dc=org ou:people: #dn: ou=people,c=fr,dc=example,dc=org cn:charlie: #dn: cn=charlie,ou=people,c=fr,dc=example,dc=org - objectclass: posixAccount + objectClass: posixAccount .#acl: - !!ldap/acl:allow-on ou=group,dc=example,dc=org # allow charlie request for all groups... - !!ldap/acl:deny-on cn=admin,ou=group,dc=example,dc=org # ...but to owner group @@ -119,7 +119,7 @@ dc:org: #dn: dc=org c:uk: #dn: c=uk,dc=example,dc=org ou:people: #dn: ou=people,c=fr,dc=example,dc=org cn:eve: #dn: cn=eve,ou=people,c=uk,dc=example,dc=org - objectclass: posixAccount + objectClass: posixAccount #NOTE: eve can't make any LDAP request (no !!ldap/bind:password field) uid: eve homeDirectory: /home/eve diff --git a/pkg/ldap/directory/yaml/directory.go b/pkg/ldap/directory/yaml/directory.go index 8e445772..a7f84de2 100644 --- a/pkg/ldap/directory/yaml/directory.go +++ b/pkg/ldap/directory/yaml/directory.go @@ -33,6 +33,7 @@ func NewDirectoryFromYAML(raw []byte) (ldap.Directory, error) { directory := &directory{ entries: &common.Object{ ImplObject: common.ImplObject{ + Attributes: ldap.Attributes{"objectClass": {"top", "yaLDAPRootDSE"}}, SubObjects: map[string]*common.Object{}, }, }, @@ -70,9 +71,9 @@ func NewDirectoryFromYAML(raw []byte) (ldap.Directory, error) { switch value.Kind { case yaml.MappingNode: - err = parseLDAPObject(directory.entries, key.Value, value) + err = parseLDAPObject(directory.entries, key, value) case yaml.SequenceNode, yaml.ScalarNode: - err = parseLDAPAttribute(directory.entries, key.Value, value) + err = parseLDAPAttribute(directory.entries, key, value) } if err != nil { @@ -97,6 +98,10 @@ func indexDirectory(obj *common.Object, index map[string]*common.Object) { } func (d directory) BaseDN(dn string) ldap.Object { + if dn == "" { + return d.entries + } + obj, found := d.index[dn] if !found { return nil diff --git a/pkg/ldap/directory/yaml/directory_fixtures_test.go b/pkg/ldap/directory/yaml/directory_fixtures_test.go index 87282da7..a706b945 100644 --- a/pkg/ldap/directory/yaml/directory_fixtures_test.go +++ b/pkg/ldap/directory/yaml/directory_fixtures_test.go @@ -18,7 +18,8 @@ func TestFixture_Basic(t *testing.T) { require.NotNil(t, obj) assert.Equal(t, ldap.Attributes{ - "dc": {"org"}, + "dc": {"org"}, + "objectClass": {"top", "domain"}, }, obj.Attributes(), ) @@ -29,7 +30,8 @@ func TestFixture_Basic(t *testing.T) { require.NotNil(t, obj) assert.Equal(t, ldap.Attributes{ - "dc": {"example"}, + "dc": {"example"}, + "objectClass": {"domain"}, }, obj.Attributes(), ) @@ -40,7 +42,8 @@ func TestFixture_Basic(t *testing.T) { require.NotNil(t, obj) assert.Equal(t, ldap.Attributes{ - "ou": {"group"}, + "ou": {"group"}, + "objectClass": {"top"}, }, obj.Attributes(), ) @@ -52,7 +55,7 @@ func TestFixture_Basic(t *testing.T) { assert.Equal(t, ldap.Attributes{ "cn": {"owner"}, - "objectclass": {"posixGroup"}, + "objectClass": {"posixGroup"}, "gidNumber": {"1000"}, "description": {"Organization owners"}, "memberUid": {"alice"}, @@ -67,7 +70,7 @@ func TestFixture_Basic(t *testing.T) { assert.Equal(t, ldap.Attributes{ "cn": {"dev"}, - "objectclass": {"posixGroup"}, + "objectClass": {"posixGroup"}, "gidNumber": {"1001"}, "description": {"Organization developers"}, "memberUid": {"bob", "charlie"}, @@ -82,7 +85,7 @@ func TestFixture_Basic(t *testing.T) { assert.Equal(t, ldap.Attributes{ "cn": {"qa"}, - "objectclass": {"posixGroup"}, + "objectClass": {"posixGroup"}, "gidNumber": {"1002"}, "memberUid": {"charlie", "eve"}, }, @@ -96,7 +99,7 @@ func TestFixture_Basic(t *testing.T) { assert.Equal(t, ldap.Attributes{ "cn": {"ok"}, - "objectclass": {"posixGroup"}, + "objectClass": {"posixGroup"}, "gidNumber": {"1003"}, "description": {"Dummy group"}, "memberUid": {"alice"}, @@ -106,34 +109,36 @@ func TestFixture_Basic(t *testing.T) { }) t.Run("cn=admin,ou=group,dc=example,dc=org", func(t *testing.T) { - obj := directory.BaseDN("c=global,dc=example,dc=org") + obj := directory.BaseDN("c=fr,dc=example,dc=org") require.NotNil(t, obj) assert.Equal(t, ldap.Attributes{ - "c": {"global"}, + "c": {"fr"}, + "objectClass": {"top", "country"}, }, obj.Attributes(), ) }) - t.Run("ou=people,c=global,dc=example,dc=org", func(t *testing.T) { - obj := directory.BaseDN("ou=people,c=global,dc=example,dc=org") + t.Run("ou=people,c=fr,dc=example,dc=org", func(t *testing.T) { + obj := directory.BaseDN("ou=people,c=fr,dc=example,dc=org") require.NotNil(t, obj) assert.Equal(t, ldap.Attributes{ - "ou": {"people"}, + "ou": {"people"}, + "objectClass": {"top"}, }, obj.Attributes(), ) }) - t.Run("cn=alice,ou=people,c=global,dc=example,dc=org", func(t *testing.T) { - obj := directory.BaseDN("cn=alice,ou=people,c=global,dc=example,dc=org") + t.Run("cn=alice,ou=people,c=fr,dc=example,dc=org", func(t *testing.T) { + obj := directory.BaseDN("cn=alice,ou=people,c=fr,dc=example,dc=org") require.NotNil(t, obj) assert.Equal(t, ldap.Attributes{ "cn": {"alice"}, - "objectclass": {"posixAccount", "UserMail"}, + "objectClass": {"posixAccount"}, "description": {"Main organization admin"}, "uid": {"alice"}, "uidNumber": {"1000"}, @@ -149,13 +154,13 @@ func TestFixture_Basic(t *testing.T) { assert.True(t, obj.CanSearchOn("dc=org")) }) - t.Run("cn=bob,ou=people,c=global,dc=example,dc=org", func(t *testing.T) { - obj := directory.BaseDN("cn=bob,ou=people,c=global,dc=example,dc=org") + t.Run("cn=bob,ou=people,c=fr,dc=example,dc=org", func(t *testing.T) { + obj := directory.BaseDN("cn=bob,ou=people,c=fr,dc=example,dc=org") require.NotNil(t, obj) assert.Equal(t, ldap.Attributes{ "cn": {"bob"}, - "objectclass": {"posixAccount"}, + "objectClass": {"posixAccount"}, "uid": {"bob"}, "homeDirectory": {"/home/bob"}, "uidNumber": {"1001"}, @@ -169,12 +174,13 @@ func TestFixture_Basic(t *testing.T) { assert.True(t, obj.CanSearchOn("ou=group,dc=example,dc=org")) }) - t.Run("cn=charlie,ou=people,c=global,dc=example,dc=org", func(t *testing.T) { + t.Run("cn=charlie,ou=people,c=fr,dc=example,dc=org", func(t *testing.T) { obj := directory.BaseDN("c=fr,dc=example,dc=org") require.NotNil(t, obj) assert.Equal(t, ldap.Attributes{ - "c": {"fr"}, + "c": {"fr"}, + "objectClass": {"top", "country"}, }, obj.Attributes(), ) @@ -185,19 +191,20 @@ func TestFixture_Basic(t *testing.T) { require.NotNil(t, obj) assert.Equal(t, ldap.Attributes{ - "ou": {"people"}, + "ou": {"people"}, + "objectClass": {"top"}, }, obj.Attributes(), ) }) - t.Run("cn=charlie,ou=people,c=fr,dc=example,dc=org", func(t *testing.T) { - obj := directory.BaseDN("cn=charlie,ou=people,c=fr,dc=example,dc=org") + t.Run("cn=charlie,ou=people,c=de,dc=example,dc=org", func(t *testing.T) { + obj := directory.BaseDN("cn=charlie,ou=people,c=de,dc=example,dc=org") require.NotNil(t, obj) assert.Equal(t, ldap.Attributes{ "cn": {"charlie"}, - "objectclass": {"posixAccount"}, + "objectClass": {"posixAccount"}, "uid": {"charlie"}, "homeDirectory": {"/home/charlie"}, "uidNumber": {"1100"}, @@ -209,9 +216,7 @@ func TestFixture_Basic(t *testing.T) { assert.True(t, obj.Bind("charlie")) assert.False(t, obj.CanSearchOn("dc=org")) assert.True(t, obj.CanSearchOn("ou=group,dc=example,dc=org")) - x := obj.CanSearchOn("cn=admin,ou=group,dc=example,dc=org") - _ = x - assert.False(t, obj.CanSearchOn("cn=admin,ou=group,dc=example,dc=org")) + assert.False(t, obj.CanSearchOn("cn=owner,ou=group,dc=example,dc=org")) }) t.Run("c=uk,dc=example,dc=org", func(t *testing.T) { @@ -219,7 +224,8 @@ func TestFixture_Basic(t *testing.T) { require.NotNil(t, obj) assert.Equal(t, ldap.Attributes{ - "c": {"uk"}, + "c": {"uk"}, + "objectClass": {"top", "country"}, }, obj.Attributes(), ) @@ -230,7 +236,8 @@ func TestFixture_Basic(t *testing.T) { require.NotNil(t, obj) assert.Equal(t, ldap.Attributes{ - "ou": {"people"}, + "ou": {"people"}, + "objectClass": {"top"}, }, obj.Attributes(), ) @@ -242,7 +249,7 @@ func TestFixture_Basic(t *testing.T) { assert.Equal(t, ldap.Attributes{ "cn": {"eve"}, - "objectclass": {"posixAccount"}, + "objectClass": {"posixAccount"}, "uid": {"eve"}, "homeDirectory": {"/home/eve"}, "uidNumber": {"1003"}, diff --git a/pkg/ldap/directory/yaml/directory_test.go b/pkg/ldap/directory/yaml/directory_test.go index 12a61d6a..06e73496 100644 --- a/pkg/ldap/directory/yaml/directory_test.go +++ b/pkg/ldap/directory/yaml/directory_test.go @@ -25,19 +25,20 @@ ou:people: expected := &directory{ entries: &common.Object{ ImplObject: common.ImplObject{ + Attributes: ldap.Attributes{"objectClass": {"top", "yaLDAPRootDSE"}}, SubObjects: map[string]*common.Object{ "ou:people": { ImplObject: common.ImplObject{ DN: "ou=people", - Attributes: ldap.Attributes{"ou": []string{"people"}}, + Attributes: ldap.Attributes{"ou": {"people"}}, SubObjects: map[string]*common.Object{ "uid:alice": { ImplObject: common.ImplObject{ DN: "uid=alice,ou=people", Attributes: ldap.Attributes{ - "uid": []string{"alice"}, - "memberOf": []string{"admin", "user", "h4ck3r"}, - "givenname": []string{"alice"}, + "uid": {"alice"}, + "memberOf": {"admin", "user", "h4ck3r"}, + "givenname": {"alice"}, }, SubObjects: map[string]*common.Object{}, }, @@ -52,15 +53,15 @@ ou:people: "ou=people": { ImplObject: common.ImplObject{ DN: "ou=people", - Attributes: ldap.Attributes{"ou": []string{"people"}}, + Attributes: ldap.Attributes{"ou": {"people"}}, SubObjects: map[string]*common.Object{ "uid:alice": { ImplObject: common.ImplObject{ DN: "uid=alice,ou=people", Attributes: ldap.Attributes{ - "uid": []string{"alice"}, - "memberOf": []string{"admin", "user", "h4ck3r"}, - "givenname": []string{"alice"}, + "uid": {"alice"}, + "memberOf": {"admin", "user", "h4ck3r"}, + "givenname": {"alice"}, }, SubObjects: map[string]*common.Object{}, }, @@ -72,9 +73,9 @@ ou:people: ImplObject: common.ImplObject{ DN: "uid=alice,ou=people", Attributes: ldap.Attributes{ - "uid": []string{"alice"}, - "memberOf": []string{"admin", "user", "h4ck3r"}, - "givenname": []string{"alice"}, + "uid": {"alice"}, + "memberOf": {"admin", "user", "h4ck3r"}, + "givenname": {"alice"}, }, SubObjects: map[string]*common.Object{}, }, @@ -87,3 +88,81 @@ ou:people: assert.NoError(t, err) assert.Equal(t, expected, directory) } + +func TestDirectory_BaseDN(t *testing.T) { + raw := []byte(` +ou:people: + uid:alice: {} +`) + directory, err := NewDirectoryFromYAML(raw) + assert.NoError(t, err) + + t.Run("ou=people", func(t *testing.T) { + actual := directory.BaseDN("ou=people") + expected := &common.Object{ + ImplObject: common.ImplObject{ + DN: "ou=people", + Attributes: ldap.Attributes{"ou": {"people"}}, + SubObjects: map[string]*common.Object{ + "uid:alice": { + ImplObject: common.ImplObject{ + DN: "uid=alice,ou=people", + Attributes: ldap.Attributes{"uid": {"alice"}}, + SubObjects: map[string]*common.Object{}, + }, + }, + }, + }, + } + + assert.Equal(t, expected, actual) + }) + + t.Run("uid=alice,ou=people", func(t *testing.T) { + actual := directory.BaseDN("uid=alice,ou=people") + expected := &common.Object{ + ImplObject: common.ImplObject{ + DN: "uid=alice,ou=people", + Attributes: ldap.Attributes{"uid": {"alice"}}, + SubObjects: map[string]*common.Object{}, + }, + } + + assert.Equal(t, expected, actual) + }) + + t.Run("empty DN", func(t *testing.T) { + actual := directory.BaseDN("") + expected := &common.Object{ + ImplObject: common.ImplObject{ + DN: "", + Attributes: ldap.Attributes{"objectClass": {"top", "yaLDAPRootDSE"}}, + SubObjects: map[string]*common.Object{ + "ou:people": { + ImplObject: common.ImplObject{ + DN: "ou=people", + Attributes: ldap.Attributes{"ou": {"people"}}, + SubObjects: map[string]*common.Object{ + "uid:alice": { + ImplObject: common.ImplObject{ + DN: "uid=alice,ou=people", + Attributes: ldap.Attributes{"uid": {"alice"}}, + SubObjects: map[string]*common.Object{}, + }, + }, + }, + }, + }, + }, + }, + } + + assert.Equal(t, expected, actual) + }) + + t.Run("DN not found", func(t *testing.T) { + actual := directory.BaseDN("cn=does-not-exist") + + assert.Nil(t, actual) + }) +} diff --git a/pkg/ldap/directory/yaml/fixtures/basic.yaml b/pkg/ldap/directory/yaml/fixtures/basic.yaml index cd283d24..698182f2 100644 --- a/pkg/ldap/directory/yaml/fixtures/basic.yaml +++ b/pkg/ldap/directory/yaml/fixtures/basic.yaml @@ -1,18 +1,24 @@ dc:org: #dn: dc=org + objectClass: [top, domain] + dc:example: #dn: dc=example,dc=org + objectClass: [domain] + ou:group: #dn: ou=group,dc=example,dc=org - cn:owner: &test #dn: cn=admin,ou=group,dc=example,dc=org - objectclass: posixGroup + objectClass: [top] + + cn:owner: &test #dn: cn=owner,ou=group,dc=example,dc=org + objectClass: posixGroup gidNumber: 1000 description: Organization owners memberUid: [alice] cn:dev: #dn: cn=dev,ou=group,dc=example,dc=org - objectclass: posixGroup + objectClass: posixGroup gidNumber: 1001 description: Organization developers memberUid: [bob, charlie] cn:qa: #dn: cn=qa,ou=group,dc=example,dc=org - objectclass: posixGroup + objectClass: posixGroup gidNumber: 1002 memberUid: [charlie, eve] cn:ok: #dn: cn=ok,ou=group,dc=example,dc=org @@ -21,10 +27,14 @@ dc:org: #dn: dc=org description: Dummy group # memberUid: [alice] - c:global: #dn: c=global,dc=example,dc=org - ou:people: #dn: ou=people,c=global,dc=example,dc=org - cn:alice: #dn: cn=alice,ou=people,c=global,dc=example,dc=org - objectclass: [posixAccount, UserMail] + c:fr: #dn: c=fr,dc=example,dc=org + objectClass: [top, country] + + ou:people: #dn: ou=people,c=fr,dc=example,dc=org + objectClass: [top] + + cn:alice: #dn: cn=alice,ou=people,c=fr,dc=example,dc=org + objectClass: posixAccount .#acl: - !!ldap/acl:allow-on dc=org # allow alice to request everything @@ -37,8 +47,8 @@ dc:org: #dn: dc=org userPassword: !!ldap/bind:password alice usermail: alice@example.org - cn:bob: #dn: cn=bob,ou=people,c=global,dc=example,dc=org - objectclass: posixAccount + cn:bob: #dn: cn=bob,ou=people,c=fr,dc=example,dc=org + objectClass: posixAccount .#acl: - !!ldap/acl:allow-on ou=group,dc=example,dc=org # allow bob request only for user groups @@ -48,13 +58,17 @@ dc:org: #dn: dc=org gidNumber: 1001 userPassword: !!ldap/bind:password bob - c:fr: #dn: c=fr,dc=example,dc=org - ou:people: #dn: ou=people,c=fr,dc=example,dc=org - cn:charlie: #dn: cn=charlie,ou=people,c=fr,dc=example,dc=org - objectclass: posixAccount + c:de: #dn: c=de,dc=example,dc=org + objectClass: [top, country] + + ou:people: #dn: ou=people,c=de,dc=example,dc=org + objectClass: [top] + + cn:charlie: #dn: cn=charlie,ou=people,c=de,dc=example,dc=org + objectClass: posixAccount .#acl: - !!ldap/acl:allow-on ou=group,dc=example,dc=org # allow charlie request for all groups... - - !!ldap/acl:deny-on cn=admin,ou=group,dc=example,dc=org # ...but to owner group + - !!ldap/acl:deny-on cn=owner,ou=group,dc=example,dc=org # ...but to owner group uid: charlie homeDirectory: /home/charlie @@ -63,12 +77,16 @@ dc:org: #dn: dc=org userPassword: !!ldap/bind:password charlie c:uk: #dn: c=uk,dc=example,dc=org + objectClass: [top, country] + ou:people: #dn: ou=people,c=fr,dc=example,dc=org + objectClass: [top] + cn:eve: #dn: cn=eve,ou=people,c=uk,dc=example,dc=org - objectclass: posixAccount + objectClass: posixAccount #NOTE: eve can't make any LDAP request (no !!ldap/bind:password field) uid: eve homeDirectory: /home/eve uidNumber: 1003 gidNumber: 1002 - userPassword: eve \ No newline at end of file + userPassword: eve diff --git a/pkg/ldap/directory/yaml/fixtures/rfc.yaml b/pkg/ldap/directory/yaml/fixtures/rfc.yaml index fdacb32f..9ae2ce35 100644 --- a/pkg/ldap/directory/yaml/fixtures/rfc.yaml +++ b/pkg/ldap/directory/yaml/fixtures/rfc.yaml @@ -2,24 +2,24 @@ dc:org: #dn: dc=org dc:example: #dn: dc=example,dc=org ou:group: #dn: ou=group,dc=example,dc=org cn:owner: #dn: cn=admin,ou=group,dc=example,dc=org - objectclass: posixGroup + objectClass: posixGroup gidNumber: 1000 description: Organization owners memberUid: {{{ search '(gidNumber=1000)' | attribute uidNumber }}} #RFC(02): not yet implemented cn:dev: #dn: cn=dev,ou=group,dc=example,dc=org - objectclass: posixGroup + objectClass: posixGroup gidNumber: 1001 description: Organization developers memberUid: {{{ search '(gidNumber=1001)' | attribute uidNumber }}} #RFC(02): not yet implemented cn:qa: #dn: cn=qa,ou=group,dc=example,dc=org - objectclass: posixGroup + objectClass: posixGroup gidNumber: 1002 memberUid: {{{ search '(gidNumber=1002)' | attribute uidNumber }}} #RFC(02): not yet implemented c:global: #dn: c=global,dc=example,dc=org ou:people: #dn: ou=people,c=global,dc=example,dc=org cn:alice: #dn: cn=alice,ou=people,c=global,dc=example,dc=org - objectclass: posixAccount + objectClass: posixAccount .#allowDN: !!ldap/acl:allow-on [dc=org] # allow alice to request everything uid: alice @@ -30,7 +30,7 @@ dc:org: #dn: dc=org loginShell: /bin/bash userPassword: !!ldap/bind:password {{{ hscVaultSecret '/secret/ldap/users/global/alice#password' | passwordType 'base64' }}} #RFC(02): not yet implemented cn:bob: #dn: cn=bob,ou=people,c=global,dc=example,dc=org - objectclass: posixAccount + objectClass: posixAccount .#allowDN: !!ldap/acl:allow-on [ou=group,dc=example,dc=org] # allow bob request only for user groups uid: bob @@ -42,7 +42,7 @@ dc:org: #dn: dc=org c:fr: #dn: c=fr,dc=example,dc=org ou:people: #dn: ou=people,c=fr,dc=example,dc=org cn:charlie: #dn: cn=charlie,ou=people,c=fr,dc=example,dc=org - objectclass: posixAccount + objectClass: posixAccount .#allowDN: !!ldap/acl:allow-on [ou=group,dc=example,dc=org, c=fr,dc=example,dc=org] # allow charlie request only for user groups & fr peoples ... .#denyDN: !!ldap/acl:deny-on cn=admin,ou=group,dc=example,dc=org # ...but deny access to owner group @@ -53,7 +53,7 @@ dc:org: #dn: dc=org userPassword: !!ldap/bind:password {{{ hscVaultSecret '/secret/ldap/users/global/bob#password' | passwordType 'base64' }}} #RFC(02): not yet implemented {{- range $id, $name := (hscVaultList 'secret/ldap/users/fr' | eval | without 'charlie') }} #RFC(02): not yet implemented cn:{{ $name }}: #dn: cn={{ $name }},ou=people,c=fr,dc=example,dc=org - objectclass: posixAccount + objectClass: posixAccount uid: charlie homeDirectory: /home/{{ $name }} @@ -65,7 +65,7 @@ dc:org: #dn: dc=org c:uk: #dn: c=uk,dc=example,dc=org ou:people: #dn: ou=people,c=fr,dc=example,dc=org cn:eve: #dn: cn=eve,ou=people,c=uk,dc=example,dc=org - objectclass: posixAccount + objectClass: posixAccount #NOTE: eve can't make any LDAP request (no !!ldap/bind:password field) uid: eve @@ -77,6 +77,6 @@ dc:org: #dn: dc=org --- !!ldap/schemas #RFC(03): not yet implemented -objectclasses: +objectClasses: - posixGroup:1.3.6.1.1.1.2.2:structural:Abstraction of a group of accounts - posixAccount:1.3.6.1.1.1.2.0:auxiliary:Abstraction of an account with POSIX attributes \ No newline at end of file diff --git a/pkg/ldap/directory/yaml/parse.go b/pkg/ldap/directory/yaml/parse.go index 673e4bbe..9ad2fdd5 100644 --- a/pkg/ldap/directory/yaml/parse.go +++ b/pkg/ldap/directory/yaml/parse.go @@ -11,19 +11,19 @@ import ( ) // parseLDAPObject parses a YAML mapping node into a LDAP object. -func parseLDAPObject(parent *common.Object, key string, value *yaml.Node) error { - if strings.Count(key, ":") != 1 { +func parseLDAPObject(parent *common.Object, key, value *yaml.Node) error { + if strings.Count(key.Value, ":") != 1 { return &ParseError{ err: fmt.Errorf( "invalid key: '%s' must be in the form ':' (e.g. 'ou:users')", - key, + key.Value, ), source: value, } } // extract the DN from the key - dn := strings.Replace(key, ":", "=", 1) + dn := strings.Replace(key.Value, ":", "=", 1) if parent := parent.DN(); parent != "" { dn = dn + "," + parent } @@ -33,7 +33,7 @@ func parseLDAPObject(parent *common.Object, key string, value *yaml.Node) error DN: dn, SubObjects: map[string]*common.Object{}, Attributes: ldap.Attributes{ - strings.SplitN(key, ":", 2)[0]: []string{strings.SplitN(key, ":", 2)[1]}, + strings.SplitN(key.Value, ":", 2)[0]: []string{strings.SplitN(key.Value, ":", 2)[1]}, }, }, } @@ -42,40 +42,40 @@ func parseLDAPObject(parent *common.Object, key string, value *yaml.Node) error seen := map[string]bool{} subnodes := slices.Clone(value.Content) for i := 0; i < len(subnodes); i += 2 { - key, value := subnodes[i], subnodes[i+1] + skey, svalue := subnodes[i], subnodes[i+1] // skip already seen keys (can happen with merge tags) - if seen[key.Value] { + if seen[skey.Value] { continue } // resolve aliases nodes - for value.Kind == yaml.AliasNode { - value = value.Alias + for svalue.Kind == yaml.AliasNode { + svalue = svalue.Alias } // if the sub-node is a 'merge' node, merge the content of the // referenced node into the current node (priority merge) - if key.Kind == yaml.ScalarNode && key.Value == "<<" && - (key.Tag == "" || key.Tag == "!" || key.Tag == "!!merge") { - if value.Kind != yaml.MappingNode { + if skey.Kind == yaml.ScalarNode && skey.Value == "<<" && + (skey.Tag == "" || skey.Tag == "!" || skey.Tag == "!!merge") { + if svalue.Kind != yaml.MappingNode { return &ParseError{ - err: fmt.Errorf("only mapping nodes can be merged, got a %s", YamlKindVerbose(value.Kind)), - source: key, + err: fmt.Errorf("only mapping nodes can be merged, got a %s", YamlKindVerbose(svalue.Kind)), + source: skey, } } - subnodes = append(subnodes, value.Content...) + subnodes = append(subnodes, svalue.Content...) continue } - switch value.Kind { + switch svalue.Kind { case yaml.MappingNode: - if err := parseLDAPObject(obj, key.Value, value); err != nil { + if err := parseLDAPObject(obj, skey, svalue); err != nil { return err } case yaml.SequenceNode, yaml.ScalarNode: - if err := parseLDAPAttribute(obj, key.Value, value); err != nil { + if err := parseLDAPAttribute(obj, skey, svalue); err != nil { return err } default: @@ -83,15 +83,15 @@ func parseLDAPObject(parent *common.Object, key string, value *yaml.Node) error // that can reach this point is a document node continue } - seen[key.Value] = true + seen[skey.Value] = true } - parent.SubObjects[key] = obj + parent.SubObjects[key.Value] = obj return nil } // parseLDAPAttribute parses a YAML sequence or scalar node into a LDAP attribute. -func parseLDAPAttribute(parent *common.Object, key string, value *yaml.Node) error { +func parseLDAPAttribute(parent *common.Object, key, value *yaml.Node) error { // ignore all null values if value.Tag == "!!null" { return nil @@ -114,9 +114,22 @@ func parseLDAPAttribute(parent *common.Object, key string, value *yaml.Node) err return nil } + for name := range parent.Attributes() { + if strings.EqualFold(name, key.Value) && name != key.Value { + return &ParseError{ + err: fmt.Errorf( + "invalid attribute: '%s' is already defined (case-insensitive match with '%s')", + key.Value, + name, + ), + source: key, + } + } + } + switch value.Kind { case yaml.ScalarNode: - parent.AddAttribute(key, value.Value) + parent.AddAttribute(key.Value, value.Value) case yaml.SequenceNode: for _, node := range value.Content { err := parseLDAPAttribute(parent, key, node) diff --git a/pkg/ldap/directory/yaml/parse_test.go b/pkg/ldap/directory/yaml/parse_test.go index ddc57594..00996576 100644 --- a/pkg/ldap/directory/yaml/parse_test.go +++ b/pkg/ldap/directory/yaml/parse_test.go @@ -27,7 +27,7 @@ func TestParseLDAPObject_Basic(t *testing.T) { require.NoError(t, err) obj := &common.Object{ImplObject: common.ImplObject{SubObjects: map[string]*common.Object{}}} - err = parseLDAPObject(obj, "go:test", node.Content[0]) + err = parseLDAPObject(obj, &yaml.Node{Value: "go:test"}, node.Content[0]) assert.NoError(t, err) assert.Equal(t, expect, obj.SubObjects["go:test"].SubObjects) } @@ -67,7 +67,7 @@ ou:people: require.NoError(t, err) obj := &common.Object{ImplObject: common.ImplObject{SubObjects: map[string]*common.Object{}}} - err = parseLDAPObject(obj, "go:test", node.Content[0]) + err = parseLDAPObject(obj, &yaml.Node{Value: "go:test"}, node.Content[0]) assert.NoError(t, err) assert.Equal(t, expect, obj.SubObjects["go:test"].SubObjects) } @@ -118,7 +118,7 @@ ou:people: require.NoError(t, err) obj := &common.Object{ImplObject: common.ImplObject{SubObjects: map[string]*common.Object{}}} - err = parseLDAPObject(obj, "go:test", node.Content[0]) + err = parseLDAPObject(obj, &yaml.Node{Value: "go:test"}, node.Content[0]) assert.NoError(t, err) assert.Equal(t, expect, obj.SubObjects["go:test"].SubObjects) } @@ -127,12 +127,12 @@ func TestParseLDAPObject_WithMergeField(t *testing.T) { raw := ` ou:people: uid:alice: &alice - objectclass: [posixAccount] + objectClass: [posixAccount] memberOf: [admin, user, h4ck3r] givenname: alice uid:bob: <<: *alice - objectclass: [UserMail] + objectClass: [UserMail] givenname: bob ` expect := map[string]*common.Object{ @@ -146,7 +146,7 @@ ou:people: DN: "uid=alice,ou=people,go=test", Attributes: ldap.Attributes{ "uid": []string{"alice"}, - "objectclass": []string{"posixAccount"}, + "objectClass": []string{"posixAccount"}, "memberOf": []string{"admin", "user", "h4ck3r"}, "givenname": []string{"alice"}, }, @@ -158,7 +158,7 @@ ou:people: DN: "uid=bob,ou=people,go=test", Attributes: ldap.Attributes{ "uid": []string{"bob"}, - "objectclass": []string{"UserMail"}, + "objectClass": []string{"UserMail"}, "memberOf": []string{"admin", "user", "h4ck3r"}, "givenname": []string{"bob"}, }, @@ -175,7 +175,7 @@ ou:people: require.NoError(t, err) obj := &common.Object{ImplObject: common.ImplObject{SubObjects: map[string]*common.Object{}}} - err = parseLDAPObject(obj, "go:test", node.Content[0]) + err = parseLDAPObject(obj, &yaml.Node{Value: "go:test"}, node.Content[0]) assert.NoError(t, err) assert.Equal(t, expect, obj.SubObjects["go:test"].SubObjects) } @@ -189,7 +189,7 @@ func TestParseLDAPObject_WithInvalidKey(t *testing.T) { require.NoError(t, err) obj := &common.Object{ImplObject: common.ImplObject{SubObjects: map[string]*common.Object{}}} - err = parseLDAPObject(obj, "go:test", node.Content[0]) + err = parseLDAPObject(obj, &yaml.Node{Value: "go:test"}, node.Content[0]) assert.EqualError(t, err, expectErr) } @@ -207,7 +207,7 @@ ou:people: require.NoError(t, err) obj := &common.Object{ImplObject: common.ImplObject{SubObjects: map[string]*common.Object{}}} - err = parseLDAPObject(obj, "go:test", node.Content[0]) + err = parseLDAPObject(obj, &yaml.Node{Value: "go:test"}, node.Content[0]) assert.EqualError(t, err, expectErr) } @@ -225,7 +225,7 @@ func TestParseLDAPAttribute_Basic(t *testing.T) { require.NoError(t, err) obj := &common.Object{} - err = parseLDAPAttribute(obj, node.Content[0].Content[0].Value, node.Content[0].Content[1]) + err = parseLDAPAttribute(obj, node.Content[0].Content[0], node.Content[0].Content[1]) assert.NoError(t, err) assert.Equal(t, expect, obj) } @@ -243,7 +243,7 @@ func TestParseLDAPAttribute_WithNullValue(t *testing.T) { require.NoError(t, err) obj := &common.Object{} - err = parseLDAPAttribute(obj, node.Content[0].Content[0].Value, node.Content[0].Content[1]) + err = parseLDAPAttribute(obj, node.Content[0].Content[0], node.Content[0].Content[1]) assert.NoError(t, err) assert.Equal(t, expect, obj) } @@ -265,7 +265,7 @@ func TestParseLDAPAttribute_WithBindPassword(t *testing.T) { require.NoError(t, err) obj := &common.Object{} - err = parseLDAPAttribute(obj, node.Content[0].Content[0].Value, node.Content[0].Content[1]) + err = parseLDAPAttribute(obj, node.Content[0].Content[0], node.Content[0].Content[1]) assert.NoError(t, err) assert.Equal(t, expect, obj) } @@ -286,7 +286,7 @@ func TestParseLDAPAttribute_WithAllowedOn(t *testing.T) { require.NoError(t, err) obj := &common.Object{} - err = parseLDAPAttribute(obj, node.Content[0].Content[0].Value, node.Content[0].Content[1]) + err = parseLDAPAttribute(obj, node.Content[0].Content[0], node.Content[0].Content[1]) assert.NoError(t, err) assert.Equal(t, expect, obj) } @@ -308,7 +308,7 @@ func TestParseLDAPAttribute_WithMultipleAllowedOn(t *testing.T) { require.NoError(t, err) obj := &common.Object{} - err = parseLDAPAttribute(obj, node.Content[0].Content[0].Value, node.Content[0].Content[1]) + err = parseLDAPAttribute(obj, node.Content[0].Content[0], node.Content[0].Content[1]) assert.NoError(t, err) assert.Equal(t, expect, obj) } @@ -329,7 +329,7 @@ func TestParseLDAPAttribute_WithDeniedOn(t *testing.T) { require.NoError(t, err) obj := &common.Object{} - err = parseLDAPAttribute(obj, node.Content[0].Content[0].Value, node.Content[0].Content[1]) + err = parseLDAPAttribute(obj, node.Content[0].Content[0], node.Content[0].Content[1]) assert.NoError(t, err) assert.Equal(t, expect, obj) } @@ -351,7 +351,7 @@ func TestParseLDAPAttribute_WithMultipleDeniedOn(t *testing.T) { require.NoError(t, err) obj := &common.Object{} - err = parseLDAPAttribute(obj, node.Content[0].Content[0].Value, node.Content[0].Content[1]) + err = parseLDAPAttribute(obj, node.Content[0].Content[0], node.Content[0].Content[1]) assert.NoError(t, err) assert.Equal(t, expect, obj) } @@ -383,7 +383,7 @@ authz: require.NoError(t, err) obj := &common.Object{} - err = parseLDAPAttribute(obj, node.Content[0].Content[0].Value, node.Content[0].Content[1]) + err = parseLDAPAttribute(obj, node.Content[0].Content[0], node.Content[0].Content[1]) assert.NoError(t, err) assert.Equal(t, expect, obj) } @@ -401,6 +401,27 @@ invalid: require.NoError(t, err) obj := &common.Object{} - err = parseLDAPAttribute(obj, node.Content[0].Content[0].Value, node.Content[0].Content[1]) + err = parseLDAPAttribute(obj, node.Content[0].Content[0], node.Content[0].Content[1]) assert.EqualError(t, err, expectErr) } + +func TestParseLDAPAttribute_WithDupplicateAttribute(t *testing.T) { + rawAttr1 := `valid: true` + rawAttr2 := `Valid: number` + + expectErr := "invalid LDAP YAML document at line 1, column 1: invalid attribute: 'Valid' is already defined (case-insensitive match with 'valid')" + + var node yaml.Node + err := yaml.Unmarshal([]byte(rawAttr1), &node) + require.NoError(t, err) + + obj := &common.Object{} + err = parseLDAPAttribute(obj, node.Content[0].Content[0], node.Content[0].Content[1]) + require.NoError(t, err) + + err = yaml.Unmarshal([]byte(rawAttr2), &node) + require.NoError(t, err) + + err = parseLDAPAttribute(obj, node.Content[0].Content[0], node.Content[0].Content[1]) + require.EqualError(t, err, expectErr) +} diff --git a/pkg/ldap/filters/comparator_resolvers.go b/pkg/ldap/filters/comparator_resolvers.go index 24e27705..b03266ce 100644 --- a/pkg/ldap/filters/comparator_resolvers.go +++ b/pkg/ldap/filters/comparator_resolvers.go @@ -3,6 +3,7 @@ package filters import ( "fmt" "strconv" + "strings" ber "github.com/go-asn1-ber/asn1-ber" goldap "github.com/go-ldap/ldap/v3" @@ -76,8 +77,11 @@ func compareResolver(fnc compareFnc, object ldap.Object, filter *ber.Packet) (bo return false, fmt.Errorf("invalid condition: must be a valid string") } - if attr, exists := object.Attributes()[attr]; exists { - return fnc(condition, attr), nil + for key, values := range object.Attributes() { + // NOTE: we need to compare the attribute name in a case-insensitive way. + if strings.EqualFold(key, attr) { + return fnc(condition, values), nil + } } return false, nil } diff --git a/pkg/ldap/filters/comparator_resolvers_test.go b/pkg/ldap/filters/comparator_resolvers_test.go index 695a2632..3d34d1bd 100644 --- a/pkg/ldap/filters/comparator_resolvers_test.go +++ b/pkg/ldap/filters/comparator_resolvers_test.go @@ -47,6 +47,11 @@ func TestGreaterOrEqualResolver(t *testing.T) { filter: "(memberOf>=z)", expected: assert.False, }, + + { // NOTE: case-insensitive attribute + filter: "(uidnumber>=-1)", + expected: assert.True, + }, } for _, tt := range tests { @@ -123,6 +128,11 @@ func TestLessOrEqualResolver(t *testing.T) { filter: "(memberOf<= )", expected: assert.False, }, + + { // NOTE: case-insensitive attribute + filter: "(uidnumber<=1000)", + expected: assert.True, + }, } for _, tt := range tests { diff --git a/pkg/ldap/filters/match_resolver_test.go b/pkg/ldap/filters/match_resolver_test.go index c976f567..fb97b634 100644 --- a/pkg/ldap/filters/match_resolver_test.go +++ b/pkg/ldap/filters/match_resolver_test.go @@ -43,6 +43,11 @@ func TestApproxResolver(t *testing.T) { filter: "(memberOf~=398)", expected: assert.True, }, + + { // NOTE: case-insensitive attribute + filter: "(uidnumber~=1000)", + expected: assert.True, + }, } for _, tt := range tests { @@ -115,6 +120,11 @@ func TestEqualResolver(t *testing.T) { filter: "(memberOf=398)", expected: assert.False, }, + + { // NOTE: case-insensitive attribute + filter: "(uidnumber=1000)", + expected: assert.True, + }, } for _, tt := range tests { diff --git a/pkg/ldap/filters/present_resolver.go b/pkg/ldap/filters/present_resolver.go index 93651722..9871a051 100644 --- a/pkg/ldap/filters/present_resolver.go +++ b/pkg/ldap/filters/present_resolver.go @@ -2,6 +2,7 @@ package filters import ( "fmt" + "strings" ber "github.com/go-asn1-ber/asn1-ber" goldap "github.com/go-ldap/ldap/v3" @@ -20,6 +21,11 @@ func PresentResolver(object ldap.Object, filter *ber.Packet) (bool, error) { return false, &Error{goldap.FilterPresent, fmt.Errorf("invalid attribute: must be a valid non-empty string")} } - _, exists := object.Attributes()[attr] - return exists, nil + for key, values := range object.Attributes() { + // NOTE: case-insensitive attribute + if strings.EqualFold(key, attr) && len(values) > 0 { + return true, nil + } + } + return false, nil } diff --git a/pkg/ldap/filters/present_resolver_test.go b/pkg/ldap/filters/present_resolver_test.go index 99fd11a5..7fbf2cc7 100644 --- a/pkg/ldap/filters/present_resolver_test.go +++ b/pkg/ldap/filters/present_resolver_test.go @@ -23,6 +23,11 @@ func TestPresentResolver(t *testing.T) { filter: "(password=*)", expected: assert.False, }, + + { // NOTE: case-insensitive attribute + filter: "(memberof=*)", + expected: assert.True, + }, } for _, tt := range tests { diff --git a/pkg/ldap/filters/substring_resolver.go b/pkg/ldap/filters/substring_resolver.go index 253e04c2..2e9c89c3 100644 --- a/pkg/ldap/filters/substring_resolver.go +++ b/pkg/ldap/filters/substring_resolver.go @@ -52,8 +52,11 @@ func SubstringResolver(object ldap.Object, filter *ber.Packet) (bool, error) { return false, &Error{goldap.FilterSubstrings, fmt.Errorf("internal error: %w", err)} } - if attr, exists := object.Attributes()[attr]; exists { - return slices.IndexFunc(attr, func(s string) bool { return rx.MatchString(s) }) > -1, nil + for key, values := range object.Attributes() { + // NOTE: case-insensitive attribute + if strings.EqualFold(key, attr) && len(values) > 0 { + return slices.IndexFunc(values, func(s string) bool { return rx.MatchString(s) }) > -1, nil + } } return false, nil } diff --git a/pkg/ldap/filters/substring_resolver_test.go b/pkg/ldap/filters/substring_resolver_test.go index 0ce606f1..c2edc75a 100644 --- a/pkg/ldap/filters/substring_resolver_test.go +++ b/pkg/ldap/filters/substring_resolver_test.go @@ -47,6 +47,11 @@ func TestSubstringResolver(t *testing.T) { filter: "(memberOf=u*k*n)", expected: assert.False, }, + + { // NOTE: case-insensitive attribute + filter: "(memberof=ad*)", + expected: assert.True, + }, } for _, tt := range tests { diff --git a/pkg/ldap/mux.go b/pkg/ldap/mux.go index 83a60e0e..96252209 100644 --- a/pkg/ldap/mux.go +++ b/pkg/ldap/mux.go @@ -1,11 +1,15 @@ package ldap import ( + "fmt" "log/slog" + "strings" "github.com/jimlambrt/gldap" "github.com/xunleii/yaldap/internal/ldap/auth" "github.com/xunleii/yaldap/pkg/ldap/directory" + "github.com/xunleii/yaldap/pkg/utils" + "golang.org/x/exp/slices" ) // server is a ldap server that uses a Directory to accept and perform search. @@ -17,15 +21,16 @@ type server struct { } // NewMux creates a new LDAP server. -func NewMux(logger *slog.Logger, directory directory.Directory) *gldap.Mux { +func NewMux(logger *slog.Logger, directory directory.Directory, sessions *auth.Sessions) *gldap.Mux { server := &server{ logger: logger, - sessions: auth.NewSessions(), + sessions: sessions, directory: directory, } mux, _ := gldap.NewMux() _ = mux.Bind(server.bind) + _ = mux.Unbind(server.unbind) _ = mux.Search(server.search) _ = mux.Add(server.add) _ = mux.Modify(server.modify) @@ -71,14 +76,7 @@ func (s *server) bind(w *gldap.ResponseWriter, req *gldap.Request) { // in order to avoid any bruteforce attack. return } - log = s.logger.With( - slog.String("method", "bind"), - slog.Group("session", - slog.Int("id", req.ConnectionID()), - slog.Int("request_id", req.ID), - slog.String("bind_dn", msg.UserName), - ), - ) + log = log.With(slog.String("bind_dn", obj.DN())) err = s.sessions.NewSession(req.ConnectionID(), obj) if err != nil { @@ -92,6 +90,25 @@ func (s *server) bind(w *gldap.ResponseWriter, req *gldap.Request) { resp.SetResultCode(gldap.ResultSuccess) } +func (s *server) unbind(_ *gldap.ResponseWriter, req *gldap.Request) { + log := s.logger.With( + slog.String("method", "unbind"), + slog.Group("session", + slog.Int("id", req.ConnectionID()), + slog.Int("request_id", req.ID), + ), + ) + + session := s.sessions.Session(req.ConnectionID()) + if session == nil { + return + } + log = log.With(slog.String("bind_dn", session.Object().DN())) + + s.sessions.Delete(req.ConnectionID()) + log.Info("unbind successful") +} + // Search implements the LDAP search mechanism. func (s *server) search(w *gldap.ResponseWriter, req *gldap.Request) { log := s.logger.With( @@ -112,23 +129,21 @@ func (s *server) search(w *gldap.ResponseWriter, req *gldap.Request) { resp.SetDiagnosticMessage(err.Error()) return } + log = log.With(slog.Group("request", + slog.String("base_dn", msg.BaseDN), + slog.String("filter", msg.Filter), + slog.String("scope", utils.LDAPScopes[msg.Scope]), + slog.Any("attributes", msg.Attributes), + )) session := s.sessions.Session(req.ConnectionID()) if session == nil { - log.Error("no session found") + log.Error("session not found or expired") resp.SetResultCode(gldap.ResultAuthorizationDenied) return } obj := session.Object() - log = s.logger.With( - slog.String("method", "bind"), - slog.Group("session", - slog.Int("id", req.ConnectionID()), - slog.Int("request_id", req.ID), - slog.String("bind_dn", obj.DN()), - ), - slog.String("base_dn", msg.BaseDN), - ) + log = log.With(slog.String("bind_dn", obj.DN())) baseDn := s.directory.BaseDN(msg.BaseDN) if baseDn == nil { @@ -137,12 +152,6 @@ func (s *server) search(w *gldap.ResponseWriter, req *gldap.Request) { return } - log.Debug( - "searching", - slog.Int64("scope", int64(msg.Scope)), - slog.String("filter", msg.Filter), - slog.Any("attributes", msg.Attributes), - ) entries, err := baseDn.Search(msg.Scope, msg.Filter) if err != nil { log.Error("unable to search", slog.String("error", err.Error())) @@ -151,16 +160,27 @@ func (s *server) search(w *gldap.ResponseWriter, req *gldap.Request) { return } + var count int for _, entry := range entries { if obj.CanSearchOn(entry.DN()) { - entry := req.NewSearchResponseEntry( - entry.DN(), - gldap.WithAttributes(entry.Attributes()), - ) - _ = w.Write(entry) + resp := req.NewSearchResponseEntry(entry.DN()) + attrs := entry.Attributes() + + // Filter attributes if needed + for attr, values := range attrs { + if len(msg.Attributes) == 0 || slices.ContainsFunc( + msg.Attributes, + func(s string) bool { return strings.EqualFold(s, attr) }, + ) { + resp.AddAttribute(attr, values) + } + } + + _ = w.Write(resp) + count++ } } - log.Info("search successful") + log.Info(fmt.Sprintf("found %d entries", count)) resp.SetResultCode(gldap.ResultSuccess) } diff --git a/pkg/ldap/mux_test.go b/pkg/ldap/mux_test.go index 3e6f8318..4cc4e7eb 100644 --- a/pkg/ldap/mux_test.go +++ b/pkg/ldap/mux_test.go @@ -1,6 +1,7 @@ package ldap_test import ( + "context" "io" "log/slog" "testing" @@ -11,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/xunleii/yaldap/internal/ldap/auth" "github.com/xunleii/yaldap/pkg/ldap" yamldir "github.com/xunleii/yaldap/pkg/ldap/directory/yaml" ) @@ -32,13 +34,14 @@ type ( func (suite *LDAPTestSuite) SetupSuite() { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + sessions := auth.NewSessions(context.Background(), time.Minute) directory, err := yamldir.NewDirectoryFromYAML([]byte(` dc:org: - objectclass: organization + objectClass: organization dc:example: - objectclass: organization + objectClass: organization ou:people: cn:alice: @@ -46,25 +49,25 @@ dc:org: - !!ldap/acl:allow-on dc=example,dc=org - !!ldap/acl:deny-on cn=bob,ou=people,dc=example,dc=org - objectclass: person + objectClass: person userpassword: !!ldap/bind:password alice cn:bob: - objectclass: person + objectClass: person userpassword: !!ldap/bind:password "" cn:charlie: - objectclass: person + objectClass: person dc:example2: - objectclass: organization + objectClass: organization `)) suite.Require().NoError(err) suite.Server, err = gldap.NewServer() suite.Require().NoError(err) - err = suite.Server.Router(ldap.NewMux(logger, directory)) + err = suite.Server.Router(ldap.NewMux(logger, directory, sessions)) suite.Require().NoError(err) go func() { @@ -134,6 +137,27 @@ func (suite *LDAPTestSuite) TestMux_Bind() { }) } +func (suite *LDAPTestSuite) TestMux_Unbind() { + conn, err := suite.DialLDAP() + suite.Require().NoError(err) + defer conn.Close() + + suite.T().Run("SuccessfulUnbind", func(t *testing.T) { + err = conn.Bind("cn=alice,ou=people,dc=example,dc=org", "alice") + suite.Require().NoError(err) + + err = conn.Unbind() // gldap.Server automatically closes the connection after unbind + assert.NoError(t, err) + + _, err = conn.Search(&goldap.SearchRequest{ + BaseDN: "dc=org", + Scope: goldap.ScopeWholeSubtree, + Filter: "(cn=alice)", + }) + assert.EqualError(t, err, "LDAP Result Code 200 \"Network Error\": ldap: connection closed") + }) +} + func (suite *LDAPTestSuite) TestMux_Search() { conn, err := suite.DialLDAP() suite.Require().NoError(err) @@ -166,7 +190,7 @@ func (suite *LDAPTestSuite) TestMux_Search() { DN: "cn=alice,ou=people,dc=example,dc=org", Attributes: map[string][]string{ "cn": {"alice"}, - "objectclass": {"person"}, + "objectClass": {"person"}, "userpassword": {"alice"}, }, }, @@ -176,7 +200,7 @@ func (suite *LDAPTestSuite) TestMux_Search() { }) suite.T().Run("FindAllObjectclass", func(t *testing.T) { - req := goldap.NewSearchRequest("dc=org", goldap.ScopeWholeSubtree, 0, 0, 0, false, "(objectclass=*)", nil, nil) + req := goldap.NewSearchRequest("dc=org", goldap.ScopeWholeSubtree, 0, 0, 0, false, "(objectClass=*)", nil, nil) res, err := conn.Search(req) require.NoError(t, err) @@ -186,14 +210,14 @@ func (suite *LDAPTestSuite) TestMux_Search() { DN: "dc=example,dc=org", Attributes: map[string][]string{ "dc": {"example"}, - "objectclass": {"organization"}, + "objectClass": {"organization"}, }, }, { DN: "cn=alice,ou=people,dc=example,dc=org", Attributes: map[string][]string{ "cn": {"alice"}, - "objectclass": {"person"}, + "objectClass": {"person"}, "userpassword": {"alice"}, }, }, @@ -201,7 +225,7 @@ func (suite *LDAPTestSuite) TestMux_Search() { DN: "cn=charlie,ou=people,dc=example,dc=org", Attributes: map[string][]string{ "cn": {"charlie"}, - "objectclass": {"person"}, + "objectClass": {"person"}, }, }, }, @@ -226,7 +250,7 @@ func (suite *LDAPTestSuite) TestMux_Search() { DN: "cn=alice,ou=people,dc=example,dc=org", Attributes: map[string][]string{ "cn": {"alice"}, - "objectclass": {"person"}, + "objectClass": {"person"}, "userpassword": {"alice"}, }, }, @@ -234,7 +258,7 @@ func (suite *LDAPTestSuite) TestMux_Search() { DN: "cn=charlie,ou=people,dc=example,dc=org", Attributes: map[string][]string{ "cn": {"charlie"}, - "objectclass": {"person"}, + "objectClass": {"person"}, }, }, }, diff --git a/pkg/utils/gldap.go b/pkg/utils/gldap.go new file mode 100644 index 00000000..8a5bc4f5 --- /dev/null +++ b/pkg/utils/gldap.go @@ -0,0 +1,11 @@ +package utils + +import "github.com/jimlambrt/gldap" + +// LDAPScopes maps gldap.Scope to string for logging +// purposes. +var LDAPScopes = map[gldap.Scope]string{ + gldap.BaseObject: "base", + gldap.SingleLevel: "one", + gldap.WholeSubtree: "sub", +} diff --git a/pkg/utils/hclog_wrapper.go b/pkg/utils/hclog_wrapper.go index f5607349..59048122 100644 --- a/pkg/utils/hclog_wrapper.go +++ b/pkg/utils/hclog_wrapper.go @@ -1,11 +1,14 @@ package utils import ( + "bytes" "context" "io" "log" "log/slog" "runtime" + "strings" + "sync" "time" "github.com/hashicorp/go-hclog" @@ -13,15 +16,28 @@ import ( type ( HashicorpLoggerWrapper struct { - *slog.Logger + Logger *slog.Logger args []interface{} } + HashicorpLoggerWriter struct { + logger *slog.Logger + level slog.Level + buffer *bytes.Buffer + last time.Time + + rwm sync.RWMutex + } +) + +const ( + // TraceLevel designates finer-grained informational events than the Debug. + LevelTrace = slog.LevelDebug * 2 ) var ( hclogLevels = map[hclog.Level]slog.Level{ - hclog.Trace: slog.LevelDebug, + hclog.Trace: LevelTrace, hclog.Debug: slog.LevelDebug, hclog.Info: slog.LevelInfo, hclog.NoLevel: slog.LevelInfo, @@ -29,6 +45,7 @@ var ( hclog.Error: slog.LevelError, } slogLevels = map[slog.Level]hclog.Level{ + LevelTrace: hclog.Trace, slog.LevelDebug: hclog.Debug, slog.LevelInfo: hclog.Info, slog.LevelWarn: hclog.Warn, @@ -49,7 +66,7 @@ func (logger HashicorpLoggerWrapper) Log(level hclog.Level, msg string, args ... // Emit a message and key/value pairs at the TRACE level. func (logger HashicorpLoggerWrapper) Trace(msg string, args ...interface{}) { - logger.log(slog.LevelDebug, msg, args...) + logger.log(LevelTrace, msg, args...) } // Emit a message and key/value pairs at the DEBUG level. @@ -75,7 +92,7 @@ func (logger HashicorpLoggerWrapper) Error(msg string, args ...interface{}) { // Indicate if TRACE logs would be emitted. This and the other Is* guards // are used to elide expensive logging code based on the current level. func (logger HashicorpLoggerWrapper) IsTrace() bool { - return logger.Logger.Enabled(context.TODO(), slog.LevelDebug) + return logger.Logger.Enabled(context.TODO(), LevelTrace) } // Indicate if DEBUG logs would be emitted. This and the other Is* guards. @@ -85,7 +102,7 @@ func (logger HashicorpLoggerWrapper) IsDebug() bool { // Indicate if INFO logs would be emitted. This and the other Is* guards. func (logger HashicorpLoggerWrapper) IsInfo() bool { - return logger.Enabled(context.Background(), slog.LevelInfo) + return logger.Logger.Enabled(context.Background(), slog.LevelInfo) } // Indicate if WARN logs would be emitted. This and the other Is* guards. @@ -140,7 +157,7 @@ func (logger HashicorpLoggerWrapper) GetLevel() hclog.Level { } func (logger HashicorpLoggerWrapper) getLevel() slog.Level { - for _, level := range []slog.Level{slog.LevelDebug, slog.LevelInfo, slog.LevelWarn, slog.LevelError} { + for _, level := range []slog.Level{LevelTrace, slog.LevelDebug, slog.LevelInfo, slog.LevelWarn, slog.LevelError} { if logger.Logger.Enabled(context.Background(), level) { return level } @@ -149,13 +166,27 @@ func (logger HashicorpLoggerWrapper) getLevel() slog.Level { } // Return a value that conforms to the stdlib log.Logger interface. -func (logger HashicorpLoggerWrapper) StandardLogger(*hclog.StandardLoggerOptions) *log.Logger { - panic("not implemented") +func (logger HashicorpLoggerWrapper) StandardLogger(opts *hclog.StandardLoggerOptions) *log.Logger { + level := logger.getLevel() + if opts != nil { + level = hclogLevels[opts.ForceLevel] + } + + return slog.NewLogLogger(logger.Logger.Handler(), level) } // Return a value that conforms to io.Writer, which can be passed into log.SetOutput(). +// NOTE: this is only used by gldap package to pretty print LDAP packets. For this +// purpose, we will log messages only if the current level is lower than Trace. func (logger HashicorpLoggerWrapper) StandardWriter(*hclog.StandardLoggerOptions) io.Writer { - panic("not implemented") + writer := &HashicorpLoggerWriter{ + logger: logger.Logger, + level: LevelTrace, + buffer: bytes.NewBuffer(nil), + last: time.Now(), + } + go writer.flushPeriodically() + return writer } func (logger HashicorpLoggerWrapper) log(level slog.Level, msg string, args ...interface{}) { @@ -169,5 +200,42 @@ func (logger HashicorpLoggerWrapper) log(level slog.Level, msg string, args ...i record := slog.NewRecord(time.Now(), level, msg, pcs[0]) record.Add(args...) - _ = logger.Handler().Handle(context.Background(), record) + _ = logger.Logger.Handler().Handle(context.Background(), record) +} + +// Write implements io.Writer. +func (writer *HashicorpLoggerWriter) Write(p []byte) (n int, err error) { + writer.rwm.Lock() + defer writer.rwm.Unlock() + + if !writer.logger.Enabled(context.Background(), writer.level) { + return 0, nil + } + + writer.last = time.Now() + return writer.buffer.Write(p) +} + +func (writer *HashicorpLoggerWriter) flushPeriodically() { + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for range ticker.C { + writer.rwm.RLock() + + if time.Since(writer.last) > 10*time.Millisecond && writer.buffer.Len() > 0 { + var pcs [1]uintptr + runtime.Callers(0, pcs[:]) + + record := slog.NewRecord( + time.Now(), + writer.level, + strings.TrimSuffix(writer.buffer.String(), "\n"), + pcs[0], + ) + _ = writer.logger.Handler().Handle(context.Background(), record) + writer.buffer.Reset() + } + writer.rwm.RUnlock() + } } diff --git a/pkg/utils/hclog_wrapper_test.go b/pkg/utils/hclog_wrapper_test.go index 87566dc6..1bc56b90 100644 --- a/pkg/utils/hclog_wrapper_test.go +++ b/pkg/utils/hclog_wrapper_test.go @@ -21,7 +21,7 @@ func (suite *HashicorpLoggerWrapperTestSuite) SetupTest() { suite.buffer = bytes.NewBuffer(nil) suite.logger = slog.New( slog.NewTextHandler(suite.buffer, &slog.HandlerOptions{ - Level: slog.LevelDebug, + Level: LevelTrace, ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr { switch attr.Key { case "time": @@ -47,7 +47,7 @@ func (suite *HashicorpLoggerWrapperTestSuite) TestLog() { level: hclog.Trace, msg: "test", args: []interface{}{"key", "value"}, - expected: "time=0001-01-01T00:00:00.000Z level=DEBUG msg=test key=value\n", + expected: "time=0001-01-01T00:00:00.000Z level=DEBUG-4 msg=test key=value\n", }, { level: hclog.Debug, @@ -124,7 +124,7 @@ func (suite *HashicorpLoggerWrapperTestSuite) TestX() { { name: "(*HashicorpLoggerWrapper).Trace", logFn: logger.Trace, - expected: "time=0001-01-01T00:00:00.000Z level=DEBUG msg=test key=value\n", + expected: "time=0001-01-01T00:00:00.000Z level=DEBUG-4 msg=test key=value\n", }, { name: "(*HashicorpLoggerWrapper).Debug", @@ -160,7 +160,7 @@ func (suite *HashicorpLoggerWrapperTestSuite) TestX() { func (suite *HashicorpLoggerWrapperTestSuite) TestIsX() { var logger hclog.Logger = &HashicorpLoggerWrapper{ Logger: slog.New(slog.NewTextHandler(nil, &slog.HandlerOptions{ - Level: slog.LevelDebug, + Level: LevelTrace, })), } @@ -284,19 +284,35 @@ func (suite *HashicorpLoggerWrapperTestSuite) TestGetLevel() { func (suite *HashicorpLoggerWrapperTestSuite) TestStandardLogger() { var logger hclog.Logger = &HashicorpLoggerWrapper{Logger: suite.logger} - suite.PanicsWithValue( - "not implemented", - func() { logger.StandardLogger(nil) }, - ) + logger.StandardLogger(nil).Println("test") + suite.Equal("time=0001-01-01T00:00:00.000Z level=DEBUG-4 msg=test\n", suite.buffer.String()) + suite.buffer.Reset() + + logger.StandardLogger(&hclog.StandardLoggerOptions{ForceLevel: hclog.Warn}).Println("test") + suite.Equal("time=0001-01-01T00:00:00.000Z level=WARN msg=test\n", suite.buffer.String()) } func (suite *HashicorpLoggerWrapperTestSuite) TestStandardWriter() { - var logger hclog.Logger = &HashicorpLoggerWrapper{Logger: suite.logger} + var logger hclog.Logger = &HashicorpLoggerWrapper{ + Logger: slog.New(slog.NewTextHandler(suite.buffer, &slog.HandlerOptions{Level: LevelTrace})), + } - suite.PanicsWithValue( - "not implemented", - func() { logger.StandardWriter(nil) }, - ) + _, err := logger.StandardWriter(nil).Write([]byte("test")) + suite.NoError(err) + + suite.Eventually(func() bool { + return suite.NotEmpty(suite.buffer.String()) + }, time.Second, 100*time.Millisecond) + suite.buffer.Reset() + + logger = &HashicorpLoggerWrapper{ + Logger: slog.New(slog.NewTextHandler(suite.buffer, &slog.HandlerOptions{Level: slog.LevelWarn})), + } + _, err = logger.StandardWriter(&hclog.StandardLoggerOptions{ForceLevel: hclog.Trace}).Write([]byte("test")) + suite.NoError(err) + suite.Eventually(func() bool { + return suite.Empty(suite.buffer.String()) + }, time.Second, 100*time.Millisecond) } func TestHashicorpLoggerWrapper(t *testing.T) {