mirror of
https://github.com/element-hq/dendrite.git
synced 2025-09-15 05:32:25 +03:00
mas: store crossSigngingKeysReplacement period in sessionsDict struct instead of db
This commit is contained in:
parent
b5f34dfe47
commit
453445695c
3 changed files with 61 additions and 39 deletions
|
@ -35,10 +35,6 @@ import (
|
||||||
"github.com/element-hq/dendrite/userapi/storage/shared"
|
"github.com/element-hq/dendrite/userapi/storage/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
replacementPeriod time.Duration = 10 * time.Minute
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$")
|
validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$")
|
||||||
deviceDisplayName = "OIDC-native client"
|
deviceDisplayName = "OIDC-native client"
|
||||||
|
@ -807,27 +803,10 @@ func AdminAllowCrossSigningReplacementWithoutUIA(
|
||||||
|
|
||||||
switch req.Method {
|
switch req.Method {
|
||||||
case http.MethodPost:
|
case http.MethodPost:
|
||||||
rq := userapi.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIARequest{
|
ts := sessions.allowCrossSigningKeysReplacement(userID.String())
|
||||||
UserID: userID.String(),
|
|
||||||
Duration: replacementPeriod,
|
|
||||||
}
|
|
||||||
var rs userapi.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIAResponse
|
|
||||||
err = userAPI.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIA(req.Context(), &rq, &rs)
|
|
||||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("userAPI.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIA")
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: http.StatusInternalServerError,
|
|
||||||
JSON: spec.Unknown(err.Error()),
|
|
||||||
}
|
|
||||||
} else if errors.Is(err, sql.ErrNoRows) {
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: http.StatusNotFound,
|
|
||||||
JSON: spec.NotFound("User not found."),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: map[string]int64{"updatable_without_uia_before_ms": rs.Timestamp},
|
JSON: map[string]int64{"updatable_without_uia_before_ms": ts},
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
|
|
@ -49,11 +49,6 @@ func UploadCrossSigningDeviceKeys(
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionID := uploadReq.Auth.Session
|
|
||||||
if sessionID == "" {
|
|
||||||
sessionID = util.RandomString(sessionIDLength)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query existing keys to determine if UIA is required
|
// Query existing keys to determine if UIA is required
|
||||||
keyResp := api.QueryKeysResponse{}
|
keyResp := api.QueryKeysResponse{}
|
||||||
keyserverAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{
|
keyserverAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{
|
||||||
|
@ -68,7 +63,6 @@ func UploadCrossSigningDeviceKeys(
|
||||||
}
|
}
|
||||||
|
|
||||||
existingMasterKey, hasMasterKey := keyResp.MasterKeys[device.UserID]
|
existingMasterKey, hasMasterKey := keyResp.MasterKeys[device.UserID]
|
||||||
requireUIA := true
|
|
||||||
|
|
||||||
if hasMasterKey {
|
if hasMasterKey {
|
||||||
if !keysDiffer(existingMasterKey, keyResp, uploadReq, device.UserID) {
|
if !keysDiffer(existingMasterKey, keyResp, uploadReq, device.UserID) {
|
||||||
|
@ -89,10 +83,8 @@ func UploadCrossSigningDeviceKeys(
|
||||||
logger.WithError(masterKeyResp.Error).Error("Failed to query master key")
|
logger.WithError(masterKeyResp.Error).Error("Failed to query master key")
|
||||||
return convertKeyError(masterKeyResp.Error)
|
return convertKeyError(masterKeyResp.Error)
|
||||||
}
|
}
|
||||||
if k := masterKeyResp.Key; k != nil && k.UpdatableWithoutUIABeforeMs != nil {
|
|
||||||
requireUIA = !(time.Now().UnixMilli() < *k.UpdatableWithoutUIABeforeMs)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
requireUIA := !sessions.isCrossSigningKeysReplacementAllowed(device.UserID) && masterKeyResp.Key != nil
|
||||||
if requireUIA {
|
if requireUIA {
|
||||||
url := ""
|
url := ""
|
||||||
if m := cfg.MSCs.MSC3861; m.AccountManagementURL != "" {
|
if m := cfg.MSCs.MSC3861; m.AccountManagementURL != "" {
|
||||||
|
@ -122,9 +114,13 @@ func UploadCrossSigningDeviceKeys(
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// XXX: is it necessary?
|
sessions.restrictCrossSigningKeysReplacement(device.UserID)
|
||||||
sessions.addCompletedSessionStage(sessionID, CrossSigningResetStage)
|
|
||||||
} else {
|
} else {
|
||||||
|
sessionID := uploadReq.Auth.Session
|
||||||
|
if sessionID == "" {
|
||||||
|
sessionID = util.RandomString(sessionIDLength)
|
||||||
|
}
|
||||||
|
|
||||||
if uploadReq.Auth.Type != authtypes.LoginTypePassword {
|
if uploadReq.Auth.Type != authtypes.LoginTypePassword {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusUnauthorized,
|
Code: http.StatusUnauthorized,
|
||||||
|
|
|
@ -66,11 +66,17 @@ type sessionsDict struct {
|
||||||
// If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2,
|
// If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2,
|
||||||
// the delete request will fail for device2 since the UIA was initiated by trying to delete device1.
|
// the delete request will fail for device2 since the UIA was initiated by trying to delete device1.
|
||||||
deleteSessionToDeviceID map[string]string
|
deleteSessionToDeviceID map[string]string
|
||||||
|
// allowedForCrossSigningKeysReplacement is a collection of sessions that MAS has authorised for updating
|
||||||
|
// cross-signing keys without UIA.
|
||||||
|
allowedForCrossSigningKeysReplacement map[string]*time.Timer
|
||||||
}
|
}
|
||||||
|
|
||||||
// defaultTimeout is the timeout used to clean up sessions
|
// defaultTimeout is the timeout used to clean up sessions
|
||||||
const defaultTimeOut = time.Minute * 5
|
const defaultTimeOut = time.Minute * 5
|
||||||
|
|
||||||
|
// allowedForCrossSigningKeysReplacementDuration is the timeout used for replacing cross signing keys without UIA
|
||||||
|
const allowedForCrossSigningKeysReplacementDuration = time.Minute * 10
|
||||||
|
|
||||||
// getCompletedStages returns the completed stages for a session.
|
// getCompletedStages returns the completed stages for a session.
|
||||||
func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
|
func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
|
||||||
d.RLock()
|
d.RLock()
|
||||||
|
@ -119,6 +125,46 @@ func (d *sessionsDict) deleteSession(sessionID string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 {
|
||||||
|
d.Lock()
|
||||||
|
defer d.Unlock()
|
||||||
|
ts := time.Now().Add(allowedForCrossSigningKeysReplacementDuration).UnixMilli()
|
||||||
|
t, ok := d.allowedForCrossSigningKeysReplacement[userID]
|
||||||
|
if ok {
|
||||||
|
t.Reset(allowedForCrossSigningKeysReplacementDuration)
|
||||||
|
return ts
|
||||||
|
}
|
||||||
|
d.allowedForCrossSigningKeysReplacement[userID] = time.AfterFunc(
|
||||||
|
allowedForCrossSigningKeysReplacementDuration,
|
||||||
|
func() {
|
||||||
|
d.restrictCrossSigningKeysReplacement(userID)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return ts
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *sessionsDict) isCrossSigningKeysReplacementAllowed(userID string) bool {
|
||||||
|
d.RLock()
|
||||||
|
defer d.RUnlock()
|
||||||
|
_, ok := d.allowedForCrossSigningKeysReplacement[userID]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *sessionsDict) restrictCrossSigningKeysReplacement(userID string) {
|
||||||
|
d.Lock()
|
||||||
|
defer d.Unlock()
|
||||||
|
t, ok := d.allowedForCrossSigningKeysReplacement[userID]
|
||||||
|
if ok {
|
||||||
|
if !t.Stop() {
|
||||||
|
select {
|
||||||
|
case <-t.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(d.allowedForCrossSigningKeysReplacement, userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func newSessionsDict() *sessionsDict {
|
func newSessionsDict() *sessionsDict {
|
||||||
return &sessionsDict{
|
return &sessionsDict{
|
||||||
sessions: make(map[string][]authtypes.LoginType),
|
sessions: make(map[string][]authtypes.LoginType),
|
||||||
|
@ -126,6 +172,7 @@ func newSessionsDict() *sessionsDict {
|
||||||
params: make(map[string]registerRequest),
|
params: make(map[string]registerRequest),
|
||||||
timer: make(map[string]*time.Timer),
|
timer: make(map[string]*time.Timer),
|
||||||
deleteSessionToDeviceID: make(map[string]string),
|
deleteSessionToDeviceID: make(map[string]string),
|
||||||
|
allowedForCrossSigningKeysReplacement: make(map[string]*time.Timer),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue