From b8ea41b2adab75e64aa95116e76c59525474fcca Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Sat, 25 Jan 2025 02:02:24 +0000 Subject: [PATCH] tests for sessionsDict.crossSigningKeysReplacement --- clientapi/routing/register.go | 36 ++++++++++---------- clientapi/routing/register_test.go | 54 ++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 18 deletions(-) diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 74140d4c..da43a6b0 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -66,16 +66,16 @@ type sessionsDict struct { // If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2, // the delete request will fail for device2 since the UIA was initiated by trying to delete device1. deleteSessionToDeviceID map[string]string - // allowedForCrossSigningKeysReplacement is a collection of sessions that MAS has authorised for updating + // crossSigningKeysReplacement is a collection of sessions that MAS has authorised for updating // cross-signing keys without UIA. - allowedForCrossSigningKeysReplacement map[string]*time.Timer + crossSigningKeysReplacement map[string]*time.Timer } // defaultTimeout is the timeout used to clean up sessions const defaultTimeOut = time.Minute * 5 -// allowedForCrossSigningKeysReplacementDuration is the timeout used for replacing cross signing keys without UIA -const allowedForCrossSigningKeysReplacementDuration = time.Minute * 10 +// crossSigningKeysReplacementDuration is the timeout used for replacing cross signing keys without UIA +const crossSigningKeysReplacementDuration = time.Minute * 10 // getCompletedStages returns the completed stages for a session. func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType { @@ -128,14 +128,14 @@ func (d *sessionsDict) deleteSession(sessionID string) { func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 { d.Lock() defer d.Unlock() - ts := time.Now().Add(allowedForCrossSigningKeysReplacementDuration).UnixMilli() - t, ok := d.allowedForCrossSigningKeysReplacement[userID] + ts := time.Now().Add(crossSigningKeysReplacementDuration).UnixMilli() + t, ok := d.crossSigningKeysReplacement[userID] if ok { - t.Reset(allowedForCrossSigningKeysReplacementDuration) + t.Reset(crossSigningKeysReplacementDuration) return ts } - d.allowedForCrossSigningKeysReplacement[userID] = time.AfterFunc( - allowedForCrossSigningKeysReplacementDuration, + d.crossSigningKeysReplacement[userID] = time.AfterFunc( + crossSigningKeysReplacementDuration, func() { d.restrictCrossSigningKeysReplacement(userID) }, @@ -146,14 +146,14 @@ func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 { func (d *sessionsDict) isCrossSigningKeysReplacementAllowed(userID string) bool { d.RLock() defer d.RUnlock() - _, ok := d.allowedForCrossSigningKeysReplacement[userID] + _, ok := d.crossSigningKeysReplacement[userID] return ok } func (d *sessionsDict) restrictCrossSigningKeysReplacement(userID string) { d.Lock() defer d.Unlock() - t, ok := d.allowedForCrossSigningKeysReplacement[userID] + t, ok := d.crossSigningKeysReplacement[userID] if ok { if !t.Stop() { select { @@ -161,18 +161,18 @@ func (d *sessionsDict) restrictCrossSigningKeysReplacement(userID string) { default: } } - delete(d.allowedForCrossSigningKeysReplacement, userID) + delete(d.crossSigningKeysReplacement, userID) } } func newSessionsDict() *sessionsDict { return &sessionsDict{ - sessions: make(map[string][]authtypes.LoginType), - sessionCompletedResult: make(map[string]registerResponse), - params: make(map[string]registerRequest), - timer: make(map[string]*time.Timer), - deleteSessionToDeviceID: make(map[string]string), - allowedForCrossSigningKeysReplacement: make(map[string]*time.Timer), + sessions: make(map[string][]authtypes.LoginType), + sessionCompletedResult: make(map[string]registerResponse), + params: make(map[string]registerRequest), + timer: make(map[string]*time.Timer), + deleteSessionToDeviceID: make(map[string]string), + crossSigningKeysReplacement: make(map[string]*time.Timer), } } diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 71cc0ca6..8529d7c5 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -669,3 +669,57 @@ func TestRegisterAdminUsingSharedSecret(t *testing.T) { assert.Equal(t, expectedDisplayName, profile.DisplayName) }) } + +func TestCrossSigningKeysReplacement(t *testing.T) { + userID := "@user:example.com" + + t.Run("Can add new session", func(t *testing.T) { + s := newSessionsDict() + assert.Empty(t, s.crossSigningKeysReplacement) + s.allowCrossSigningKeysReplacement(userID) + assert.Len(t, s.crossSigningKeysReplacement, 1) + assert.Contains(t, s.crossSigningKeysReplacement, userID) + }) + + t.Run("Can check if session exists or not", func(t *testing.T) { + s := newSessionsDict() + t.Run("exists", func(t *testing.T) { + s.allowCrossSigningKeysReplacement(userID) + assert.Len(t, s.crossSigningKeysReplacement, 1) + assert.True(t, s.isCrossSigningKeysReplacementAllowed(userID)) + }) + + t.Run("not exists", func(t *testing.T) { + assert.False(t, s.isCrossSigningKeysReplacementAllowed("@random:test.com")) + }) + }) + + t.Run("Can deactivate session", func(t *testing.T) { + s := newSessionsDict() + assert.Empty(t, s.crossSigningKeysReplacement) + t.Run("not exists", func(t *testing.T) { + s.restrictCrossSigningKeysReplacement("@random:test.com") + assert.Empty(t, s.crossSigningKeysReplacement) + }) + + t.Run("exists", func(t *testing.T) { + s.allowCrossSigningKeysReplacement(userID) + s.restrictCrossSigningKeysReplacement(userID) + assert.Empty(t, s.crossSigningKeysReplacement) + }) + }) + + t.Run("Can erase expired sessions", func(t *testing.T) { + s := newSessionsDict() + s.allowCrossSigningKeysReplacement(userID) + assert.Len(t, s.crossSigningKeysReplacement, 1) + assert.True(t, s.isCrossSigningKeysReplacementAllowed(userID)) + timer := s.crossSigningKeysReplacement[userID] + + // pretending the timer is expired + timer.Reset(time.Millisecond) + time.Sleep(time.Millisecond * 500) + + assert.Empty(t, s.crossSigningKeysReplacement) + }) +}