Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Auto-generate username if none provided during registration #470

Merged
merged 3 commits into from
May 31, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- TODO:
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
);
-- Create sequence for autogenerated numeric usernames
CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
`

const insertAccountSQL = "" +
Expand All @@ -49,13 +51,17 @@ const selectAccountByLocalpartSQL = "" +
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1"

const selectNewNumericLocalpartSQL = "" +
"SELECT nextval('numeric_username_seq')"

// TODO: Update password

type accountsStatements struct {
insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}

func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
Expand All @@ -72,6 +78,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
return
}
if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil {
return
}
s.serverName = server
return
}
Expand Down Expand Up @@ -121,3 +130,10 @@ func (s *accountsStatements) selectAccountByLocalpart(
}
return
}

func (s *accountsStatements) selectNewNumericLocalpart(
ctx context.Context,
) (id int64, err error) {
err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id)
return
}
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,13 @@ func (d *Database) GetAccountDataByType(
)
}

// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx)
}

func hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
return string(hashBytes), err
Expand Down
18 changes: 18 additions & 0 deletions src/github.com/matrix-org/dendrite/clientapi/routing/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"net/url"
"regexp"
"sort"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -403,6 +404,23 @@ func Register(
sessionID = util.RandomString(sessionIDLength)
}

// Don't allow numeric usernames less than MAX_INT64.
if _, err := strconv.ParseInt(r.Username, 10, 64); err == nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"),
}
}
// Auto generate a numeric username if r.Username is empty
if r.Username == "" {
id, err := accountDB.GetNewNumericLocalpart(req.Context())
if err != nil {
return httputil.LogThenError(req, err)
}

r.Username = strconv.FormatInt(id, 10)
}

// If no auth type is specified by the client, send back the list of available flows
if r.Auth.Type == "" {
return util.JSONResponse{
Expand Down