mirror of
https://github.com/element-hq/dendrite.git
synced 2025-09-14 05:12:26 +03:00
Add account type (#2171)
* Add account_type for sqlite3 * Add account_type for postgres * Remove CreateGuestAccount from interface * Add new AccountTypes & update test * Use newly added AccountType for account creation * Add migrations * Reuse type * Add AccounnType to Device, so it can be verified on requests * Rename migration, add missing update for appservices * Rename sqlite3 migration * Add missing AccountType to return value * Update sqlite migration Change allowance check on /admin/whois * Fix migration, add IS NULL * Move accountType to completeRegistration * Fix migrations * Add passing test
This commit is contained in:
parent
e9b672a34e
commit
5a39512f5f
18 changed files with 230 additions and 117 deletions
|
@ -19,10 +19,11 @@ import (
|
|||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
@ -39,16 +40,18 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
|||
-- Identifies which application service this account belongs to, if any.
|
||||
appservice_id TEXT,
|
||||
-- If the account is currently active
|
||||
is_deactivated BOOLEAN DEFAULT FALSE
|
||||
is_deactivated BOOLEAN DEFAULT FALSE,
|
||||
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
|
||||
account_type SMALLINT NOT NULL
|
||||
-- TODO:
|
||||
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
|
||||
-- upgraded_ts, devices, any email reset stuff?
|
||||
);
|
||||
-- Create sequence for autogenerated numeric usernames
|
||||
CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
|
||||
`
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
|
||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
||||
|
||||
const updatePasswordSQL = "" +
|
||||
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
|
@ -57,7 +60,7 @@ const deactivateAccountSQL = "" +
|
|||
"UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1"
|
||||
|
||||
const selectAccountByLocalpartSQL = "" +
|
||||
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
|
||||
"SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1"
|
||||
|
||||
const selectPasswordHashSQL = "" +
|
||||
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
|
||||
|
@ -96,16 +99,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
|||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||
// on success.
|
||||
func (s *accountsStatements) insertAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||
|
||||
var err error
|
||||
if appserviceID == "" {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
||||
if accountType != api.AccountTypeAppService {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
|
||||
} else {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -116,6 +119,7 @@ func (s *accountsStatements) insertAccount(
|
|||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
ServerName: s.serverName,
|
||||
AppServiceID: appserviceID,
|
||||
AccountType: accountType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -147,7 +151,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
|
|||
var acc api.Account
|
||||
|
||||
stmt := s.selectAccountByLocalpartStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
|
||||
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue