From 2788be480009e4592add6cab3fa9a927541d2eb3 Mon Sep 17 00:00:00 2001 From: Sejong Kim Date: Mon, 22 Jan 2024 14:01:36 +0900 Subject: [PATCH] Roll back Users collection sharding --- api/types/project.go | 6 +- server/backend/database/database.go | 15 ++-- server/backend/database/memory/database.go | 42 ++++++---- .../backend/database/memory/database_test.go | 8 +- server/backend/database/mongo/client.go | 35 ++++++--- server/backend/database/mongo/client_test.go | 8 +- server/backend/database/project_info.go | 4 +- server/backend/database/project_info_test.go | 6 +- .../backend/database/testcases/testcases.go | 76 ++++++++++++------- server/projects/projects.go | 8 +- server/rpc/admin_server.go | 20 ++--- server/rpc/interceptors/admin_auth.go | 4 +- server/users/users.go | 21 ++++- test/integration/housekeeping_test.go | 9 ++- 14 files changed, 170 insertions(+), 92 deletions(-) diff --git a/api/types/project.go b/api/types/project.go index f290731ad..3ed8dd909 100644 --- a/api/types/project.go +++ b/api/types/project.go @@ -17,7 +17,9 @@ package types -import "time" +import ( + "time" +) // Project is a project that consists of multiple documents and clients. type Project struct { @@ -28,7 +30,7 @@ type Project struct { Name string `json:"name"` // Owner is the owner of this project. - Owner string `json:"owner"` + Owner ID `json:"owner"` // AuthWebhookURL is the url of the authorization webhook. AuthWebhookURL string `json:"auth_webhook_url"` diff --git a/server/backend/database/database.go b/server/backend/database/database.go index dc5bd9633..e6e8be5db 100644 --- a/server/backend/database/database.go +++ b/server/backend/database/database.go @@ -77,7 +77,7 @@ type Database interface { // FindProjectInfoByName returns a project by the given name. FindProjectInfoByName( ctx context.Context, - owner string, + owner types.ID, name string, ) (*ProjectInfo, error) @@ -99,17 +99,17 @@ type Database interface { CreateProjectInfo( ctx context.Context, name string, - owner string, + owner types.ID, clientDeactivateThreshold string, ) (*ProjectInfo, error) // ListProjectInfos returns all project infos owned by owner. - ListProjectInfos(ctx context.Context, owner string) ([]*ProjectInfo, error) + ListProjectInfos(ctx context.Context, owner types.ID) ([]*ProjectInfo, error) // UpdateProjectInfo updates the project. UpdateProjectInfo( ctx context.Context, - owner string, + owner types.ID, id types.ID, fields *types.UpdatableProjectFields, ) (*ProjectInfo, error) @@ -121,8 +121,11 @@ type Database interface { hashedPassword string, ) (*UserInfo, error) - // FindUserInfo returns a user by the given username. - FindUserInfo(ctx context.Context, username string) (*UserInfo, error) + // FindUserInfoByName returns a user by the given ID. + FindUserInfoByID(ctx context.Context, id types.ID) (*UserInfo, error) + + // FindUserInfoByName returns a user by the given username. + FindUserInfoByName(ctx context.Context, username string) (*UserInfo, error) // ListUserInfos returns all users. ListUserInfos(ctx context.Context) ([]*UserInfo, error) diff --git a/server/backend/database/memory/database.go b/server/backend/database/memory/database.go index 50c0f98bd..3a57b3b1d 100644 --- a/server/backend/database/memory/database.go +++ b/server/backend/database/memory/database.go @@ -97,13 +97,13 @@ func (d *DB) FindProjectInfoBySecretKey( // FindProjectInfoByName returns a project by the given name. func (d *DB) FindProjectInfoByName( _ context.Context, - owner string, + owner types.ID, name string, ) (*database.ProjectInfo, error) { txn := d.db.Txn(false) defer txn.Abort() - raw, err := txn.First(tblProjects, "owner_name", owner, name) + raw, err := txn.First(tblProjects, "owner_name", owner.String(), name) if err != nil { return nil, fmt.Errorf("find project by owner and name: %w", err) } @@ -143,7 +143,7 @@ func (d *DB) EnsureDefaultUserAndProject( return nil, nil, err } - project, err := d.ensureDefaultProjectInfo(ctx, username, clientDeactivateThreshold) + project, err := d.ensureDefaultProjectInfo(ctx, user.ID, clientDeactivateThreshold) if err != nil { return nil, nil, err } @@ -187,7 +187,7 @@ func (d *DB) ensureDefaultUserInfo( // ensureDefaultProjectInfo creates the default project if it does not exist. func (d *DB) ensureDefaultProjectInfo( _ context.Context, - defaultUsername string, + defaultUserID types.ID, defaultClientDeactivateThreshold string, ) (*database.ProjectInfo, error) { txn := d.db.Txn(true) @@ -200,7 +200,7 @@ func (d *DB) ensureDefaultProjectInfo( var info *database.ProjectInfo if raw == nil { - info = database.NewProjectInfo(database.DefaultProjectName, defaultUsername, defaultClientDeactivateThreshold) + info = database.NewProjectInfo(database.DefaultProjectName, defaultUserID, defaultClientDeactivateThreshold) info.ID = database.DefaultProjectID if err := txn.Insert(tblProjects, info); err != nil { return nil, fmt.Errorf("insert project: %w", err) @@ -217,7 +217,7 @@ func (d *DB) ensureDefaultProjectInfo( func (d *DB) CreateProjectInfo( _ context.Context, name string, - owner string, + owner types.ID, clientDeactivateThreshold string, ) (*database.ProjectInfo, error) { txn := d.db.Txn(true) @@ -225,7 +225,7 @@ func (d *DB) CreateProjectInfo( // NOTE(hackerwins): Check if the project already exists. // https://github.com/hashicorp/go-memdb/issues/7#issuecomment-270427642 - existing, err := txn.First(tblProjects, "owner_name", owner, name) + existing, err := txn.First(tblProjects, "owner_name", owner.String(), name) if err != nil { return nil, fmt.Errorf("find project by owner and name: %w", err) } @@ -304,7 +304,7 @@ func (d *DB) FindNextNCyclingProjectInfos( // ListProjectInfos returns all project infos owned by owner. func (d *DB) ListProjectInfos( _ context.Context, - owner string, + owner types.ID, ) ([]*database.ProjectInfo, error) { txn := d.db.Txn(false) defer txn.Abort() @@ -312,7 +312,7 @@ func (d *DB) ListProjectInfos( iter, err := txn.LowerBound( tblProjects, "owner_name", - owner, + owner.String(), "", ) if err != nil { @@ -335,7 +335,7 @@ func (d *DB) ListProjectInfos( // UpdateProjectInfo updates the given project. func (d *DB) UpdateProjectInfo( _ context.Context, - owner string, + owner types.ID, id types.ID, fields *types.UpdatableProjectFields, ) (*database.ProjectInfo, error) { @@ -356,7 +356,7 @@ func (d *DB) UpdateProjectInfo( } if fields.Name != nil { - existing, err := txn.First(tblProjects, "owner_name", owner, *fields.Name) + existing, err := txn.First(tblProjects, "owner_name", owner.String(), *fields.Name) if err != nil { return nil, fmt.Errorf("find project by owner and name: %w", err) } @@ -402,8 +402,24 @@ func (d *DB) CreateUserInfo( return info, nil } -// FindUserInfo finds a user by the given username. -func (d *DB) FindUserInfo(_ context.Context, username string) (*database.UserInfo, error) { +// FindUserInfoByID finds a user by the given ID. +func (d *DB) FindUserInfoByID(_ context.Context, clientID types.ID) (*database.UserInfo, error) { + txn := d.db.Txn(false) + defer txn.Abort() + + raw, err := txn.First(tblUsers, "id", clientID.String()) + if err != nil { + return nil, fmt.Errorf("find user by id: %w", err) + } + if raw == nil { + return nil, fmt.Errorf("%s: %w", clientID, database.ErrUserNotFound) + } + + return raw.(*database.UserInfo).DeepCopy(), nil +} + +// FindUserInfoByName finds a user by the given username. +func (d *DB) FindUserInfoByName(_ context.Context, username string) (*database.UserInfo, error) { txn := d.db.Txn(false) defer txn.Abort() diff --git a/server/backend/database/memory/database_test.go b/server/backend/database/memory/database_test.go index dcd6f8b84..cc36ce8f7 100644 --- a/server/backend/database/memory/database_test.go +++ b/server/backend/database/memory/database_test.go @@ -64,8 +64,12 @@ func TestDB(t *testing.T) { testcases.RunListUserInfosTest(t, db) }) - t.Run("FindUserInfo test", func(t *testing.T) { - testcases.RunFindUserInfoTest(t, db) + t.Run("FindUserInfoByID test", func(t *testing.T) { + testcases.RunFindUserInfoByIDTest(t, db) + }) + + t.Run("FindUserInfoByName test", func(t *testing.T) { + testcases.RunFindUserInfoByNameTest(t, db) }) t.Run("FindProjectInfoBySecretKey test", func(t *testing.T) { diff --git a/server/backend/database/mongo/client.go b/server/backend/database/mongo/client.go index 0bfb7fa7a..702f2f348 100644 --- a/server/backend/database/mongo/client.go +++ b/server/backend/database/mongo/client.go @@ -107,7 +107,7 @@ func (c *Client) EnsureDefaultUserAndProject( return nil, nil, err } - projectInfo, err := c.ensureDefaultProjectInfo(ctx, userInfo.Username, clientDeactivateThreshold) + projectInfo, err := c.ensureDefaultProjectInfo(ctx, userInfo.ID, clientDeactivateThreshold) if err != nil { return nil, nil, err } @@ -162,10 +162,10 @@ func (c *Client) ensureDefaultUserInfo( // ensureDefaultProjectInfo creates the default project info if it does not exist. func (c *Client) ensureDefaultProjectInfo( ctx context.Context, - defaultUsername string, + defaultUserID types.ID, defaultClientDeactivateThreshold string, ) (*database.ProjectInfo, error) { - candidate := database.NewProjectInfo(database.DefaultProjectName, defaultUsername, defaultClientDeactivateThreshold) + candidate := database.NewProjectInfo(database.DefaultProjectName, defaultUserID, defaultClientDeactivateThreshold) candidate.ID = database.DefaultProjectID _, err := c.collection(ColProjects).UpdateOne(ctx, bson.M{ @@ -203,7 +203,7 @@ func (c *Client) ensureDefaultProjectInfo( func (c *Client) CreateProjectInfo( ctx context.Context, name string, - owner string, + owner types.ID, clientDeactivateThreshold string, ) (*database.ProjectInfo, error) { info := database.NewProjectInfo(name, owner, clientDeactivateThreshold) @@ -275,7 +275,7 @@ func (c *Client) FindNextNCyclingProjectInfos( // ListProjectInfos returns all project infos owned by owner. func (c *Client) ListProjectInfos( ctx context.Context, - owner string, + owner types.ID, ) ([]*database.ProjectInfo, error) { cursor, err := c.collection(ColProjects).Find(ctx, bson.M{ "owner": owner, @@ -329,7 +329,7 @@ func (c *Client) FindProjectInfoBySecretKey(ctx context.Context, secretKey strin // FindProjectInfoByName returns a project by name. func (c *Client) FindProjectInfoByName( ctx context.Context, - owner string, + owner types.ID, name string, ) (*database.ProjectInfo, error) { result := c.collection(ColProjects).FindOne(ctx, bson.M{ @@ -368,7 +368,7 @@ func (c *Client) FindProjectInfoByID(ctx context.Context, id types.ID) (*databas // UpdateProjectInfo updates the project info. func (c *Client) UpdateProjectInfo( ctx context.Context, - owner string, + owner types.ID, id types.ID, fields *types.UpdatableProjectFields, ) (*database.ProjectInfo, error) { @@ -428,8 +428,25 @@ func (c *Client) CreateUserInfo( return info, nil } -// FindUserInfo returns a user by username. -func (c *Client) FindUserInfo(ctx context.Context, username string) (*database.UserInfo, error) { +// FindUserInfoByID returns a user by ID. +func (c *Client) FindUserInfoByID(ctx context.Context, clientID types.ID) (*database.UserInfo, error) { + result := c.collection(ColUsers).FindOne(ctx, bson.M{ + "_id": clientID, + }) + + userInfo := database.UserInfo{} + if err := result.Decode(&userInfo); err != nil { + if err == mongo.ErrNoDocuments { + return nil, fmt.Errorf("%s: %w", clientID, database.ErrUserNotFound) + } + return nil, fmt.Errorf("decode user info: %w", err) + } + + return &userInfo, nil +} + +// FindUserInfoByName returns a user by username. +func (c *Client) FindUserInfoByName(ctx context.Context, username string) (*database.UserInfo, error) { result := c.collection(ColUsers).FindOne(ctx, bson.M{ "username": username, }) diff --git a/server/backend/database/mongo/client_test.go b/server/backend/database/mongo/client_test.go index 3f4108aa3..f59eca9b3 100644 --- a/server/backend/database/mongo/client_test.go +++ b/server/backend/database/mongo/client_test.go @@ -81,8 +81,12 @@ func TestClient(t *testing.T) { testcases.RunListUserInfosTest(t, cli) }) - t.Run("FindUserInfo test", func(t *testing.T) { - testcases.RunFindUserInfoTest(t, cli) + t.Run("FindUserInfoByID test", func(t *testing.T) { + testcases.RunFindUserInfoByIDTest(t, cli) + }) + + t.Run("FindUserInfoByName test", func(t *testing.T) { + testcases.RunFindUserInfoByNameTest(t, cli) }) t.Run("FindProjectInfoBySecretKey test", func(t *testing.T) { diff --git a/server/backend/database/project_info.go b/server/backend/database/project_info.go index cc2d76958..8c5fe7eca 100644 --- a/server/backend/database/project_info.go +++ b/server/backend/database/project_info.go @@ -43,7 +43,7 @@ type ProjectInfo struct { Name string `bson:"name"` // Owner is the owner of this project. - Owner string `bson:"owner"` + Owner types.ID `bson:"owner"` // PublicKey is the API key of this project. PublicKey string `bson:"public_key"` @@ -69,7 +69,7 @@ type ProjectInfo struct { } // NewProjectInfo creates a new ProjectInfo of the given name. -func NewProjectInfo(name string, owner string, clientDeactivateThreshold string) *ProjectInfo { +func NewProjectInfo(name string, owner types.ID, clientDeactivateThreshold string) *ProjectInfo { return &ProjectInfo{ Name: name, Owner: owner, diff --git a/server/backend/database/project_info_test.go b/server/backend/database/project_info_test.go index 2fbece09c..658e2bf33 100644 --- a/server/backend/database/project_info_test.go +++ b/server/backend/database/project_info_test.go @@ -27,9 +27,9 @@ import ( func TestProjectInfo(t *testing.T) { t.Run("update fields test", func(t *testing.T) { - dummyOwnerName := "dummy" + dummyOwnerID := types.ID("000000000000000000000000") clientDeactivateThreshold := "1h" - project := database.NewProjectInfo(t.Name(), dummyOwnerName, clientDeactivateThreshold) + project := database.NewProjectInfo(t.Name(), dummyOwnerID, clientDeactivateThreshold) testName := "testName" testURL := "testUrl" @@ -44,7 +44,7 @@ func TestProjectInfo(t *testing.T) { project.UpdateFields(&types.UpdatableProjectFields{AuthWebhookMethods: &testMethods}) assert.Equal(t, testMethods, project.AuthWebhookMethods) - assert.Equal(t, dummyOwnerName, project.Owner) + assert.Equal(t, dummyOwnerID, project.Owner) project.UpdateFields(&types.UpdatableProjectFields{ ClientDeactivateThreshold: &testClientDeactivateThreshold, diff --git a/server/backend/database/testcases/testcases.go b/server/backend/database/testcases/testcases.go index d704951e0..4fdc092b3 100644 --- a/server/backend/database/testcases/testcases.go +++ b/server/backend/database/testcases/testcases.go @@ -41,8 +41,8 @@ import ( ) const ( - dummyOwnerName = "dummy" - otherOwnerName = "other" + dummyOwnerID = types.ID("000000000000000000000000") + otherOwnerID = types.ID("000000000000000000000001") dummyClientID = types.ID("000000000000000000000000") clientDeactivateThreshold = "1h" ) @@ -107,44 +107,44 @@ func RunFindProjectInfoByNameTest( _, err := db.CreateProjectInfo( ctx, fmt.Sprintf("%s-%d", t.Name(), suffix), - dummyOwnerName, + dummyOwnerID, clientDeactivateThreshold, ) assert.NoError(t, err) } - _, err := db.CreateProjectInfo(ctx, t.Name(), otherOwnerName, clientDeactivateThreshold) + _, err := db.CreateProjectInfo(ctx, t.Name(), otherOwnerID, clientDeactivateThreshold) assert.NoError(t, err) // Lists all projects that the dummyOwnerID is the owner. - projects, err := db.ListProjectInfos(ctx, dummyOwnerName) + projects, err := db.ListProjectInfos(ctx, dummyOwnerID) assert.NoError(t, err) assert.Len(t, projects, len(suffixes)) - _, err = db.CreateProjectInfo(ctx, t.Name(), dummyOwnerName, clientDeactivateThreshold) + _, err = db.CreateProjectInfo(ctx, t.Name(), dummyOwnerID, clientDeactivateThreshold) assert.NoError(t, err) - project, err := db.FindProjectInfoByName(ctx, dummyOwnerName, t.Name()) + project, err := db.FindProjectInfoByName(ctx, dummyOwnerID, t.Name()) assert.NoError(t, err) assert.Equal(t, project.Name, t.Name()) newName := fmt.Sprintf("%s-%d", t.Name(), 3) fields := &types.UpdatableProjectFields{Name: &newName} - _, err = db.UpdateProjectInfo(ctx, dummyOwnerName, project.ID, fields) + _, err = db.UpdateProjectInfo(ctx, dummyOwnerID, project.ID, fields) assert.NoError(t, err) - _, err = db.FindProjectInfoByName(ctx, dummyOwnerName, newName) + _, err = db.FindProjectInfoByName(ctx, dummyOwnerID, newName) assert.NoError(t, err) }) t.Run("FindProjectInfoByName test", func(t *testing.T) { ctx := context.Background() - info1, err := db.CreateProjectInfo(ctx, t.Name(), dummyOwnerName, clientDeactivateThreshold) + info1, err := db.CreateProjectInfo(ctx, t.Name(), dummyOwnerID, clientDeactivateThreshold) assert.NoError(t, err) - _, err = db.CreateProjectInfo(ctx, t.Name(), otherOwnerName, clientDeactivateThreshold) + _, err = db.CreateProjectInfo(ctx, t.Name(), otherOwnerID, clientDeactivateThreshold) assert.NoError(t, err) - info2, err := db.FindProjectInfoByName(ctx, dummyOwnerName, t.Name()) + info2, err := db.FindProjectInfoByName(ctx, dummyOwnerID, t.Name()) assert.NoError(t, err) assert.Equal(t, info1.ID, info2.ID) }) @@ -305,9 +305,9 @@ func RunListUserInfosTest(t *testing.T, db database.Database) { }) } -// RunFindUserInfoTest runs the FindUserInfo test for the given db. -func RunFindUserInfoTest(t *testing.T, db database.Database) { - t.Run("RunFindUserInfo test", func(t *testing.T) { +// RunFindUserInfoByIDTest runs the FindUserInfoByID test for the given db. +func RunFindUserInfoByIDTest(t *testing.T, db database.Database) { + t.Run("RunFindUserInfoByID test", func(t *testing.T) { ctx := context.Background() username := "findUserInfoTestAccount" @@ -316,7 +316,25 @@ func RunFindUserInfoTest(t *testing.T, db database.Database) { user, _, err := db.EnsureDefaultUserAndProject(ctx, username, password, clientDeactivateThreshold) assert.NoError(t, err) - info1, err := db.FindUserInfo(ctx, user.Username) + info1, err := db.FindUserInfoByID(ctx, user.ID) + assert.NoError(t, err) + + assert.Equal(t, user.ID, info1.ID) + }) +} + +// RunFindUserInfoByNameTest runs the FindUserInfoByName test for the given db. +func RunFindUserInfoByNameTest(t *testing.T, db database.Database) { + t.Run("RunFindUserInfoByName test", func(t *testing.T) { + ctx := context.Background() + + username := "findUserInfoTestAccount" + password := "temporary-password" + + user, _, err := db.EnsureDefaultUserAndProject(ctx, username, password, clientDeactivateThreshold) + assert.NoError(t, err) + + info1, err := db.FindUserInfoByName(ctx, user.Username) assert.NoError(t, err) assert.Equal(t, user.ID, info1.ID) @@ -386,9 +404,9 @@ func RunUpdateProjectInfoTest(t *testing.T, db database.Database) { } newClientDeactivateThreshold := "1h" - info, err := db.CreateProjectInfo(ctx, t.Name(), dummyOwnerName, clientDeactivateThreshold) + info, err := db.CreateProjectInfo(ctx, t.Name(), dummyOwnerID, clientDeactivateThreshold) assert.NoError(t, err) - _, err = db.CreateProjectInfo(ctx, existName, dummyOwnerName, clientDeactivateThreshold) + _, err = db.CreateProjectInfo(ctx, existName, dummyOwnerID, clientDeactivateThreshold) assert.NoError(t, err) id := info.ID @@ -401,7 +419,7 @@ func RunUpdateProjectInfoTest(t *testing.T, db database.Database) { ClientDeactivateThreshold: &newClientDeactivateThreshold, } assert.NoError(t, fields.Validate()) - res, err := db.UpdateProjectInfo(ctx, dummyOwnerName, id, fields) + res, err := db.UpdateProjectInfo(ctx, dummyOwnerID, id, fields) assert.NoError(t, err) updateInfo, err := db.FindProjectInfoByID(ctx, id) assert.NoError(t, err) @@ -416,7 +434,7 @@ func RunUpdateProjectInfoTest(t *testing.T, db database.Database) { Name: &newName2, } assert.NoError(t, fields.Validate()) - res, err = db.UpdateProjectInfo(ctx, dummyOwnerName, id, fields) + res, err = db.UpdateProjectInfo(ctx, dummyOwnerID, id, fields) assert.NoError(t, err) updateInfo, err = db.FindProjectInfoByID(ctx, id) assert.NoError(t, err) @@ -432,7 +450,7 @@ func RunUpdateProjectInfoTest(t *testing.T, db database.Database) { AuthWebhookURL: &newAuthWebhookURL2, } assert.NoError(t, fields.Validate()) - res, err = db.UpdateProjectInfo(ctx, dummyOwnerName, id, fields) + res, err = db.UpdateProjectInfo(ctx, dummyOwnerID, id, fields) assert.NoError(t, err) updateInfo, err = db.FindProjectInfoByID(ctx, id) assert.NoError(t, err) @@ -448,7 +466,7 @@ func RunUpdateProjectInfoTest(t *testing.T, db database.Database) { ClientDeactivateThreshold: &clientDeactivateThreshold2, } assert.NoError(t, fields.Validate()) - res, err = db.UpdateProjectInfo(ctx, dummyOwnerName, id, fields) + res, err = db.UpdateProjectInfo(ctx, dummyOwnerID, id, fields) assert.NoError(t, err) updateInfo, err = db.FindProjectInfoByID(ctx, id) assert.NoError(t, err) @@ -460,12 +478,12 @@ func RunUpdateProjectInfoTest(t *testing.T, db database.Database) { // 05. Duplicated name test fields = &types.UpdatableProjectFields{Name: &existName} - _, err = db.UpdateProjectInfo(ctx, dummyOwnerName, id, fields) + _, err = db.UpdateProjectInfo(ctx, dummyOwnerID, id, fields) assert.ErrorIs(t, err, database.ErrProjectNameAlreadyExists) // 06. OwnerID not match test fields = &types.UpdatableProjectFields{Name: &existName} - _, err = db.UpdateProjectInfo(ctx, otherOwnerName, id, fields) + _, err = db.UpdateProjectInfo(ctx, otherOwnerID, id, fields) assert.ErrorIs(t, err, database.ErrProjectNotFound) }) } @@ -539,7 +557,7 @@ func RunFindDocInfosByPagingTest(t *testing.T, db database.Database, projectID t ctx := context.Background() // dummy project setup - testProjectInfo, err := db.CreateProjectInfo(ctx, t.Name(), dummyOwnerName, clientDeactivateThreshold) + testProjectInfo, err := db.CreateProjectInfo(ctx, t.Name(), dummyOwnerID, clientDeactivateThreshold) assert.NoError(t, err) // dummy document setup @@ -646,7 +664,7 @@ func RunFindDocInfosByPagingTest(t *testing.T, db database.Database, projectID t ctx := context.Background() // 01. Initialize a project and create documents. - projectInfo, err := db.CreateProjectInfo(ctx, t.Name(), dummyOwnerName, clientDeactivateThreshold) + projectInfo, err := db.CreateProjectInfo(ctx, t.Name(), dummyOwnerID, clientDeactivateThreshold) assert.NoError(t, err) var docInfos []*database.DocInfo @@ -1120,7 +1138,7 @@ func RunFindNextNCyclingProjectInfosTest(t *testing.T, db database.Database) { p, err := db.CreateProjectInfo( ctx, fmt.Sprintf("%s-%d-RunFindNextNCyclingProjectInfos", t.Name(), i), - otherOwnerName, + otherOwnerID, clientDeactivateThreshold, ) assert.NoError(t, err) @@ -1150,7 +1168,7 @@ func RunFindDeactivateCandidatesPerProjectTest(t *testing.T, db database.Databas p1, err := db.CreateProjectInfo( ctx, fmt.Sprintf("%s-FindDeactivateCandidatesPerProject", t.Name()), - otherOwnerName, + otherOwnerID, clientDeactivateThreshold, ) assert.NoError(t, err) @@ -1164,7 +1182,7 @@ func RunFindDeactivateCandidatesPerProjectTest(t *testing.T, db database.Databas p2, err := db.CreateProjectInfo( ctx, fmt.Sprintf("%s-FindDeactivateCandidatesPerProject-2", t.Name()), - otherOwnerName, + otherOwnerID, "0s", ) assert.NoError(t, err) diff --git a/server/projects/projects.go b/server/projects/projects.go index 5b9ce84c5..956ced116 100644 --- a/server/projects/projects.go +++ b/server/projects/projects.go @@ -29,7 +29,7 @@ import ( func CreateProject( ctx context.Context, be *backend.Backend, - owner string, + owner types.ID, name string, ) (*types.Project, error) { info, err := be.DB.CreateProjectInfo(ctx, name, owner, be.Config.ClientDeactivateThreshold) @@ -44,7 +44,7 @@ func CreateProject( func ListProjects( ctx context.Context, be *backend.Backend, - owner string, + owner types.ID, ) ([]*types.Project, error) { infos, err := be.DB.ListProjectInfos(ctx, owner) if err != nil { @@ -63,7 +63,7 @@ func ListProjects( func GetProject( ctx context.Context, be *backend.Backend, - owner string, + owner types.ID, name string, ) (*types.Project, error) { info, err := be.DB.FindProjectInfoByName(ctx, owner, name) @@ -78,7 +78,7 @@ func GetProject( func UpdateProject( ctx context.Context, be *backend.Backend, - owner string, + owner types.ID, id types.ID, fields *types.UpdatableProjectFields, ) (*types.Project, error) { diff --git a/server/rpc/admin_server.go b/server/rpc/admin_server.go index b8362522b..3a398c051 100644 --- a/server/rpc/admin_server.go +++ b/server/rpc/admin_server.go @@ -106,7 +106,7 @@ func (s *adminServer) CreateProject( } user := users.From(ctx) - project, err := projects.CreateProject(ctx, s.backend, user.Username, req.Msg.Name) + project, err := projects.CreateProject(ctx, s.backend, user.ID, req.Msg.Name) if err != nil { return nil, err } @@ -122,7 +122,7 @@ func (s *adminServer) ListProjects( _ *connect.Request[api.ListProjectsRequest], ) (*connect.Response[api.ListProjectsResponse], error) { user := users.From(ctx) - projectList, err := projects.ListProjects(ctx, s.backend, user.Username) + projectList, err := projects.ListProjects(ctx, s.backend, user.ID) if err != nil { return nil, err } @@ -138,7 +138,7 @@ func (s *adminServer) GetProject( req *connect.Request[api.GetProjectRequest], ) (*connect.Response[api.GetProjectResponse], error) { user := users.From(ctx) - project, err := projects.GetProject(ctx, s.backend, user.Username, req.Msg.Name) + project, err := projects.GetProject(ctx, s.backend, user.ID, req.Msg.Name) if err != nil { return nil, err } @@ -165,7 +165,7 @@ func (s *adminServer) UpdateProject( project, err := projects.UpdateProject( ctx, s.backend, - user.Username, + user.ID, types.ID(req.Msg.Id), fields, ) @@ -184,7 +184,7 @@ func (s *adminServer) GetDocument( req *connect.Request[api.GetDocumentRequest], ) (*connect.Response[api.GetDocumentResponse], error) { user := users.From(ctx) - project, err := projects.GetProject(ctx, s.backend, user.Username, req.Msg.ProjectName) + project, err := projects.GetProject(ctx, s.backend, user.ID, req.Msg.ProjectName) if err != nil { return nil, err } @@ -210,7 +210,7 @@ func (s *adminServer) GetSnapshotMeta( req *connect.Request[api.GetSnapshotMetaRequest], ) (*connect.Response[api.GetSnapshotMetaResponse], error) { user := users.From(ctx) - project, err := projects.GetProject(ctx, s.backend, user.Username, req.Msg.ProjectName) + project, err := projects.GetProject(ctx, s.backend, user.ID, req.Msg.ProjectName) if err != nil { return nil, err } @@ -243,7 +243,7 @@ func (s *adminServer) ListDocuments( req *connect.Request[api.ListDocumentsRequest], ) (*connect.Response[api.ListDocumentsResponse], error) { user := users.From(ctx) - project, err := projects.GetProject(ctx, s.backend, user.Username, req.Msg.ProjectName) + project, err := projects.GetProject(ctx, s.backend, user.ID, req.Msg.ProjectName) if err != nil { return nil, err } @@ -277,7 +277,7 @@ func (s *adminServer) SearchDocuments( req *connect.Request[api.SearchDocumentsRequest], ) (*connect.Response[api.SearchDocumentsResponse], error) { user := users.From(ctx) - project, err := projects.GetProject(ctx, s.backend, user.Username, req.Msg.ProjectName) + project, err := projects.GetProject(ctx, s.backend, user.ID, req.Msg.ProjectName) if err != nil { return nil, err } @@ -305,7 +305,7 @@ func (s *adminServer) RemoveDocumentByAdmin( req *connect.Request[api.RemoveDocumentByAdminRequest], ) (*connect.Response[api.RemoveDocumentByAdminResponse], error) { user := users.From(ctx) - project, err := projects.GetProject(ctx, s.backend, user.Username, req.Msg.ProjectName) + project, err := projects.GetProject(ctx, s.backend, user.ID, req.Msg.ProjectName) if err != nil { return nil, err } @@ -363,7 +363,7 @@ func (s *adminServer) ListChanges( req *connect.Request[api.ListChangesRequest], ) (*connect.Response[api.ListChangesResponse], error) { user := users.From(ctx) - project, err := projects.GetProject(ctx, s.backend, user.Username, req.Msg.ProjectName) + project, err := projects.GetProject(ctx, s.backend, user.ID, req.Msg.ProjectName) if err != nil { return nil, err } diff --git a/server/rpc/interceptors/admin_auth.go b/server/rpc/interceptors/admin_auth.go index 923561d98..39ee85119 100644 --- a/server/rpc/interceptors/admin_auth.go +++ b/server/rpc/interceptors/admin_auth.go @@ -165,7 +165,7 @@ func (i *AdminAuthInterceptor) authenticate( // NOTE(raararaara): If the token is access token, return the user of the token. claims, err := i.tokenManager.Verify(authorization) if err == nil { - user, err := users.GetUser(ctx, i.backend, claims.Username) + user, err := users.GetUserByName(ctx, i.backend, claims.Username) if err == nil { return user, nil } @@ -174,7 +174,7 @@ func (i *AdminAuthInterceptor) authenticate( // NOTE(raararaara): If the token is secret key, return the owner of the project. project, err := projects.GetProjectFromSecretKey(ctx, i.backend, authorization) if err == nil { - user, err := users.GetUser(ctx, i.backend, project.Owner) + user, err := users.GetUserByID(ctx, i.backend, project.Owner) if err == nil { return user, nil } diff --git a/server/users/users.go b/server/users/users.go index e24d383f0..e4adc6a66 100644 --- a/server/users/users.go +++ b/server/users/users.go @@ -53,7 +53,7 @@ func IsCorrectPassword( username, password string, ) (*types.User, error) { - info, err := be.DB.FindUserInfo(ctx, username) + info, err := be.DB.FindUserInfoByName(ctx, username) if err != nil { return nil, err } @@ -68,16 +68,29 @@ func IsCorrectPassword( return info.ToUser(), nil } -// GetUser returns a user by the given username. -func GetUser( +// GetUserByName returns a user by the given username. +func GetUserByName( ctx context.Context, be *backend.Backend, username string, ) (*types.User, error) { - info, err := be.DB.FindUserInfo(ctx, username) + info, err := be.DB.FindUserInfoByName(ctx, username) if err != nil { return nil, err } return info.ToUser(), nil } + +// GetUserByID returns a user by ID. +func GetUserByID( + ctx context.Context, + be *backend.Backend, + id types.ID, +) (*types.User, error) { + info, err := be.DB.FindUserInfoByID(ctx, id) + if err != nil { + return nil, err + } + return info.ToUser(), nil +} diff --git a/test/integration/housekeeping_test.go b/test/integration/housekeeping_test.go index 9ca974d68..9b56d916d 100644 --- a/test/integration/housekeeping_test.go +++ b/test/integration/housekeeping_test.go @@ -32,6 +32,7 @@ import ( "github.com/stretchr/testify/assert" + "github.com/yorkie-team/yorkie/api/types" "github.com/yorkie-team/yorkie/server/backend/database" "github.com/yorkie-team/yorkie/server/backend/database/mongo" "github.com/yorkie-team/yorkie/server/backend/housekeeping" @@ -40,8 +41,8 @@ import ( ) const ( - dummyOwnerName = "dummy" - otherOwnerName = "other" + dummyOwnerID = types.ID("000000000000000000000000") + otherOwnerID = types.ID("000000000000000000000001") clientDeactivateThreshold = "23h" ) @@ -140,10 +141,10 @@ func createProjects(t *testing.T, db *mongo.Client) []*database.ProjectInfo { projects := make([]*database.ProjectInfo, 0) for i := 0; i < 10; i++ { - p, err := db.CreateProjectInfo(ctx, fmt.Sprintf("%d project", i), dummyOwnerName, clientDeactivateThreshold) + p, err := db.CreateProjectInfo(ctx, fmt.Sprintf("%d project", i), dummyOwnerID, clientDeactivateThreshold) assert.NoError(t, err) projects = append(projects, p) - p, err = db.CreateProjectInfo(ctx, fmt.Sprintf("%d project", i), otherOwnerName, clientDeactivateThreshold) + p, err = db.CreateProjectInfo(ctx, fmt.Sprintf("%d project", i), otherOwnerID, clientDeactivateThreshold) assert.NoError(t, err) projects = append(projects, p) }