From ff63e7fa983386441709d84fc9df6ac33fbfaa59 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Mon, 30 Dec 2024 20:31:10 +0000 Subject: [PATCH] mas: modify PUT /profile/{userID}/displayname endpoint Extended logic of the endpoint in order to make it compatible with MAS --- clientapi/routing/admin.go | 48 +++++++++++++++++------------------- clientapi/routing/profile.go | 38 ++++++++++++++++++++-------- 2 files changed, 50 insertions(+), 36 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 0532a577..c51ea7e9 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -760,42 +760,38 @@ func AdminRetrieveAccount(req *http.Request, cfg *config.ClientAPI, userAPI user Deactivated bool `json:"deactivated"` }{} - { - var rs api.QueryAccountByLocalpartResponse - err := userAPI.QueryAccountByLocalpart(req.Context(), &api.QueryAccountByLocalpartRequest{Localpart: local, ServerName: domain}, &rs) - if err == sql.ErrNoRows { + var rs api.QueryAccountByLocalpartResponse + err = userAPI.QueryAccountByLocalpart(req.Context(), &api.QueryAccountByLocalpartRequest{Localpart: local, ServerName: domain}, &rs) + if err == sql.ErrNoRows { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(fmt.Sprintf("User '%s' not found", userID)), + } + } else if err != nil { + logger.WithError(err).Error("userAPI.QueryAccountByLocalpart") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown(err.Error()), + } + } + body.Deactivated = rs.Account.Deactivated + + profile, err := userAPI.QueryProfile(req.Context(), userID) + if err != nil { + if err == appserviceAPI.ErrProfileNotExists { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: spec.NotFound(fmt.Sprintf("User '%s' not found", userID)), + JSON: spec.NotFound(err.Error()), } } else if err != nil { - logger.WithError(err).Error("userAPI.QueryAccountByLocalpart") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.Unknown(err.Error()), } } - body.Deactivated = rs.Account.Deactivated - } - - { - profile, err := userAPI.QueryProfile(req.Context(), userID) - if err != nil { - if err == appserviceAPI.ErrProfileNotExists { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.NotFound(err.Error()), - } - } else if err != nil { - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.Unknown(err.Error()), - } - } - } - body.AvatarURL = profile.AvatarURL - body.DisplayName = profile.DisplayName } + body.AvatarURL = profile.AvatarURL + body.DisplayName = profile.DisplayName return util.JSONResponse{ Code: http.StatusOK, diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index b75d38a6..74bbddbc 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -172,24 +172,20 @@ func GetDisplayName( // SetDisplayName implements PUT /profile/{userID}/displayname func SetDisplayName( - req *http.Request, profileAPI userapi.ProfileAPI, + req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.ClientRoomserverAPI, ) util.JSONResponse { - if userID != device.UserID { + if userID != device.UserID && device.AccountType != userapi.AccountTypeOIDCService { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("userID does not match the current user"), } } - var r eventutil.UserProfile - if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { - return *resErr - } - + logger := util.GetLogger(req.Context()) localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") + logger.WithError(err).Error("gomatrixserverlib.SplitID failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}, @@ -203,6 +199,28 @@ func SetDisplayName( } } + if device.AccountType == userapi.AccountTypeOIDCService { + // When a request is made on behalf of an OIDC provider service, the original device object refers + // to the provider's pseudo-device and includes only the AccountTypeOIDCService flag. To continue, + // we need to replace the admin's device with the user's device + var rs userapi.QueryDevicesResponse + err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{UserID: userID}, &rs) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + if len(rs.Devices) > 0 { + device = &rs.Devices[0] + } + } + + var r eventutil.UserProfile + if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { + return *resErr + } + evTime, err := httputil.ParseTSParam(req) if err != nil { return util.JSONResponse{ @@ -211,9 +229,9 @@ func SetDisplayName( } } - profile, changed, err := profileAPI.SetDisplayName(req.Context(), localpart, domain, r.DisplayName) + profile, changed, err := userAPI.SetDisplayName(req.Context(), localpart, domain, r.DisplayName) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed") + logger.WithError(err).Error("profileAPI.SetDisplayName failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.InternalServerError{},