From 150be588f5b6719c0638cdae9ade5c3b2c967609 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Tue, 24 Dec 2024 03:06:26 +0000 Subject: [PATCH] mas: added localpart_external_ids table --- userapi/api/api.go | 8 ++ userapi/storage/interface.go | 7 ++ .../postgres/localpart_external_ids_table.go | 97 +++++++++++++++++++ userapi/storage/postgres/storage.go | 5 + userapi/storage/shared/storage.go | 13 +++ .../sqlite3/localpart_external_ids_table.go | 97 +++++++++++++++++++ userapi/storage/sqlite3/storage.go | 5 + userapi/storage/tables/interface.go | 10 ++ 8 files changed, 242 insertions(+) create mode 100644 userapi/storage/postgres/localpart_external_ids_table.go create mode 100644 userapi/storage/sqlite3/localpart_external_ids_table.go diff --git a/userapi/api/api.go b/userapi/api/api.go index 26482129..f0ef26bf 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -471,6 +471,14 @@ type OpenIDTokenAttributes struct { ExpiresAtMS int64 } +// LocalpartExternalID represents a connection between Matrix account and OpenID Connect provider +type LocalpartExternalID struct { + Localpart string + ExternalID string + AuthProvider string + CreatedTs int64 +} + // UserInfo is for returning information about the user an OpenID token was issued for type UserInfo struct { Sub string // The Matrix user's ID who generated the token diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 2a46a7fd..13d8c201 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -134,6 +134,12 @@ type Notification interface { DeleteOldNotifications(ctx context.Context) error } +type LocalpartExternalID interface { + CreateLocalpartExternalID(ctx context.Context, localpart, externalID, authProvider string) error + GetLocalpartForExternalID(ctx context.Context, externalID, authProvider string) (*api.LocalpartExternalID, error) + DeleteLocalpartExternalID(ctx context.Context, externalID, authProvider string) error +} + type UserDatabase interface { Account AccountData @@ -147,6 +153,7 @@ type UserDatabase interface { Statistics ThreePID RegistrationTokens + LocalpartExternalID } type KeyChangeDatabase interface { diff --git a/userapi/storage/postgres/localpart_external_ids_table.go b/userapi/storage/postgres/localpart_external_ids_table.go new file mode 100644 index 00000000..9bc47dbf --- /dev/null +++ b/userapi/storage/postgres/localpart_external_ids_table.go @@ -0,0 +1,97 @@ +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/element-hq/dendrite/internal/sqlutil" + "github.com/element-hq/dendrite/userapi/api" + "github.com/element-hq/dendrite/userapi/storage/tables" + log "github.com/sirupsen/logrus" +) + +const localpartExternalIDsSchema = ` +-- Stores data about connections between accounts and third-party auth providers +CREATE TABLE IF NOT EXISTS userapi_localpart_external_ids ( + -- The Matrix user ID for this account + localpart TEXT NOT NULL, + -- The external ID + external_id TEXT NOT NULL, + -- Auth provider ID (see OIDCProvider.IDPID) + auth_provider TEXT NOT NULL, + -- When this connection was created, as a unix timestamp. + created_ts BIGINT NOT NULL, + + CONSTRAINT userapi_localpart_external_ids_external_id_auth_provider_unique UNIQUE(external_id, auth_provider), + CONSTRAINT userapi_localpart_external_ids_localpart_external_id_auth_provider_unique UNIQUE(localpart, external_id, auth_provider) +); + +-- This index allows efficient lookup of the local user by the external ID +CREATE INDEX IF NOT EXISTS userapi_external_id_auth_provider_idx ON userapi_localpart_external_ids(external_id, auth_provider); +` + +const insertUserExternalIDSQL = "" + + "INSERT INTO userapi_localpart_external_ids(localpart, external_id, auth_provider, created_ts) VALUES ($1, $2, $3, $4)" + +const selectUserExternalIDSQL = "" + + "SELECT localpart, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2" + +const deleteUserExternalIDSQL = "" + + "SELECT localpart, external_id, auth_provider, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2" + +type localpartExternalIDStatements struct { + db *sql.DB + insertUserExternalIDStmt *sql.Stmt + selectUserExternalIDStmt *sql.Stmt + deleteUserExternalIDStmt *sql.Stmt +} + +func NewPostgresLocalpartExternalIDsTable(db *sql.DB) (tables.LocalpartExternalIDsTable, error) { + s := &localpartExternalIDStatements{ + db: db, + } + _, err := db.Exec(localpartExternalIDsSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertUserExternalIDStmt, insertUserExternalIDSQL}, + {&s.selectUserExternalIDStmt, selectUserExternalIDSQL}, + {&s.deleteUserExternalIDStmt, deleteUserExternalIDSQL}, + }.Prepare(db) +} + +// Select selects an existing OpenID Connect connection from the database +func (u *localpartExternalIDStatements) Select(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) { + ret := api.LocalpartExternalID{ + ExternalID: externalID, + AuthProvider: authProvider, + } + err := u.selectUserExternalIDStmt.QueryRowContext(ctx, externalID, authProvider).Scan( + &ret.Localpart, &ret.CreatedTs, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + log.WithError(err).Error("Unable to retrieve localpart from the db") + return nil, err + } + + return &ret, nil +} + +// Insert creates a new record representing an OpenID Connect connection between Matrix and external accounts. +func (u *localpartExternalIDStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error { + stmt := sqlutil.TxStmt(txn, u.insertUserExternalIDStmt) + _, err := stmt.ExecContext(ctx, localpart, externalID, authProvider, time.Now().Unix()) + return err +} + +// Delete deletes the existing OpenID Connect connection. After this method is called, the Matrix account will no longer be associated with the external account. +func (u *localpartExternalIDStatements) Delete(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error { + stmt := sqlutil.TxStmt(txn, u.deleteUserExternalIDStmt) + _, err := stmt.ExecContext(ctx, externalID, authProvider) + return err +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index c7fb9d29..eff12a64 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -97,6 +97,10 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties if err != nil { return nil, fmt.Errorf("NewPostgresStatsTable: %w", err) } + localpartExternalIDsTable, err := NewPostgresLocalpartExternalIDsTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteLocalpartExternalIDsTable: %w", err) + } m = sqlutil.NewMigrator(db) m.AddMigrations(sqlutil.Migration{ @@ -123,6 +127,7 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties Notifications: notificationsTable, RegistrationTokens: registationTokensTable, Stats: statsTable, + LocalpartExternalIDs: localpartExternalIDsTable, ServerName: serverName, DB: db, Writer: writer, diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 44ace733..aade4be1 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -49,6 +49,7 @@ type Database struct { Notifications tables.NotificationTable Pushers tables.PusherTable Stats tables.StatsTable + LocalpartExternalIDs tables.LocalpartExternalIDsTable LoginTokenLifetime time.Duration ServerName spec.ServerName BcryptCost int @@ -870,6 +871,18 @@ func (d *Database) UpsertPusher( }) } +func (d *Database) CreateLocalpartExternalID(ctx context.Context, localpart, externalID, authProvider string) error { + return d.LocalpartExternalIDs.Insert(ctx, nil, localpart, externalID, authProvider) +} + +func (d *Database) GetLocalpartForExternalID(ctx context.Context, externalID, authProvider string) (*api.LocalpartExternalID, error) { + return d.LocalpartExternalIDs.Select(ctx, nil, externalID, authProvider) +} + +func (d *Database) DeleteLocalpartExternalID(ctx context.Context, externalID, authProvider string) error { + return d.LocalpartExternalIDs.Delete(ctx, nil, externalID, authProvider) +} + // GetPushers returns the pushers matching the given localpart. func (d *Database) GetPushers( ctx context.Context, localpart string, serverName spec.ServerName, diff --git a/userapi/storage/sqlite3/localpart_external_ids_table.go b/userapi/storage/sqlite3/localpart_external_ids_table.go new file mode 100644 index 00000000..30f1fc60 --- /dev/null +++ b/userapi/storage/sqlite3/localpart_external_ids_table.go @@ -0,0 +1,97 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "time" + + "github.com/element-hq/dendrite/internal/sqlutil" + "github.com/element-hq/dendrite/userapi/api" + "github.com/element-hq/dendrite/userapi/storage/tables" + log "github.com/sirupsen/logrus" +) + +const localpartExternalIDsSchema = ` +-- Stores data about connections between accounts and third-party auth providers +CREATE TABLE IF NOT EXISTS userapi_localpart_external_ids ( + -- The Matrix user ID for this account + localpart TEXT NOT NULL, + -- The external ID + external_id TEXT NOT NULL, + -- Auth provider ID (see OIDCProvider.IDPID) + auth_provider TEXT NOT NULL, + -- When this connection was created, as a unix timestamp. + created_ts BIGINT NOT NULL, + + UNIQUE(external_id, auth_provider), + UNIQUE(localpart, external_id, auth_provider) +); + +-- This index allows efficient lookup of the local user by the external ID +CREATE INDEX IF NOT EXISTS userapi_external_id_auth_provider_idx ON userapi_localpart_external_ids(external_id, auth_provider); +` + +const insertLocalpartExternalIDSQL = "" + + "INSERT INTO userapi_localpart_external_ids(localpart, external_id, auth_provider, created_ts) VALUES ($1, $2, $3, $4)" + +const selectLocalpartExternalIDSQL = "" + + "SELECT localpart, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2" + +const deleteLocalpartExternalIDSQL = "" + + "SELECT localpart, external_id, auth_provider, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2" + +type localpartExternalIDStatements struct { + db *sql.DB + insertUserExternalIDStmt *sql.Stmt + selectUserExternalIDStmt *sql.Stmt + deleteUserExternalIDStmt *sql.Stmt +} + +func NewSQLiteLocalpartExternalIDsTable(db *sql.DB) (tables.LocalpartExternalIDsTable, error) { + s := &localpartExternalIDStatements{ + db: db, + } + _, err := db.Exec(localpartExternalIDsSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertUserExternalIDStmt, insertLocalpartExternalIDSQL}, + {&s.selectUserExternalIDStmt, selectLocalpartExternalIDSQL}, + {&s.deleteUserExternalIDStmt, deleteLocalpartExternalIDSQL}, + }.Prepare(db) +} + +// Select selects an existing OpenID Connect connection from the database +func (u *localpartExternalIDStatements) Select(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) { + ret := api.LocalpartExternalID{ + ExternalID: externalID, + AuthProvider: authProvider, + } + err := u.selectUserExternalIDStmt.QueryRowContext(ctx, externalID, authProvider).Scan( + &ret.Localpart, &ret.CreatedTs, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + log.WithError(err).Error("Unable to retrieve localpart from the db") + return nil, err + } + + return &ret, nil +} + +// Insert creates a new record representing an OpenID Connect connection between Matrix and external accounts. +func (u *localpartExternalIDStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error { + stmt := sqlutil.TxStmt(txn, u.insertUserExternalIDStmt) + _, err := stmt.ExecContext(ctx, localpart, externalID, authProvider, time.Now().Unix()) + return err +} + +// Delete deletes the existing OpenID Connect connection. After this method is called, the Matrix account will no longer be associated with the external account. +func (u *localpartExternalIDStatements) Delete(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error { + stmt := sqlutil.TxStmt(txn, u.deleteUserExternalIDStmt) + _, err := stmt.ExecContext(ctx, externalID, authProvider) + return err +} diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 6d906191..80ecaf83 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -94,6 +94,10 @@ func NewUserDatabase(ctx context.Context, conMan *sqlutil.Connections, dbPropert if err != nil { return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err) } + localpartExternalIDsTable, err := NewSQLiteLocalpartExternalIDsTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteUserExternalIDsTable: %w", err) + } m = sqlutil.NewMigrator(db) m.AddMigrations(sqlutil.Migration{ @@ -119,6 +123,7 @@ func NewUserDatabase(ctx context.Context, conMan *sqlutil.Connections, dbPropert Pushers: pusherTable, Notifications: notificationsTable, Stats: statsTable, + LocalpartExternalIDs: localpartExternalIDsTable, ServerName: serverName, DB: db, Writer: writer, diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 44f31a5c..7b141629 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -127,6 +127,16 @@ type StatsTable interface { UpsertDailyStats(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error } +type LocalpartExternalIDsTable interface { + Select(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) + Insert(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error + Delete(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error +} + +type UIAuthSessionsTable interface { + SelectByID(ctx context.Context, txn *sql.Tx, sessionID int) (*api.UIAuthSession, error) +} + type NotificationFilter uint32 const (