mas: handle 3pids from mas

This commit is contained in:
Roman Isaev 2025-01-06 03:19:47 +00:00
parent 5cffc2c257
commit 811a504e01
No known key found for this signature in database
GPG key ID: 7BE2B6A6C89AEC7F
5 changed files with 91 additions and 19 deletions

View file

@ -23,6 +23,7 @@ import (
appserviceAPI "github.com/element-hq/dendrite/appservice/api"
clientapi "github.com/element-hq/dendrite/clientapi/api"
"github.com/element-hq/dendrite/clientapi/auth/authtypes"
clienthttputil "github.com/element-hq/dendrite/clientapi/httputil"
"github.com/element-hq/dendrite/clientapi/userutil"
"github.com/element-hq/dendrite/internal/httputil"
@ -31,6 +32,7 @@ import (
"github.com/element-hq/dendrite/setup/jetstream"
"github.com/element-hq/dendrite/userapi/api"
userapi "github.com/element-hq/dendrite/userapi/api"
"github.com/element-hq/dendrite/userapi/storage/shared"
)
const (
@ -842,11 +844,13 @@ type adminExternalID struct {
type adminCreateOrModifyAccountRequest struct {
DisplayName string `json:"displayname"`
AvatarURL string `json:"avatar_url"`
// TODO: the following fields are not used here, but they are used in Synapse. Probably we should reproduce the logic of the
// endpoint fully compatible.
ThreePIDs []struct {
Medium string `json:"medium"`
Address string `json:"address"`
} `json:"threepids"`
// TODO: the following fields are not used here, but they are used in Synapse.
// Password string `json:"password"`
// LogoutDevices bool `json:"logout_devices"`
// Threepids json.RawMessage `json:"threepids"`
// ExternalIDs []adminExternalID `json:"external_ids"`
// Admin bool `json:"admin"`
// Deactivated bool `json:"deactivated"`
@ -872,9 +876,11 @@ func AdminCreateOrModifyAccount(req *http.Request, userAPI userapi.ClientUserAPI
logger.Debugf("UnmarshalJSONRequest failed: %+v", *resErr)
return *resErr
}
logger.Debugf("adminCreateOrModifyAccountRequest is: %+v", r)
logger.Debugf("adminCreateOrModifyAccountRequest is: %#v", r)
statusCode := http.StatusOK
{
// TODO: Ideally, the following commands should be executed in one transaction.
// can we propagate the tx object and pass it in context?
var res userapi.PerformAccountCreationResponse
err = userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeUser,
@ -885,12 +891,31 @@ func AdminCreateOrModifyAccount(req *http.Request, userAPI userapi.ClientUserAPI
DisplayName: r.DisplayName,
}, &res)
if err != nil {
logger.WithError(err).Debugln("Failed creating account")
logger.WithError(err).Error("userAPI.PerformAccountCreation")
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if res.AccountCreated {
statusCode = http.StatusCreated
}
if l := len(r.ThreePIDs); l > 0 {
logger.Debugf("Trying to bulk save 3PID associations: %+v", r.ThreePIDs)
threePIDs := make([]authtypes.ThreePID, 0, len(r.ThreePIDs))
for i := range r.ThreePIDs {
tpid := &r.ThreePIDs[i]
threePIDs = append(threePIDs, authtypes.ThreePID{Medium: tpid.Medium, Address: tpid.Address})
}
err = userAPI.PerformBulkSaveThreePIDAssociation(req.Context(), &userapi.PerformBulkSaveThreePIDAssociationRequest{
ThreePIDs: threePIDs,
Localpart: local,
ServerName: domain,
}, &struct{}{})
if err == shared.Err3PIDInUse {
return util.MessageResponse(http.StatusBadRequest, err.Error())
} else if err != nil {
logger.WithError(err).Error("userAPI.PerformSaveThreePIDAssociation")
return util.ErrorResponse(err)
}
}
return util.JSONResponse{

View file

@ -111,6 +111,7 @@ type ClientUserAPI interface {
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error
PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error
PerformBulkSaveThreePIDAssociation(ctx context.Context, req *PerformBulkSaveThreePIDAssociationRequest, res *struct{}) error
}
type KeyBackupAPI interface {
@ -653,6 +654,12 @@ type PerformSaveThreePIDAssociationRequest struct {
Medium string
}
type PerformBulkSaveThreePIDAssociationRequest struct {
ThreePIDs []authtypes.ThreePID
Localpart string
ServerName spec.ServerName
}
type QueryAccountByLocalpartRequest struct {
Localpart string
ServerName spec.ServerName

View file

@ -987,4 +987,8 @@ func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, re
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium)
}
func (a *UserInternalAPI) PerformBulkSaveThreePIDAssociation(ctx context.Context, req *api.PerformBulkSaveThreePIDAssociationRequest, res *struct{}) error {
return a.DB.BulkSaveThreePIDAssociation(ctx, req.ThreePIDs, req.Localpart, req.ServerName)
}
const pushRulesAccountDataType = "m.push_rules"

View file

@ -120,6 +120,7 @@ type Pusher interface {
type ThreePID interface {
SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName spec.ServerName, medium string) (err error)
BulkSaveThreePIDAssociation(ctx context.Context, threePIDs []authtypes.ThreePID, localpart string, serverName spec.ServerName) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName spec.ServerName, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (threepids []authtypes.ThreePID, err error)

View file

@ -353,6 +353,41 @@ func (d *Database) SaveThreePIDAssociation(
})
}
// BulkSaveThreePIDAssociation recreates 3PIDs for a user.
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) BulkSaveThreePIDAssociation(ctx context.Context, threePIDs []authtypes.ThreePID, localpart string, serverName spec.ServerName) (err error) {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
oldThreePIDs, err := d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName)
if err != nil {
return err
}
for _, t := range oldThreePIDs {
if err := d.ThreePIDs.DeleteThreePID(ctx, txn, t.Address, t.Medium); err != nil {
return err
}
}
for _, t := range threePIDs {
// if 3PID is associated with another user, return Err3PIDInUse
user, _, err := d.ThreePIDs.SelectLocalpartForThreePID(
ctx, txn, t.Address, t.Medium,
)
if err != nil {
return err
}
if len(user) > 0 && user != localpart {
return Err3PIDInUse
}
if err = d.ThreePIDs.InsertThreePID(ctx, txn, t.Address, t.Medium, localpart, serverName); err != nil {
return err
}
}
return nil
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.