mirror of
https://github.com/element-hq/dendrite.git
synced 2025-09-14 13:22:26 +03:00
Refactor user API storage (#2202)
* Refactor User API database * Fix migration bugs
This commit is contained in:
parent
9bd5e414c9
commit
9f4a39e8e0
22 changed files with 1165 additions and 1671 deletions
|
@ -21,18 +21,11 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
type loginTokenStatements struct {
|
||||
insertStmt *sql.Stmt
|
||||
deleteStmt *sql.Stmt
|
||||
selectByTokenStmt *sql.Stmt
|
||||
}
|
||||
|
||||
// execSchema ensures tables and indices exist.
|
||||
func (s *loginTokenStatements) execSchema(db *sql.DB) error {
|
||||
_, err := db.Exec(`
|
||||
const loginTokenSchema = `
|
||||
CREATE TABLE IF NOT EXISTS login_tokens (
|
||||
-- The random value of the token issued to a user
|
||||
token TEXT NOT NULL PRIMARY KEY,
|
||||
|
@ -45,24 +38,38 @@ CREATE TABLE IF NOT EXISTS login_tokens (
|
|||
|
||||
-- This index allows efficient garbage collection of expired tokens.
|
||||
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
|
||||
`)
|
||||
return err
|
||||
`
|
||||
|
||||
const insertLoginTokenSQL = "" +
|
||||
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteLoginTokenSQL = "" +
|
||||
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
|
||||
|
||||
const selectLoginTokenSQL = "" +
|
||||
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
|
||||
|
||||
type loginTokenStatements struct {
|
||||
insertStmt *sql.Stmt
|
||||
deleteStmt *sql.Stmt
|
||||
selectStmt *sql.Stmt
|
||||
}
|
||||
|
||||
// prepare runs statement preparation.
|
||||
func (s *loginTokenStatements) prepare(db *sql.DB) error {
|
||||
if err := s.execSchema(db); err != nil {
|
||||
return err
|
||||
func NewPostgresLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) {
|
||||
s := &loginTokenStatements{}
|
||||
_, err := db.Exec(loginTokenSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sqlutil.StatementList{
|
||||
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"},
|
||||
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"},
|
||||
{&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"},
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertStmt, insertLoginTokenSQL},
|
||||
{&s.deleteStmt, deleteLoginTokenSQL},
|
||||
{&s.selectStmt, selectLoginTokenSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
// insert adds an already generated token to the database.
|
||||
func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
|
||||
func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertStmt)
|
||||
_, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID)
|
||||
return err
|
||||
|
@ -72,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata
|
|||
//
|
||||
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
|
||||
// The login_tokens_expiration_idx index should make that efficient.
|
||||
func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error {
|
||||
func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteStmt)
|
||||
res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
|
||||
if err != nil {
|
||||
|
@ -85,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t
|
|||
}
|
||||
|
||||
// selectByToken returns the data associated with the given token. May return sql.ErrNoRows.
|
||||
func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
||||
func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
||||
var data api.LoginTokenData
|
||||
err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
|
||||
err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue