mas: modify PUT /profile/{userID}/displayname endpoint

Extended logic of the endpoint in order to make it compatible with
MAS
This commit is contained in:
Roman Isaev 2024-12-30 20:31:10 +00:00
parent 524f65cb0c
commit ff63e7fa98
No known key found for this signature in database
GPG key ID: 7BE2B6A6C89AEC7F
2 changed files with 50 additions and 36 deletions

View file

@ -760,9 +760,8 @@ func AdminRetrieveAccount(req *http.Request, cfg *config.ClientAPI, userAPI user
Deactivated bool `json:"deactivated"` Deactivated bool `json:"deactivated"`
}{} }{}
{
var rs api.QueryAccountByLocalpartResponse var rs api.QueryAccountByLocalpartResponse
err := userAPI.QueryAccountByLocalpart(req.Context(), &api.QueryAccountByLocalpartRequest{Localpart: local, ServerName: domain}, &rs) err = userAPI.QueryAccountByLocalpart(req.Context(), &api.QueryAccountByLocalpartRequest{Localpart: local, ServerName: domain}, &rs)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
@ -776,9 +775,7 @@ func AdminRetrieveAccount(req *http.Request, cfg *config.ClientAPI, userAPI user
} }
} }
body.Deactivated = rs.Account.Deactivated body.Deactivated = rs.Account.Deactivated
}
{
profile, err := userAPI.QueryProfile(req.Context(), userID) profile, err := userAPI.QueryProfile(req.Context(), userID)
if err != nil { if err != nil {
if err == appserviceAPI.ErrProfileNotExists { if err == appserviceAPI.ErrProfileNotExists {
@ -795,7 +792,6 @@ func AdminRetrieveAccount(req *http.Request, cfg *config.ClientAPI, userAPI user
} }
body.AvatarURL = profile.AvatarURL body.AvatarURL = profile.AvatarURL
body.DisplayName = profile.DisplayName body.DisplayName = profile.DisplayName
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,

View file

@ -172,24 +172,20 @@ func GetDisplayName(
// SetDisplayName implements PUT /profile/{userID}/displayname // SetDisplayName implements PUT /profile/{userID}/displayname
func SetDisplayName( func SetDisplayName(
req *http.Request, profileAPI userapi.ProfileAPI, req *http.Request, userAPI userapi.ClientUserAPI,
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.ClientRoomserverAPI, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.ClientRoomserverAPI,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID && device.AccountType != userapi.AccountTypeOIDCService {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden("userID does not match the current user"), JSON: spec.Forbidden("userID does not match the current user"),
} }
} }
var r eventutil.UserProfile logger := util.GetLogger(req.Context())
if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil {
return *resErr
}
localpart, domain, err := gomatrixserverlib.SplitID('@', userID) localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") logger.WithError(err).Error("gomatrixserverlib.SplitID failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
@ -203,6 +199,28 @@ func SetDisplayName(
} }
} }
if device.AccountType == userapi.AccountTypeOIDCService {
// When a request is made on behalf of an OIDC provider service, the original device object refers
// to the provider's pseudo-device and includes only the AccountTypeOIDCService flag. To continue,
// we need to replace the admin's device with the user's device
var rs userapi.QueryDevicesResponse
err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{UserID: userID}, &rs)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
if len(rs.Devices) > 0 {
device = &rs.Devices[0]
}
}
var r eventutil.UserProfile
if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil {
return *resErr
}
evTime, err := httputil.ParseTSParam(req) evTime, err := httputil.ParseTSParam(req)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -211,9 +229,9 @@ func SetDisplayName(
} }
} }
profile, changed, err := profileAPI.SetDisplayName(req.Context(), localpart, domain, r.DisplayName) profile, changed, err := userAPI.SetDisplayName(req.Context(), localpart, domain, r.DisplayName)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed") logger.WithError(err).Error("profileAPI.SetDisplayName failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},