mirror of
https://github.com/element-hq/dendrite.git
synced 2025-09-14 05:12:26 +03:00
tests for sessionsDict.crossSigningKeysReplacement
This commit is contained in:
parent
c1ad175178
commit
b8ea41b2ad
2 changed files with 72 additions and 18 deletions
|
@ -66,16 +66,16 @@ type sessionsDict struct {
|
||||||
// If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2,
|
// If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2,
|
||||||
// the delete request will fail for device2 since the UIA was initiated by trying to delete device1.
|
// the delete request will fail for device2 since the UIA was initiated by trying to delete device1.
|
||||||
deleteSessionToDeviceID map[string]string
|
deleteSessionToDeviceID map[string]string
|
||||||
// allowedForCrossSigningKeysReplacement is a collection of sessions that MAS has authorised for updating
|
// crossSigningKeysReplacement is a collection of sessions that MAS has authorised for updating
|
||||||
// cross-signing keys without UIA.
|
// cross-signing keys without UIA.
|
||||||
allowedForCrossSigningKeysReplacement map[string]*time.Timer
|
crossSigningKeysReplacement map[string]*time.Timer
|
||||||
}
|
}
|
||||||
|
|
||||||
// defaultTimeout is the timeout used to clean up sessions
|
// defaultTimeout is the timeout used to clean up sessions
|
||||||
const defaultTimeOut = time.Minute * 5
|
const defaultTimeOut = time.Minute * 5
|
||||||
|
|
||||||
// allowedForCrossSigningKeysReplacementDuration is the timeout used for replacing cross signing keys without UIA
|
// crossSigningKeysReplacementDuration is the timeout used for replacing cross signing keys without UIA
|
||||||
const allowedForCrossSigningKeysReplacementDuration = time.Minute * 10
|
const crossSigningKeysReplacementDuration = time.Minute * 10
|
||||||
|
|
||||||
// getCompletedStages returns the completed stages for a session.
|
// getCompletedStages returns the completed stages for a session.
|
||||||
func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
|
func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
|
||||||
|
@ -128,14 +128,14 @@ func (d *sessionsDict) deleteSession(sessionID string) {
|
||||||
func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 {
|
func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 {
|
||||||
d.Lock()
|
d.Lock()
|
||||||
defer d.Unlock()
|
defer d.Unlock()
|
||||||
ts := time.Now().Add(allowedForCrossSigningKeysReplacementDuration).UnixMilli()
|
ts := time.Now().Add(crossSigningKeysReplacementDuration).UnixMilli()
|
||||||
t, ok := d.allowedForCrossSigningKeysReplacement[userID]
|
t, ok := d.crossSigningKeysReplacement[userID]
|
||||||
if ok {
|
if ok {
|
||||||
t.Reset(allowedForCrossSigningKeysReplacementDuration)
|
t.Reset(crossSigningKeysReplacementDuration)
|
||||||
return ts
|
return ts
|
||||||
}
|
}
|
||||||
d.allowedForCrossSigningKeysReplacement[userID] = time.AfterFunc(
|
d.crossSigningKeysReplacement[userID] = time.AfterFunc(
|
||||||
allowedForCrossSigningKeysReplacementDuration,
|
crossSigningKeysReplacementDuration,
|
||||||
func() {
|
func() {
|
||||||
d.restrictCrossSigningKeysReplacement(userID)
|
d.restrictCrossSigningKeysReplacement(userID)
|
||||||
},
|
},
|
||||||
|
@ -146,14 +146,14 @@ func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 {
|
||||||
func (d *sessionsDict) isCrossSigningKeysReplacementAllowed(userID string) bool {
|
func (d *sessionsDict) isCrossSigningKeysReplacementAllowed(userID string) bool {
|
||||||
d.RLock()
|
d.RLock()
|
||||||
defer d.RUnlock()
|
defer d.RUnlock()
|
||||||
_, ok := d.allowedForCrossSigningKeysReplacement[userID]
|
_, ok := d.crossSigningKeysReplacement[userID]
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *sessionsDict) restrictCrossSigningKeysReplacement(userID string) {
|
func (d *sessionsDict) restrictCrossSigningKeysReplacement(userID string) {
|
||||||
d.Lock()
|
d.Lock()
|
||||||
defer d.Unlock()
|
defer d.Unlock()
|
||||||
t, ok := d.allowedForCrossSigningKeysReplacement[userID]
|
t, ok := d.crossSigningKeysReplacement[userID]
|
||||||
if ok {
|
if ok {
|
||||||
if !t.Stop() {
|
if !t.Stop() {
|
||||||
select {
|
select {
|
||||||
|
@ -161,18 +161,18 @@ func (d *sessionsDict) restrictCrossSigningKeysReplacement(userID string) {
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete(d.allowedForCrossSigningKeysReplacement, userID)
|
delete(d.crossSigningKeysReplacement, userID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSessionsDict() *sessionsDict {
|
func newSessionsDict() *sessionsDict {
|
||||||
return &sessionsDict{
|
return &sessionsDict{
|
||||||
sessions: make(map[string][]authtypes.LoginType),
|
sessions: make(map[string][]authtypes.LoginType),
|
||||||
sessionCompletedResult: make(map[string]registerResponse),
|
sessionCompletedResult: make(map[string]registerResponse),
|
||||||
params: make(map[string]registerRequest),
|
params: make(map[string]registerRequest),
|
||||||
timer: make(map[string]*time.Timer),
|
timer: make(map[string]*time.Timer),
|
||||||
deleteSessionToDeviceID: make(map[string]string),
|
deleteSessionToDeviceID: make(map[string]string),
|
||||||
allowedForCrossSigningKeysReplacement: make(map[string]*time.Timer),
|
crossSigningKeysReplacement: make(map[string]*time.Timer),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -669,3 +669,57 @@ func TestRegisterAdminUsingSharedSecret(t *testing.T) {
|
||||||
assert.Equal(t, expectedDisplayName, profile.DisplayName)
|
assert.Equal(t, expectedDisplayName, profile.DisplayName)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCrossSigningKeysReplacement(t *testing.T) {
|
||||||
|
userID := "@user:example.com"
|
||||||
|
|
||||||
|
t.Run("Can add new session", func(t *testing.T) {
|
||||||
|
s := newSessionsDict()
|
||||||
|
assert.Empty(t, s.crossSigningKeysReplacement)
|
||||||
|
s.allowCrossSigningKeysReplacement(userID)
|
||||||
|
assert.Len(t, s.crossSigningKeysReplacement, 1)
|
||||||
|
assert.Contains(t, s.crossSigningKeysReplacement, userID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Can check if session exists or not", func(t *testing.T) {
|
||||||
|
s := newSessionsDict()
|
||||||
|
t.Run("exists", func(t *testing.T) {
|
||||||
|
s.allowCrossSigningKeysReplacement(userID)
|
||||||
|
assert.Len(t, s.crossSigningKeysReplacement, 1)
|
||||||
|
assert.True(t, s.isCrossSigningKeysReplacementAllowed(userID))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("not exists", func(t *testing.T) {
|
||||||
|
assert.False(t, s.isCrossSigningKeysReplacementAllowed("@random:test.com"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Can deactivate session", func(t *testing.T) {
|
||||||
|
s := newSessionsDict()
|
||||||
|
assert.Empty(t, s.crossSigningKeysReplacement)
|
||||||
|
t.Run("not exists", func(t *testing.T) {
|
||||||
|
s.restrictCrossSigningKeysReplacement("@random:test.com")
|
||||||
|
assert.Empty(t, s.crossSigningKeysReplacement)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("exists", func(t *testing.T) {
|
||||||
|
s.allowCrossSigningKeysReplacement(userID)
|
||||||
|
s.restrictCrossSigningKeysReplacement(userID)
|
||||||
|
assert.Empty(t, s.crossSigningKeysReplacement)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Can erase expired sessions", func(t *testing.T) {
|
||||||
|
s := newSessionsDict()
|
||||||
|
s.allowCrossSigningKeysReplacement(userID)
|
||||||
|
assert.Len(t, s.crossSigningKeysReplacement, 1)
|
||||||
|
assert.True(t, s.isCrossSigningKeysReplacementAllowed(userID))
|
||||||
|
timer := s.crossSigningKeysReplacement[userID]
|
||||||
|
|
||||||
|
// pretending the timer is expired
|
||||||
|
timer.Reset(time.Millisecond)
|
||||||
|
time.Sleep(time.Millisecond * 500)
|
||||||
|
|
||||||
|
assert.Empty(t, s.crossSigningKeysReplacement)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue