diff --git a/clientapi/auth/default_user_verifier.go b/clientapi/auth/default_user_verifier.go index 6e6746d1..54147b77 100644 --- a/clientapi/auth/default_user_verifier.go +++ b/clientapi/auth/default_user_verifier.go @@ -20,7 +20,8 @@ type DefaultUserVerifier struct { // Note: For an AS user, AS dummy device is returned. // On failure returns an JSON error response which can be sent to the client. func (d *DefaultUserVerifier) VerifyUserFromRequest(req *http.Request) (*api.Device, *util.JSONResponse) { - util.GetLogger(req.Context()).Debug("Default VerifyUserFromRequest") + ctx := req.Context() + util.GetLogger(ctx).Debug("Default VerifyUserFromRequest") // Try to find the Application Service user token, err := ExtractAccessToken(req) if err != nil { @@ -30,12 +31,12 @@ func (d *DefaultUserVerifier) VerifyUserFromRequest(req *http.Request) (*api.Dev } } var res api.QueryAccessTokenResponse - err = d.UserAPI.QueryAccessToken(req.Context(), &api.QueryAccessTokenRequest{ + err = d.UserAPI.QueryAccessToken(ctx, &api.QueryAccessTokenRequest{ AccessToken: token, AppServiceUserID: req.URL.Query().Get("user_id"), }, &res) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccessToken failed") + util.GetLogger(ctx).WithError(err).Error("userAPI.QueryAccessToken failed") return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}, diff --git a/setup/mscs/msc3861/msc3861.go b/setup/mscs/msc3861/msc3861.go index b3c458b1..8a0df647 100644 --- a/setup/mscs/msc3861/msc3861.go +++ b/setup/mscs/msc3861/msc3861.go @@ -2,13 +2,15 @@ package msc3861 import ( "github.com/element-hq/dendrite/setup" + "github.com/matrix-org/gomatrixserverlib/fclient" ) func Enable(m *setup.Monolith) error { + client := fclient.NewClient() userVerifier, err := newMSC3861UserVerifier( m.UserAPI, m.Config.Global.ServerName, m.Config.MSCs.MSC3861, !m.Config.ClientAPI.GuestsDisabled, - nil, + client, ) if err != nil { return err diff --git a/setup/mscs/msc3861/msc3861_user_verifier.go b/setup/mscs/msc3861/msc3861_user_verifier.go index fbfdcaa2..65cf5957 100644 --- a/setup/mscs/msc3861/msc3861_user_verifier.go +++ b/setup/mscs/msc3861/msc3861_user_verifier.go @@ -1,7 +1,6 @@ package msc3861 import ( - "cmp" "context" "database/sql" "encoding/json" @@ -15,6 +14,7 @@ import ( "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/setup/config" "github.com/element-hq/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -45,7 +45,7 @@ type MSC3861UserVerifier struct { userAPI api.UserInternalAPI serverName spec.ServerName cfg *config.MSC3861 - httpClient *http.Client + httpClient *fclient.Client openIdConfig *OpenIDConfiguration allowGuest bool } @@ -55,13 +55,21 @@ func newMSC3861UserVerifier( serverName spec.ServerName, cfg *config.MSC3861, allowGuest bool, - httpClient *http.Client, + client *fclient.Client, ) (*MSC3861UserVerifier, error) { - client := cmp.Or(httpClient, http.DefaultClient) + if cfg == nil { + return nil, errors.New("unable to create MSC3861UserVerifier object as 'cfg' param is nil") + } + + if client == nil { + return nil, errors.New("unable to create MSC3861UserVerifier object as 'client' param is nil") + } + openIdConfig, err := fetchOpenIDConfiguration(client, cfg.Issuer) if err != nil { return nil, err } + return &MSC3861UserVerifier{ userAPI: userAPI, serverName: serverName, @@ -342,14 +350,14 @@ func (m *MSC3861UserVerifier) introspectToken(ctx context.Context, token string) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") req.SetBasicAuth(m.cfg.ClientID, m.cfg.ClientSecret) - resp, err := m.httpClient.Do(req) + resp, err := m.httpClient.DoHTTPRequest(ctx, req) if err != nil { return nil, err } body := resp.Body defer resp.Body.Close() // nolint: errcheck - if c := resp.StatusCode; c < 200 || c >= 300 { + if c := resp.StatusCode; c/100 != 2 { return nil, errors.New(strings.Join([]string{"The introspection endpoint returned a '", resp.Status, "' response"}, "")) } var ir introspectionResponse @@ -394,13 +402,17 @@ type OpenIDConfiguration struct { AccountManagementActionsSupported []string `json:"account_management_actions_supported"` } -func fetchOpenIDConfiguration(httpClient *http.Client, authHostURL string) (*OpenIDConfiguration, error) { +func fetchOpenIDConfiguration(httpClient *fclient.Client, authHostURL string) (*OpenIDConfiguration, error) { u, err := url.Parse(authHostURL) if err != nil { return nil, err } u = u.JoinPath(".well-known/openid-configuration") - resp, err := httpClient.Get(u.String()) + req, err := http.NewRequest(http.MethodGet, u.String(), nil) + if err != nil { + return nil, err + } + resp, err := httpClient.DoHTTPRequest(context.Background(), req) if err != nil { return nil, err }