mirror of
https://github.com/element-hq/dendrite.git
synced 2025-09-14 05:12:26 +03:00
Add clientapi tests (#2916)
This PR - adds several tests for the clientapi, mostly around `/register` and auth fallback. - removes the now deprecated `homeserver` field from responses to `/register` and `/login` - slightly refactors auth fallback handling
This commit is contained in:
parent
f47515e38b
commit
f762ce1050
15 changed files with 838 additions and 220 deletions
|
@ -18,12 +18,12 @@ package routing
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -60,10 +60,7 @@ var (
|
|||
)
|
||||
)
|
||||
|
||||
const (
|
||||
maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain
|
||||
sessionIDLength = 24
|
||||
)
|
||||
const sessionIDLength = 24
|
||||
|
||||
// sessionsDict keeps track of completed auth stages for each session.
|
||||
// It shouldn't be passed by value because it contains a mutex.
|
||||
|
@ -198,8 +195,7 @@ func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) {
|
|||
}
|
||||
|
||||
var (
|
||||
sessions = newSessionsDict()
|
||||
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
|
||||
sessions = newSessionsDict()
|
||||
)
|
||||
|
||||
// registerRequest represents the submitted registration request.
|
||||
|
@ -262,10 +258,9 @@ func newUserInteractiveResponse(
|
|||
|
||||
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register
|
||||
type registerResponse struct {
|
||||
UserID string `json:"user_id"`
|
||||
AccessToken string `json:"access_token,omitempty"`
|
||||
HomeServer gomatrixserverlib.ServerName `json:"home_server"`
|
||||
DeviceID string `json:"device_id,omitempty"`
|
||||
UserID string `json:"user_id"`
|
||||
AccessToken string `json:"access_token,omitempty"`
|
||||
DeviceID string `json:"device_id,omitempty"`
|
||||
}
|
||||
|
||||
// recaptchaResponse represents the HTTP response from a Google Recaptcha server
|
||||
|
@ -276,66 +271,28 @@ type recaptchaResponse struct {
|
|||
ErrorCodes []int `json:"error-codes"`
|
||||
}
|
||||
|
||||
// validateUsername returns an error response if the username is invalid
|
||||
func validateUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse {
|
||||
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)),
|
||||
}
|
||||
} else if !validUsernameRegex.MatchString(localpart) {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
|
||||
}
|
||||
} else if localpart[0] == '_' { // Regex checks its not a zero length string
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateApplicationServiceUsername returns an error response if the username is invalid for an application service
|
||||
func validateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse {
|
||||
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)),
|
||||
}
|
||||
} else if !validUsernameRegex.MatchString(localpart) {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
var (
|
||||
ErrInvalidCaptcha = errors.New("invalid captcha response")
|
||||
ErrMissingResponse = errors.New("captcha response is required")
|
||||
ErrCaptchaDisabled = errors.New("captcha registration is disabled")
|
||||
)
|
||||
|
||||
// validateRecaptcha returns an error response if the captcha response is invalid
|
||||
func validateRecaptcha(
|
||||
cfg *config.ClientAPI,
|
||||
response string,
|
||||
clientip string,
|
||||
) *util.JSONResponse {
|
||||
) error {
|
||||
ip, _, _ := net.SplitHostPort(clientip)
|
||||
if !cfg.RecaptchaEnabled {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusConflict,
|
||||
JSON: jsonerror.Unknown("Captcha registration is disabled"),
|
||||
}
|
||||
return ErrCaptchaDisabled
|
||||
}
|
||||
|
||||
if response == "" {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON("Captcha response is required"),
|
||||
}
|
||||
return ErrMissingResponse
|
||||
}
|
||||
|
||||
// Make a POST request to Google's API to check the captcha response
|
||||
// Make a POST request to the captcha provider API to check the captcha response
|
||||
resp, err := http.PostForm(cfg.RecaptchaSiteVerifyAPI,
|
||||
url.Values{
|
||||
"secret": {cfg.RecaptchaPrivateKey},
|
||||
|
@ -345,10 +302,7 @@ func validateRecaptcha(
|
|||
)
|
||||
|
||||
if err != nil {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: jsonerror.BadJSON("Error in requesting validation of captcha response"),
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Close the request once we're finishing reading from it
|
||||
|
@ -358,25 +312,16 @@ func validateRecaptcha(
|
|||
var r recaptchaResponse
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusGatewayTimeout,
|
||||
JSON: jsonerror.Unknown("Error in contacting captcha server" + err.Error()),
|
||||
}
|
||||
return err
|
||||
}
|
||||
err = json.Unmarshal(body, &r)
|
||||
if err != nil {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: jsonerror.BadJSON("Error in unmarshaling captcha server's response: " + err.Error()),
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Check that we received a "success"
|
||||
if !r.Success {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: jsonerror.BadJSON("Invalid captcha response. Please try again."),
|
||||
}
|
||||
return ErrInvalidCaptcha
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -508,8 +453,8 @@ func validateApplicationService(
|
|||
}
|
||||
|
||||
// Check username application service is trying to register is valid
|
||||
if err := validateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil {
|
||||
return "", err
|
||||
if err := internal.ValidateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil {
|
||||
return "", internal.UsernameResponse(err)
|
||||
}
|
||||
|
||||
// No errors, registration valid
|
||||
|
@ -564,15 +509,12 @@ func Register(
|
|||
if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil {
|
||||
r.Username, r.ServerName = l, d
|
||||
}
|
||||
if req.URL.Query().Get("kind") == "guest" {
|
||||
return handleGuestRegistration(req, r, cfg, userAPI)
|
||||
}
|
||||
|
||||
// Don't allow numeric usernames less than MAX_INT64.
|
||||
if _, err := strconv.ParseInt(r.Username, 10, 64); err == nil {
|
||||
if _, err = strconv.ParseInt(r.Username, 10, 64); err == nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"),
|
||||
|
@ -584,7 +526,7 @@ func Register(
|
|||
ServerName: r.ServerName,
|
||||
}
|
||||
nres := &userapi.QueryNumericLocalpartResponse{}
|
||||
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
|
||||
if err = userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
@ -601,8 +543,8 @@ func Register(
|
|||
case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil:
|
||||
// Spec-compliant case (the access_token is specified and the login type
|
||||
// is correctly set, so it's an appservice registration)
|
||||
if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil {
|
||||
return *resErr
|
||||
if err = internal.ValidateApplicationServiceUsername(r.Username, r.ServerName); err != nil {
|
||||
return *internal.UsernameResponse(err)
|
||||
}
|
||||
case accessTokenErr == nil:
|
||||
// Non-spec-compliant case (the access_token is specified but the login
|
||||
|
@ -614,12 +556,12 @@ func Register(
|
|||
default:
|
||||
// Spec-compliant case (neither the access_token nor the login type are
|
||||
// specified, so it's a normal user registration)
|
||||
if resErr := validateUsername(r.Username, r.ServerName); resErr != nil {
|
||||
return *resErr
|
||||
if err = internal.ValidateUsername(r.Username, r.ServerName); err != nil {
|
||||
return *internal.UsernameResponse(err)
|
||||
}
|
||||
}
|
||||
if resErr := internal.ValidatePassword(r.Password); resErr != nil {
|
||||
return *resErr
|
||||
if err = internal.ValidatePassword(r.Password); err != nil {
|
||||
return *internal.PasswordResponse(err)
|
||||
}
|
||||
|
||||
logger := util.GetLogger(req.Context())
|
||||
|
@ -697,7 +639,6 @@ func handleGuestRegistration(
|
|||
JSON: registerResponse{
|
||||
UserID: devRes.Device.UserID,
|
||||
AccessToken: devRes.Device.AccessToken,
|
||||
HomeServer: res.Account.ServerName,
|
||||
DeviceID: devRes.Device.ID,
|
||||
},
|
||||
}
|
||||
|
@ -761,9 +702,18 @@ func handleRegistrationFlow(
|
|||
switch r.Auth.Type {
|
||||
case authtypes.LoginTypeRecaptcha:
|
||||
// Check given captcha response
|
||||
resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr)
|
||||
if resErr != nil {
|
||||
return *resErr
|
||||
err := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr)
|
||||
switch err {
|
||||
case ErrCaptchaDisabled:
|
||||
return util.JSONResponse{Code: http.StatusForbidden, JSON: jsonerror.Unknown(err.Error())}
|
||||
case ErrMissingResponse:
|
||||
return util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error())}
|
||||
case ErrInvalidCaptcha:
|
||||
return util.JSONResponse{Code: http.StatusUnauthorized, JSON: jsonerror.BadJSON(err.Error())}
|
||||
case nil:
|
||||
default:
|
||||
util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha")
|
||||
return util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()}
|
||||
}
|
||||
|
||||
// Add Recaptcha to the list of completed registration stages
|
||||
|
@ -924,8 +874,7 @@ func completeRegistration(
|
|||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: registerResponse{
|
||||
UserID: userutil.MakeUserID(username, accRes.Account.ServerName),
|
||||
HomeServer: accRes.Account.ServerName,
|
||||
UserID: userutil.MakeUserID(username, accRes.Account.ServerName),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -958,7 +907,6 @@ func completeRegistration(
|
|||
result := registerResponse{
|
||||
UserID: devRes.Device.UserID,
|
||||
AccessToken: devRes.Device.AccessToken,
|
||||
HomeServer: accRes.Account.ServerName,
|
||||
DeviceID: devRes.Device.ID,
|
||||
}
|
||||
sessions.addCompletedRegistration(sessionID, result)
|
||||
|
@ -1054,8 +1002,8 @@ func RegisterAvailable(
|
|||
}
|
||||
}
|
||||
|
||||
if err := validateUsername(username, domain); err != nil {
|
||||
return *err
|
||||
if err := internal.ValidateUsername(username, domain); err != nil {
|
||||
return *internal.UsernameResponse(err)
|
||||
}
|
||||
|
||||
// Check if this username is reserved by an application service
|
||||
|
@ -1117,11 +1065,11 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien
|
|||
// downcase capitals
|
||||
ssrr.User = strings.ToLower(ssrr.User)
|
||||
|
||||
if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil {
|
||||
return *resErr
|
||||
if err = internal.ValidateUsername(ssrr.User, cfg.Matrix.ServerName); err != nil {
|
||||
return *internal.UsernameResponse(err)
|
||||
}
|
||||
if resErr := internal.ValidatePassword(ssrr.Password); resErr != nil {
|
||||
return *resErr
|
||||
if err = internal.ValidatePassword(ssrr.Password); err != nil {
|
||||
return *internal.PasswordResponse(err)
|
||||
}
|
||||
deviceID := "shared_secret_registration"
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue