diff --git a/clientapi/routing/key_crosssigning_test.go b/clientapi/routing/key_crosssigning_test.go index 0ebb91e0..1339b45c 100644 --- a/clientapi/routing/key_crosssigning_test.go +++ b/clientapi/routing/key_crosssigning_test.go @@ -10,6 +10,8 @@ import ( "strings" "testing" + "github.com/element-hq/dendrite/userapi/types" + "github.com/element-hq/dendrite/setup/config" "github.com/element-hq/dendrite/test" "github.com/element-hq/dendrite/test/testrig" @@ -20,19 +22,28 @@ import ( ) type mockKeyAPI struct { - t *testing.T - userResponses map[string]api.QueryKeysResponse + t *testing.T + queryKeysData map[string]api.QueryKeysResponse + queryMasterKeysData map[string]api.QueryMasterKeysResponse } func (m mockKeyAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { - res.MasterKeys = m.userResponses[req.UserID].MasterKeys - res.SelfSigningKeys = m.userResponses[req.UserID].SelfSigningKeys - res.UserSigningKeys = m.userResponses[req.UserID].UserSigningKeys + res.MasterKeys = m.queryKeysData[req.UserID].MasterKeys + res.SelfSigningKeys = m.queryKeysData[req.UserID].SelfSigningKeys + res.UserSigningKeys = m.queryKeysData[req.UserID].UserSigningKeys if m.t != nil { m.t.Logf("QueryKeys: %+v => %+v", req, res) } } +func (m mockKeyAPI) QueryMasterKeys(ctx context.Context, req *api.QueryMasterKeysRequest, res *api.QueryMasterKeysResponse) { + res.Key = m.queryMasterKeysData[req.UserID].Key + res.Error = m.queryMasterKeysData[req.UserID].Error + if m.t != nil { + m.t.Logf("QueryMasterKeys: %+v => %+v", req, res) + } +} + func (m mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) { // Just a dummy upload which always succeeds } @@ -53,13 +64,19 @@ func Test_UploadCrossSigningDeviceKeys_ValidRequest(t *testing.T) { req.Header.Set("Content-Type", "application/json") keyserverAPI := &mockKeyAPI{ - userResponses: map[string]api.QueryKeysResponse{ + queryKeysData: map[string]api.QueryKeysResponse{ + "@user:example.com": {}, + }, + queryMasterKeysData: map[string]api.QueryMasterKeysResponse{ "@user:example.com": {}, }, } device := &api.Device{UserID: "@user:example.com", ID: "device"} - cfg := &config.ClientAPI{} - + cfg := &config.ClientAPI{ + MSCs: &config.MSCs{ + MSCs: []string{}, + }, + } res := UploadCrossSigningDeviceKeys(req, keyserverAPI, device, getAccountByPassword, cfg) if res.Code != http.StatusOK { t.Fatalf("expected status %d, got %d", http.StatusOK, res.Code) @@ -101,18 +118,32 @@ func Test_UploadCrossSigningDeviceKeys_Unauthorised(t *testing.T) { keyserverAPI := &mockKeyAPI{ t: t, - userResponses: map[string]api.QueryKeysResponse{ + queryKeysData: map[string]api.QueryKeysResponse{ "@user:example.com": { MasterKeys: map[string]fclient.CrossSigningKey{ - "@user:example.com": {UserID: "@user:example.com", Usage: []fclient.CrossSigningKeyPurpose{"master"}, Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("key1")}}, + "@user:example.com": { + UserID: "@user:example.com", + Usage: []fclient.CrossSigningKeyPurpose{fclient.CrossSigningKeyPurposeMaster}, + Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("key1")}}, }, SelfSigningKeys: nil, UserSigningKeys: nil, }, }, + queryMasterKeysData: map[string]api.QueryMasterKeysResponse{ + "@user:example.com": { + Key: &types.CrossSigningKey{ + KeyData: spec.Base64Bytes("key1"), + }, + }, + }, } device := &api.Device{UserID: "@user:example.com", ID: "device"} - cfg := &config.ClientAPI{} + cfg := &config.ClientAPI{ + MSCs: &config.MSCs{ + MSCs: []string{}, + }, + } res := UploadCrossSigningDeviceKeys(req, keyserverAPI, device, getAccountByPassword, cfg) if res.Code != http.StatusUnauthorized { @@ -132,8 +163,11 @@ func Test_UploadCrossSigningDeviceKeys_InvalidJSON(t *testing.T) { keyserverAPI := &mockKeyAPI{} device := &api.Device{UserID: "@user:example.com", ID: "device"} - cfg := &config.ClientAPI{} - + cfg := &config.ClientAPI{ + MSCs: &config.MSCs{ + MSCs: []string{}, + }, + } res := UploadCrossSigningDeviceKeys(req, keyserverAPI, device, getAccountByPassword, cfg) if res.Code != http.StatusBadRequest { t.Fatalf("expected status %d, got %d", http.StatusBadRequest, res.Code) @@ -151,10 +185,21 @@ func Test_UploadCrossSigningDeviceKeys_ExistingKeysMismatch(t *testing.T) { req.Header.Set("Content-Type", "application/json") keyserverAPI := &mockKeyAPI{ - userResponses: map[string]api.QueryKeysResponse{ + queryKeysData: map[string]api.QueryKeysResponse{ "@user:example.com": { MasterKeys: map[string]fclient.CrossSigningKey{ - "@user:example.com": {UserID: "@user:example.com", Usage: []fclient.CrossSigningKeyPurpose{"master"}, Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("different_key")}}, + "@user:example.com": { + UserID: "@user:example.com", + Usage: []fclient.CrossSigningKeyPurpose{fclient.CrossSigningKeyPurposeMaster}, + Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("different_key")}, + }, + }, + }, + }, + queryMasterKeysData: map[string]api.QueryMasterKeysResponse{ + "@user:example.com": { + Key: &types.CrossSigningKey{ + KeyData: spec.Base64Bytes("different_key"), }, }, }, diff --git a/setup/config/config_mscs.go b/setup/config/config_mscs.go index 694bb351..c8f2249a 100644 --- a/setup/config/config_mscs.go +++ b/setup/config/config_mscs.go @@ -1,7 +1,5 @@ package config -import "slices" - type MSCs struct { Matrix *Global `yaml:"-"` @@ -46,7 +44,7 @@ func (c *MSCs) Verify(configErrs *ConfigErrors) { } func (c *MSCs) MSC3861Enabled() bool { - return slices.Contains(c.MSCs, "msc3861") && c.MSC3861 != nil + return c.Enabled("msc3861") && c.MSC3861 != nil } type MSC3861 struct {