From fd52c7eb1f0e62f0db0ea198a2a6c4ac5674429b Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 12 Feb 2025 17:45:31 +0000 Subject: [PATCH] msc3861: cr --- setup/monolith.go | 4 ++-- setup/mscs/msc3861/msc3861_user_verifier.go | 23 +++++++-------------- userapi/api/api.go | 20 ++---------------- userapi/internal/user_api.go | 9 ++++---- 4 files changed, 15 insertions(+), 41 deletions(-) diff --git a/setup/monolith.go b/setup/monolith.go index 7dd67705..b61633c1 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -89,11 +89,11 @@ func (m *Monolith) AddAllPublicRoutes( } type UserVerifierProvider struct { - UserVerifier httputil.UserVerifier + httputil.UserVerifier } func (u *UserVerifierProvider) VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse) { - return u.UserVerifier.VerifyUserFromRequest(req) + return u.VerifyUserFromRequest(req) } func NewUserVerifierProvider(userVerifier httputil.UserVerifier) *UserVerifierProvider { diff --git a/setup/mscs/msc3861/msc3861_user_verifier.go b/setup/mscs/msc3861/msc3861_user_verifier.go index 0de9b7fc..30f578dc 100644 --- a/setup/mscs/msc3861/msc3861_user_verifier.go +++ b/setup/mscs/msc3861/msc3861_user_verifier.go @@ -202,17 +202,12 @@ func (m *MSC3861UserVerifier) getUserByAccessToken(ctx context.Context, token st } localpart := "" - { - var rs api.QueryLocalpartExternalIDResponse - if err = m.userAPI.QueryExternalUserIDByLocalpartAndProvider(ctx, &api.QueryLocalpartExternalIDRequest{ - ExternalID: sub, - AuthProvider: externalAuthProvider, - }, &rs); err != nil && err != sql.ErrNoRows { - return nil, err - } - if l := rs.LocalpartExternalID; l != nil { - localpart = l.Localpart - } + localpartExternalID, err := m.userAPI.QueryExternalUserIDByLocalpartAndProvider(ctx, sub, externalAuthProvider) + if err != nil && err != sql.ErrNoRows { + return nil, err + } + if localpartExternalID != nil { + localpart = localpartExternalID.Localpart } if localpart == "" { @@ -253,11 +248,7 @@ func (m *MSC3861UserVerifier) getUserByAccessToken(ctx context.Context, token st } } - if err = m.userAPI.PerformLocalpartExternalUserIDCreation(ctx, &api.PerformLocalpartExternalUserIDCreationRequest{ - Localpart: userID.Local(), - ExternalID: sub, - AuthProvider: externalAuthProvider, - }); err != nil { + if err = m.userAPI.PerformLocalpartExternalUserIDCreation(ctx, userID.Local(), sub, externalAuthProvider); err != nil { logger.WithError(err).Error("PerformLocalpartExternalUserIDCreation") return nil, err } diff --git a/userapi/api/api.go b/userapi/api/api.go index 9b131998..31059f5a 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -31,8 +31,8 @@ type UserInternalAPI interface { FederationUserAPI QuerySearchProfilesAPI // used by p2p demos - QueryExternalUserIDByLocalpartAndProvider(ctx context.Context, req *QueryLocalpartExternalIDRequest, res *QueryLocalpartExternalIDResponse) (err error) - PerformLocalpartExternalUserIDCreation(ctx context.Context, req *PerformLocalpartExternalUserIDCreationRequest) (err error) + QueryExternalUserIDByLocalpartAndProvider(ctx context.Context, externalID, authProvider string) (*LocalpartExternalID, error) + PerformLocalpartExternalUserIDCreation(ctx context.Context, localpart, externalID, authProvider string) (error) } // api functions required by the appservice api @@ -667,22 +667,6 @@ type QueryAccountByLocalpartRequest struct { type QueryAccountByLocalpartResponse struct { Account *Account } - -type QueryLocalpartExternalIDRequest struct { - ExternalID string - AuthProvider string -} - -type QueryLocalpartExternalIDResponse struct { - LocalpartExternalID *LocalpartExternalID -} - -type PerformLocalpartExternalUserIDCreationRequest struct { - Localpart string - ExternalID string - AuthProvider string -} - // API functions required by the clientapi type ClientKeyAPI interface { UploadDeviceKeysAPI diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 2b500c95..e4c846e5 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -604,13 +604,12 @@ func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api. return } -func (a *UserInternalAPI) PerformLocalpartExternalUserIDCreation(ctx context.Context, req *api.PerformLocalpartExternalUserIDCreationRequest) (err error) { - return a.DB.CreateLocalpartExternalID(ctx, req.Localpart, req.ExternalID, req.AuthProvider) +func (a *UserInternalAPI) PerformLocalpartExternalUserIDCreation(ctx context.Context, localpart, externalID, authProvider string) (err error) { + return a.DB.CreateLocalpartExternalID(ctx, localpart, externalID, authProvider) } -func (a *UserInternalAPI) QueryExternalUserIDByLocalpartAndProvider(ctx context.Context, req *api.QueryLocalpartExternalIDRequest, res *api.QueryLocalpartExternalIDResponse) (err error) { - res.LocalpartExternalID, err = a.DB.GetLocalpartForExternalID(ctx, req.ExternalID, req.AuthProvider) - return +func (a *UserInternalAPI) QueryExternalUserIDByLocalpartAndProvider(ctx context.Context, externalID, authProvider string) (*api.LocalpartExternalID, error) { + return a.DB.GetLocalpartForExternalID(ctx, externalID, authProvider) } // Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem