diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index b213d5aa..68112154 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -23,6 +23,7 @@ import ( appserviceAPI "github.com/element-hq/dendrite/appservice/api" clientapi "github.com/element-hq/dendrite/clientapi/api" + "github.com/element-hq/dendrite/clientapi/auth/authtypes" clienthttputil "github.com/element-hq/dendrite/clientapi/httputil" "github.com/element-hq/dendrite/clientapi/userutil" "github.com/element-hq/dendrite/internal/httputil" @@ -31,6 +32,7 @@ import ( "github.com/element-hq/dendrite/setup/jetstream" "github.com/element-hq/dendrite/userapi/api" userapi "github.com/element-hq/dendrite/userapi/api" + "github.com/element-hq/dendrite/userapi/storage/shared" ) const ( @@ -842,11 +844,13 @@ type adminExternalID struct { type adminCreateOrModifyAccountRequest struct { DisplayName string `json:"displayname"` AvatarURL string `json:"avatar_url"` - // TODO: the following fields are not used here, but they are used in Synapse. Probably we should reproduce the logic of the - // endpoint fully compatible. + ThreePIDs []struct { + Medium string `json:"medium"` + Address string `json:"address"` + } `json:"threepids"` + // TODO: the following fields are not used here, but they are used in Synapse. // Password string `json:"password"` // LogoutDevices bool `json:"logout_devices"` - // Threepids json.RawMessage `json:"threepids"` // ExternalIDs []adminExternalID `json:"external_ids"` // Admin bool `json:"admin"` // Deactivated bool `json:"deactivated"` @@ -872,24 +876,45 @@ func AdminCreateOrModifyAccount(req *http.Request, userAPI userapi.ClientUserAPI logger.Debugf("UnmarshalJSONRequest failed: %+v", *resErr) return *resErr } - logger.Debugf("adminCreateOrModifyAccountRequest is: %+v", r) + logger.Debugf("adminCreateOrModifyAccountRequest is: %#v", r) statusCode := http.StatusOK - { - var res userapi.PerformAccountCreationResponse - err = userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ - AccountType: userapi.AccountTypeUser, - Localpart: local, - ServerName: domain, - OnConflict: api.ConflictUpdate, - AvatarURL: r.AvatarURL, - DisplayName: r.DisplayName, - }, &res) - if err != nil { - logger.WithError(err).Debugln("Failed creating account") - return util.MessageResponse(http.StatusBadRequest, err.Error()) + + // TODO: Ideally, the following commands should be executed in one transaction. + // can we propagate the tx object and pass it in context? + var res userapi.PerformAccountCreationResponse + err = userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ + AccountType: userapi.AccountTypeUser, + Localpart: local, + ServerName: domain, + OnConflict: api.ConflictUpdate, + AvatarURL: r.AvatarURL, + DisplayName: r.DisplayName, + }, &res) + if err != nil { + logger.WithError(err).Error("userAPI.PerformAccountCreation") + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if res.AccountCreated { + statusCode = http.StatusCreated + } + + if l := len(r.ThreePIDs); l > 0 { + logger.Debugf("Trying to bulk save 3PID associations: %+v", r.ThreePIDs) + threePIDs := make([]authtypes.ThreePID, 0, len(r.ThreePIDs)) + for i := range r.ThreePIDs { + tpid := &r.ThreePIDs[i] + threePIDs = append(threePIDs, authtypes.ThreePID{Medium: tpid.Medium, Address: tpid.Address}) } - if res.AccountCreated { - statusCode = http.StatusCreated + err = userAPI.PerformBulkSaveThreePIDAssociation(req.Context(), &userapi.PerformBulkSaveThreePIDAssociationRequest{ + ThreePIDs: threePIDs, + Localpart: local, + ServerName: domain, + }, &struct{}{}) + if err == shared.Err3PIDInUse { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } else if err != nil { + logger.WithError(err).Error("userAPI.PerformSaveThreePIDAssociation") + return util.ErrorResponse(err) } } diff --git a/userapi/api/api.go b/userapi/api/api.go index 2efa8976..08308bc3 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -111,6 +111,7 @@ type ClientUserAPI interface { QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error + PerformBulkSaveThreePIDAssociation(ctx context.Context, req *PerformBulkSaveThreePIDAssociationRequest, res *struct{}) error } type KeyBackupAPI interface { @@ -653,6 +654,12 @@ type PerformSaveThreePIDAssociationRequest struct { Medium string } +type PerformBulkSaveThreePIDAssociationRequest struct { + ThreePIDs []authtypes.ThreePID + Localpart string + ServerName spec.ServerName +} + type QueryAccountByLocalpartRequest struct { Localpart string ServerName spec.ServerName diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index a7760c1b..b691496e 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -987,4 +987,8 @@ func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, re return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium) } +func (a *UserInternalAPI) PerformBulkSaveThreePIDAssociation(ctx context.Context, req *api.PerformBulkSaveThreePIDAssociationRequest, res *struct{}) error { + return a.DB.BulkSaveThreePIDAssociation(ctx, req.ThreePIDs, req.Localpart, req.ServerName) +} + const pushRulesAccountDataType = "m.push_rules" diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 11b36095..2c0b4bf2 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -120,6 +120,7 @@ type Pusher interface { type ThreePID interface { SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName spec.ServerName, medium string) (err error) + BulkSaveThreePIDAssociation(ctx context.Context, threePIDs []authtypes.ThreePID, localpart string, serverName spec.ServerName) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName spec.ServerName, err error) GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (threepids []authtypes.ThreePID, err error) diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index e76cde61..3d6d51ed 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -353,6 +353,41 @@ func (d *Database) SaveThreePIDAssociation( }) } +// BulkSaveThreePIDAssociation recreates 3PIDs for a user. +// If the third-party identifier is already part of an association, returns Err3PIDInUse. +// Returns an error if there was a problem talking to the database. +func (d *Database) BulkSaveThreePIDAssociation(ctx context.Context, threePIDs []authtypes.ThreePID, localpart string, serverName spec.ServerName) (err error) { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + oldThreePIDs, err := d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName) + if err != nil { + return err + } + for _, t := range oldThreePIDs { + if err := d.ThreePIDs.DeleteThreePID(ctx, txn, t.Address, t.Medium); err != nil { + return err + } + } + for _, t := range threePIDs { + // if 3PID is associated with another user, return Err3PIDInUse + user, _, err := d.ThreePIDs.SelectLocalpartForThreePID( + ctx, txn, t.Address, t.Medium, + ) + if err != nil { + return err + } + + if len(user) > 0 && user != localpart { + return Err3PIDInUse + } + + if err = d.ThreePIDs.InsertThreePID(ctx, txn, t.Address, t.Medium, localpart, serverName); err != nil { + return err + } + } + return nil + }) +} + // RemoveThreePIDAssociation removes the association involving a given third-party // identifier. // If no association exists involving this third-party identifier, returns nothing.