tests for sessionsDict.crossSigningKeysReplacement

This commit is contained in:
Roman Isaev 2025-01-25 02:02:24 +00:00
parent c1ad175178
commit b8ea41b2ad
No known key found for this signature in database
GPG key ID: 7BE2B6A6C89AEC7F
2 changed files with 72 additions and 18 deletions

View file

@ -66,16 +66,16 @@ type sessionsDict struct {
// If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2, // If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2,
// the delete request will fail for device2 since the UIA was initiated by trying to delete device1. // the delete request will fail for device2 since the UIA was initiated by trying to delete device1.
deleteSessionToDeviceID map[string]string deleteSessionToDeviceID map[string]string
// allowedForCrossSigningKeysReplacement is a collection of sessions that MAS has authorised for updating // crossSigningKeysReplacement is a collection of sessions that MAS has authorised for updating
// cross-signing keys without UIA. // cross-signing keys without UIA.
allowedForCrossSigningKeysReplacement map[string]*time.Timer crossSigningKeysReplacement map[string]*time.Timer
} }
// defaultTimeout is the timeout used to clean up sessions // defaultTimeout is the timeout used to clean up sessions
const defaultTimeOut = time.Minute * 5 const defaultTimeOut = time.Minute * 5
// allowedForCrossSigningKeysReplacementDuration is the timeout used for replacing cross signing keys without UIA // crossSigningKeysReplacementDuration is the timeout used for replacing cross signing keys without UIA
const allowedForCrossSigningKeysReplacementDuration = time.Minute * 10 const crossSigningKeysReplacementDuration = time.Minute * 10
// getCompletedStages returns the completed stages for a session. // getCompletedStages returns the completed stages for a session.
func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType { func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
@ -128,14 +128,14 @@ func (d *sessionsDict) deleteSession(sessionID string) {
func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 { func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 {
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
ts := time.Now().Add(allowedForCrossSigningKeysReplacementDuration).UnixMilli() ts := time.Now().Add(crossSigningKeysReplacementDuration).UnixMilli()
t, ok := d.allowedForCrossSigningKeysReplacement[userID] t, ok := d.crossSigningKeysReplacement[userID]
if ok { if ok {
t.Reset(allowedForCrossSigningKeysReplacementDuration) t.Reset(crossSigningKeysReplacementDuration)
return ts return ts
} }
d.allowedForCrossSigningKeysReplacement[userID] = time.AfterFunc( d.crossSigningKeysReplacement[userID] = time.AfterFunc(
allowedForCrossSigningKeysReplacementDuration, crossSigningKeysReplacementDuration,
func() { func() {
d.restrictCrossSigningKeysReplacement(userID) d.restrictCrossSigningKeysReplacement(userID)
}, },
@ -146,14 +146,14 @@ func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 {
func (d *sessionsDict) isCrossSigningKeysReplacementAllowed(userID string) bool { func (d *sessionsDict) isCrossSigningKeysReplacementAllowed(userID string) bool {
d.RLock() d.RLock()
defer d.RUnlock() defer d.RUnlock()
_, ok := d.allowedForCrossSigningKeysReplacement[userID] _, ok := d.crossSigningKeysReplacement[userID]
return ok return ok
} }
func (d *sessionsDict) restrictCrossSigningKeysReplacement(userID string) { func (d *sessionsDict) restrictCrossSigningKeysReplacement(userID string) {
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
t, ok := d.allowedForCrossSigningKeysReplacement[userID] t, ok := d.crossSigningKeysReplacement[userID]
if ok { if ok {
if !t.Stop() { if !t.Stop() {
select { select {
@ -161,7 +161,7 @@ func (d *sessionsDict) restrictCrossSigningKeysReplacement(userID string) {
default: default:
} }
} }
delete(d.allowedForCrossSigningKeysReplacement, userID) delete(d.crossSigningKeysReplacement, userID)
} }
} }
@ -172,7 +172,7 @@ func newSessionsDict() *sessionsDict {
params: make(map[string]registerRequest), params: make(map[string]registerRequest),
timer: make(map[string]*time.Timer), timer: make(map[string]*time.Timer),
deleteSessionToDeviceID: make(map[string]string), deleteSessionToDeviceID: make(map[string]string),
allowedForCrossSigningKeysReplacement: make(map[string]*time.Timer), crossSigningKeysReplacement: make(map[string]*time.Timer),
} }
} }

View file

@ -669,3 +669,57 @@ func TestRegisterAdminUsingSharedSecret(t *testing.T) {
assert.Equal(t, expectedDisplayName, profile.DisplayName) assert.Equal(t, expectedDisplayName, profile.DisplayName)
}) })
} }
func TestCrossSigningKeysReplacement(t *testing.T) {
userID := "@user:example.com"
t.Run("Can add new session", func(t *testing.T) {
s := newSessionsDict()
assert.Empty(t, s.crossSigningKeysReplacement)
s.allowCrossSigningKeysReplacement(userID)
assert.Len(t, s.crossSigningKeysReplacement, 1)
assert.Contains(t, s.crossSigningKeysReplacement, userID)
})
t.Run("Can check if session exists or not", func(t *testing.T) {
s := newSessionsDict()
t.Run("exists", func(t *testing.T) {
s.allowCrossSigningKeysReplacement(userID)
assert.Len(t, s.crossSigningKeysReplacement, 1)
assert.True(t, s.isCrossSigningKeysReplacementAllowed(userID))
})
t.Run("not exists", func(t *testing.T) {
assert.False(t, s.isCrossSigningKeysReplacementAllowed("@random:test.com"))
})
})
t.Run("Can deactivate session", func(t *testing.T) {
s := newSessionsDict()
assert.Empty(t, s.crossSigningKeysReplacement)
t.Run("not exists", func(t *testing.T) {
s.restrictCrossSigningKeysReplacement("@random:test.com")
assert.Empty(t, s.crossSigningKeysReplacement)
})
t.Run("exists", func(t *testing.T) {
s.allowCrossSigningKeysReplacement(userID)
s.restrictCrossSigningKeysReplacement(userID)
assert.Empty(t, s.crossSigningKeysReplacement)
})
})
t.Run("Can erase expired sessions", func(t *testing.T) {
s := newSessionsDict()
s.allowCrossSigningKeysReplacement(userID)
assert.Len(t, s.crossSigningKeysReplacement, 1)
assert.True(t, s.isCrossSigningKeysReplacementAllowed(userID))
timer := s.crossSigningKeysReplacement[userID]
// pretending the timer is expired
timer.Reset(time.Millisecond)
time.Sleep(time.Millisecond * 500)
assert.Empty(t, s.crossSigningKeysReplacement)
})
}