mas: handle 3pids from mas

This commit is contained in:
Roman Isaev 2025-01-06 03:19:47 +00:00
parent 5cffc2c257
commit 811a504e01
No known key found for this signature in database
GPG key ID: 7BE2B6A6C89AEC7F
5 changed files with 91 additions and 19 deletions

View file

@ -120,6 +120,7 @@ type Pusher interface {
type ThreePID interface {
SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName spec.ServerName, medium string) (err error)
BulkSaveThreePIDAssociation(ctx context.Context, threePIDs []authtypes.ThreePID, localpart string, serverName spec.ServerName) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName spec.ServerName, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (threepids []authtypes.ThreePID, err error)

View file

@ -353,6 +353,41 @@ func (d *Database) SaveThreePIDAssociation(
})
}
// BulkSaveThreePIDAssociation recreates 3PIDs for a user.
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) BulkSaveThreePIDAssociation(ctx context.Context, threePIDs []authtypes.ThreePID, localpart string, serverName spec.ServerName) (err error) {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
oldThreePIDs, err := d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName)
if err != nil {
return err
}
for _, t := range oldThreePIDs {
if err := d.ThreePIDs.DeleteThreePID(ctx, txn, t.Address, t.Medium); err != nil {
return err
}
}
for _, t := range threePIDs {
// if 3PID is associated with another user, return Err3PIDInUse
user, _, err := d.ThreePIDs.SelectLocalpartForThreePID(
ctx, txn, t.Address, t.Medium,
)
if err != nil {
return err
}
if len(user) > 0 && user != localpart {
return Err3PIDInUse
}
if err = d.ThreePIDs.InsertThreePID(ctx, txn, t.Address, t.Medium, localpart, serverName); err != nil {
return err
}
}
return nil
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.