more fixes

This commit is contained in:
Roman Isaev 2025-01-23 01:58:49 +00:00
parent ea875b3b6b
commit bf31c44298
No known key found for this signature in database
GPG key ID: 7BE2B6A6C89AEC7F
3 changed files with 27 additions and 12 deletions

View file

@ -20,7 +20,8 @@ type DefaultUserVerifier struct {
// Note: For an AS user, AS dummy device is returned. // Note: For an AS user, AS dummy device is returned.
// On failure returns an JSON error response which can be sent to the client. // On failure returns an JSON error response which can be sent to the client.
func (d *DefaultUserVerifier) VerifyUserFromRequest(req *http.Request) (*api.Device, *util.JSONResponse) { func (d *DefaultUserVerifier) VerifyUserFromRequest(req *http.Request) (*api.Device, *util.JSONResponse) {
util.GetLogger(req.Context()).Debug("Default VerifyUserFromRequest") ctx := req.Context()
util.GetLogger(ctx).Debug("Default VerifyUserFromRequest")
// Try to find the Application Service user // Try to find the Application Service user
token, err := ExtractAccessToken(req) token, err := ExtractAccessToken(req)
if err != nil { if err != nil {
@ -30,12 +31,12 @@ func (d *DefaultUserVerifier) VerifyUserFromRequest(req *http.Request) (*api.Dev
} }
} }
var res api.QueryAccessTokenResponse var res api.QueryAccessTokenResponse
err = d.UserAPI.QueryAccessToken(req.Context(), &api.QueryAccessTokenRequest{ err = d.UserAPI.QueryAccessToken(ctx, &api.QueryAccessTokenRequest{
AccessToken: token, AccessToken: token,
AppServiceUserID: req.URL.Query().Get("user_id"), AppServiceUserID: req.URL.Query().Get("user_id"),
}, &res) }, &res)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccessToken failed") util.GetLogger(ctx).WithError(err).Error("userAPI.QueryAccessToken failed")
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},

View file

@ -2,13 +2,15 @@ package msc3861
import ( import (
"github.com/element-hq/dendrite/setup" "github.com/element-hq/dendrite/setup"
"github.com/matrix-org/gomatrixserverlib/fclient"
) )
func Enable(m *setup.Monolith) error { func Enable(m *setup.Monolith) error {
client := fclient.NewClient()
userVerifier, err := newMSC3861UserVerifier( userVerifier, err := newMSC3861UserVerifier(
m.UserAPI, m.Config.Global.ServerName, m.UserAPI, m.Config.Global.ServerName,
m.Config.MSCs.MSC3861, !m.Config.ClientAPI.GuestsDisabled, m.Config.MSCs.MSC3861, !m.Config.ClientAPI.GuestsDisabled,
nil, client,
) )
if err != nil { if err != nil {
return err return err

View file

@ -1,7 +1,6 @@
package msc3861 package msc3861
import ( import (
"cmp"
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
@ -15,6 +14,7 @@ import (
"github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/clientapi/auth"
"github.com/element-hq/dendrite/setup/config" "github.com/element-hq/dendrite/setup/config"
"github.com/element-hq/dendrite/userapi/api" "github.com/element-hq/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -45,7 +45,7 @@ type MSC3861UserVerifier struct {
userAPI api.UserInternalAPI userAPI api.UserInternalAPI
serverName spec.ServerName serverName spec.ServerName
cfg *config.MSC3861 cfg *config.MSC3861
httpClient *http.Client httpClient *fclient.Client
openIdConfig *OpenIDConfiguration openIdConfig *OpenIDConfiguration
allowGuest bool allowGuest bool
} }
@ -55,13 +55,21 @@ func newMSC3861UserVerifier(
serverName spec.ServerName, serverName spec.ServerName,
cfg *config.MSC3861, cfg *config.MSC3861,
allowGuest bool, allowGuest bool,
httpClient *http.Client, client *fclient.Client,
) (*MSC3861UserVerifier, error) { ) (*MSC3861UserVerifier, error) {
client := cmp.Or(httpClient, http.DefaultClient) if cfg == nil {
return nil, errors.New("unable to create MSC3861UserVerifier object as 'cfg' param is nil")
}
if client == nil {
return nil, errors.New("unable to create MSC3861UserVerifier object as 'client' param is nil")
}
openIdConfig, err := fetchOpenIDConfiguration(client, cfg.Issuer) openIdConfig, err := fetchOpenIDConfiguration(client, cfg.Issuer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &MSC3861UserVerifier{ return &MSC3861UserVerifier{
userAPI: userAPI, userAPI: userAPI,
serverName: serverName, serverName: serverName,
@ -342,14 +350,14 @@ func (m *MSC3861UserVerifier) introspectToken(ctx context.Context, token string)
req.Header.Add("Content-Type", "application/x-www-form-urlencoded") req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(m.cfg.ClientID, m.cfg.ClientSecret) req.SetBasicAuth(m.cfg.ClientID, m.cfg.ClientSecret)
resp, err := m.httpClient.Do(req) resp, err := m.httpClient.DoHTTPRequest(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
body := resp.Body body := resp.Body
defer resp.Body.Close() // nolint: errcheck defer resp.Body.Close() // nolint: errcheck
if c := resp.StatusCode; c < 200 || c >= 300 { if c := resp.StatusCode; c/100 != 2 {
return nil, errors.New(strings.Join([]string{"The introspection endpoint returned a '", resp.Status, "' response"}, "")) return nil, errors.New(strings.Join([]string{"The introspection endpoint returned a '", resp.Status, "' response"}, ""))
} }
var ir introspectionResponse var ir introspectionResponse
@ -394,13 +402,17 @@ type OpenIDConfiguration struct {
AccountManagementActionsSupported []string `json:"account_management_actions_supported"` AccountManagementActionsSupported []string `json:"account_management_actions_supported"`
} }
func fetchOpenIDConfiguration(httpClient *http.Client, authHostURL string) (*OpenIDConfiguration, error) { func fetchOpenIDConfiguration(httpClient *fclient.Client, authHostURL string) (*OpenIDConfiguration, error) {
u, err := url.Parse(authHostURL) u, err := url.Parse(authHostURL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
u = u.JoinPath(".well-known/openid-configuration") u = u.JoinPath(".well-known/openid-configuration")
resp, err := httpClient.Get(u.String()) req, err := http.NewRequest(http.MethodGet, u.String(), nil)
if err != nil {
return nil, err
}
resp, err := httpClient.DoHTTPRequest(context.Background(), req)
if err != nil { if err != nil {
return nil, err return nil, err
} }