Support for fallback keys (#3451)

Backports support for fallback keys from Harmony, which should make E2EE
more reliable in the face of OTK exhaustion.

Signed-off-by: Neil Alexander <git@neilalexander.dev>

Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>

[skip ci]
This commit is contained in:
Neil 2024-12-17 19:19:15 +01:00 committed by GitHub
parent c3d7a34c15
commit 78dbf21c5f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 446 additions and 20 deletions

View file

@ -44,14 +44,22 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
if len(req.DeviceKeys) > 0 {
a.uploadLocalDeviceKeys(ctx, req, res)
}
if len(req.OneTimeKeys) > 0 {
a.uploadOneTimeKeys(ctx, req, res)
if len(req.OneTimeKeys) > 0 || len(req.FallbackKeys) > 0 {
a.uploadOneTimeAndFallbackKeys(ctx, req, res)
}
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
return err
}
algos, err := a.KeyDatabase.UnusedFallbackKeyAlgorithms(ctx, req.UserID, req.DeviceID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to query unused fallback algorithms: %s", err),
}
return nil
}
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
res.FallbackKeysUnusedAlgorithms = algos
return nil
}
@ -169,7 +177,15 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn
}
return nil
}
algos, err := a.KeyDatabase.UnusedFallbackKeyAlgorithms(ctx, req.UserID, req.DeviceID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to query unused fallback algorithms: %s", err),
}
return nil
}
res.Count = *count
res.UnusedFallbackAlgorithms = algos
return nil
}
@ -507,6 +523,9 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer(
for userID := range userIDsForAllDevices {
err := a.Updater.ManualUpdate(context.Background(), spec.ServerName(serverName), userID)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
logrus.WithFields(logrus.Fields{
logrus.ErrorKey: err,
"user_id": userID,
@ -520,6 +539,9 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer(
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
logrus.WithFields(logrus.Fields{
logrus.ErrorKey: err,
"user_id": userID,
@ -715,7 +737,7 @@ func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Pe
}
}
func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
func (a *UserInternalAPI) uploadOneTimeAndFallbackKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
if req.UserID == "" {
res.Error = &api.KeyError{
Err: "user ID missing",
@ -768,7 +790,32 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
// collect counts
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
}
if len(req.FallbackKeys) > 0 {
if err := a.KeyDatabase.DeleteFallbackKeys(ctx, req.UserID, req.DeviceID); err != nil {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s : failed to clear fallback keys: %s", req.UserID, req.DeviceID, err.Error()),
})
return
}
for _, key := range req.FallbackKeys {
// grab existing keys based on (user/device/algorithm/key ID)
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
i := 0
for keyIDWithAlgo := range key.KeyJSON {
keyIDsWithAlgorithms[i] = keyIDWithAlgo
i++
}
unused, err := a.KeyDatabase.StoreFallbackKeys(ctx, key)
if err != nil {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s : failed to store fallback keys: %s", req.UserID, req.DeviceID, err.Error()),
})
continue
}
// collect counts
res.FallbackKeysUnusedAlgorithms = unused
}
}
}
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {