mirror of
https://github.com/element-hq/dendrite.git
synced 2025-09-13 21:02:25 +03:00
Support for fallback keys (#3451)
Backports support for fallback keys from Harmony, which should make E2EE more reliable in the face of OTK exhaustion. Signed-off-by: Neil Alexander <git@neilalexander.dev> Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com> [skip ci]
This commit is contained in:
parent
c3d7a34c15
commit
78dbf21c5f
13 changed files with 446 additions and 20 deletions
|
@ -44,14 +44,22 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
|
|||
if len(req.DeviceKeys) > 0 {
|
||||
a.uploadLocalDeviceKeys(ctx, req, res)
|
||||
}
|
||||
if len(req.OneTimeKeys) > 0 {
|
||||
a.uploadOneTimeKeys(ctx, req, res)
|
||||
if len(req.OneTimeKeys) > 0 || len(req.FallbackKeys) > 0 {
|
||||
a.uploadOneTimeAndFallbackKeys(ctx, req, res)
|
||||
}
|
||||
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
algos, err := a.KeyDatabase.UnusedFallbackKeyAlgorithms(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("Failed to query unused fallback algorithms: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
|
||||
res.FallbackKeysUnusedAlgorithms = algos
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -169,7 +177,15 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn
|
|||
}
|
||||
return nil
|
||||
}
|
||||
algos, err := a.KeyDatabase.UnusedFallbackKeyAlgorithms(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("Failed to query unused fallback algorithms: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
res.Count = *count
|
||||
res.UnusedFallbackAlgorithms = algos
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -507,6 +523,9 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer(
|
|||
for userID := range userIDsForAllDevices {
|
||||
err := a.Updater.ManualUpdate(context.Background(), spec.ServerName(serverName), userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
logrus.WithFields(logrus.Fields{
|
||||
logrus.ErrorKey: err,
|
||||
"user_id": userID,
|
||||
|
@ -520,6 +539,9 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer(
|
|||
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
|
||||
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
logrus.WithFields(logrus.Fields{
|
||||
logrus.ErrorKey: err,
|
||||
"user_id": userID,
|
||||
|
@ -715,7 +737,7 @@ func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Pe
|
|||
}
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||
func (a *UserInternalAPI) uploadOneTimeAndFallbackKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||
if req.UserID == "" {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "user ID missing",
|
||||
|
@ -768,7 +790,32 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
|
|||
// collect counts
|
||||
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
|
||||
}
|
||||
|
||||
if len(req.FallbackKeys) > 0 {
|
||||
if err := a.KeyDatabase.DeleteFallbackKeys(ctx, req.UserID, req.DeviceID); err != nil {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: fmt.Sprintf("%s device %s : failed to clear fallback keys: %s", req.UserID, req.DeviceID, err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
for _, key := range req.FallbackKeys {
|
||||
// grab existing keys based on (user/device/algorithm/key ID)
|
||||
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
|
||||
i := 0
|
||||
for keyIDWithAlgo := range key.KeyJSON {
|
||||
keyIDsWithAlgorithms[i] = keyIDWithAlgo
|
||||
i++
|
||||
}
|
||||
unused, err := a.KeyDatabase.StoreFallbackKeys(ctx, key)
|
||||
if err != nil {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: fmt.Sprintf("%s device %s : failed to store fallback keys: %s", req.UserID, req.DeviceID, err.Error()),
|
||||
})
|
||||
continue
|
||||
}
|
||||
// collect counts
|
||||
res.FallbackKeysUnusedAlgorithms = unused
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue