Support for fallback keys (#3451)

Backports support for fallback keys from Harmony, which should make E2EE
more reliable in the face of OTK exhaustion.

Signed-off-by: Neil Alexander <git@neilalexander.dev>

Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>

[skip ci]
This commit is contained in:
Neil 2024-12-17 19:19:15 +01:00 committed by GitHub
parent c3d7a34c15
commit 78dbf21c5f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 446 additions and 20 deletions

View file

@ -167,6 +167,15 @@ type KeyDatabase interface {
// OneTimeKeysCount returns a count of all OTKs for this device.
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
// StoreFallbackKeys persists the given fallback keys.
StoreFallbackKeys(ctx context.Context, keys api.FallbackKeys) ([]string, error)
// UnusedFallbackKeyAlgorithms returns unused fallback algorithms for this user/device.
UnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error)
// DeleteFallbackKeys deletes all fallback keys for the user.
DeleteFallbackKeys(ctx context.Context, userID, deviceID string) error
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error

View file

@ -0,0 +1,134 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2017 Vector Creations Ltd
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/element-hq/dendrite/internal"
"github.com/element-hq/dendrite/internal/sqlutil"
"github.com/element-hq/dendrite/userapi/api"
"github.com/element-hq/dendrite/userapi/storage/tables"
)
var fallbackKeysSchema = `
-- Stores one-time public keys for users
CREATE TABLE IF NOT EXISTS keyserver_fallback_keys (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
key_id TEXT NOT NULL,
algorithm TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
used BOOLEAN NOT NULL,
-- Clobber based on tuple of user/device/algorithm.
CONSTRAINT keyserver_fallback_keys_unique UNIQUE (user_id, device_id, algorithm)
);
CREATE INDEX IF NOT EXISTS keyserver_fallback_keys_idx ON keyserver_fallback_keys (user_id, device_id);
`
const upsertFallbackKeysSQL = "" +
"INSERT INTO keyserver_fallback_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json, used)" +
" VALUES ($1, $2, $3, $4, $5, $6, false)" +
" ON CONFLICT ON CONSTRAINT keyserver_fallback_keys_unique" +
" DO UPDATE SET key_id = $3, key_json = $6, used = false"
const selectFallbackUnusedAlgorithmsSQL = "" +
"SELECT algorithm FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND used = false"
const selectFallbackKeysByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 ORDER BY used ASC LIMIT 1"
const deleteFallbackKeysSQL = "" +
"DELETE FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2"
const updateFallbackKeyUsedSQL = "" +
"UPDATE keyserver_fallback_keys SET used=true WHERE user_id = $1 AND device_id = $2 AND key_id = $3 AND algorithm = $4"
type fallbackKeysStatements struct {
db *sql.DB
upsertKeysStmt *sql.Stmt
selectUnusedAlgorithmsStmt *sql.Stmt
selectKeyByAlgorithmStmt *sql.Stmt
deleteFallbackKeysStmt *sql.Stmt
updateFallbackKeyUsedStmt *sql.Stmt
}
func NewPostgresFallbackKeysTable(db *sql.DB) (tables.FallbackKeys, error) {
s := &fallbackKeysStatements{
db: db,
}
_, err := db.Exec(fallbackKeysSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertKeysStmt, upsertFallbackKeysSQL},
{&s.selectUnusedAlgorithmsStmt, selectFallbackUnusedAlgorithmsSQL},
{&s.selectKeyByAlgorithmStmt, selectFallbackKeysByAlgorithmSQL},
{&s.deleteFallbackKeysStmt, deleteFallbackKeysSQL},
{&s.updateFallbackKeyUsedStmt, updateFallbackKeyUsedSQL},
}.Prepare(db)
}
func (s *fallbackKeysStatements) SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) {
rows, err := s.selectUnusedAlgorithmsStmt.QueryContext(ctx, userID, deviceID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
algos := []string{}
for rows.Next() {
var algorithm string
if err = rows.Scan(&algorithm); err != nil {
return nil, err
}
algos = append(algos, algorithm)
}
return algos, rows.Err()
}
func (s *fallbackKeysStatements) InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error) {
now := time.Now().Unix()
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo)
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
)
if err != nil {
return nil, err
}
}
return s.SelectUnusedFallbackKeyAlgorithms(ctx, keys.UserID, keys.DeviceID)
}
func (s *fallbackKeysStatements) DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteFallbackKeysStmt).ExecContext(ctx, userID, deviceID)
return err
}
func (s *fallbackKeysStatements) SelectAndUpdateFallbackKey(
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
) (map[string]json.RawMessage, error) {
var keyID string
var keyJSON string
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
_, err = sqlutil.TxStmtContext(ctx, txn, s.updateFallbackKeyUsedStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err
}

View file

@ -141,6 +141,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
if err != nil {
return nil, err
}
fk, err := NewPostgresFallbackKeysTable(db)
if err != nil {
return nil, err
}
dk, err := NewPostgresDeviceKeysTable(db)
if err != nil {
return nil, err
@ -164,6 +168,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
return &shared.KeyDatabase{
OneTimeKeysTable: otk,
FallbackKeysTable: fk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,

View file

@ -57,6 +57,7 @@ type Database struct {
type KeyDatabase struct {
OneTimeKeysTable tables.OneTimeKeys
FallbackKeysTable tables.FallbackKeys
DeviceKeysTable tables.DeviceKeys
KeyChangesTable tables.KeyChanges
StaleDeviceListsTable tables.StaleDeviceLists
@ -937,6 +938,22 @@ func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID str
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
}
func (d *KeyDatabase) StoreFallbackKeys(ctx context.Context, keys api.FallbackKeys) (unused []string, err error) {
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
unused, err = d.FallbackKeysTable.InsertFallbackKeys(ctx, txn, keys)
return err
})
return
}
func (d *KeyDatabase) DeleteFallbackKeys(ctx context.Context, userID, deviceID string) error {
return d.FallbackKeysTable.DeleteFallbackKeys(ctx, nil, userID, deviceID)
}
func (d *KeyDatabase) UnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) {
return d.FallbackKeysTable.SelectUnusedFallbackKeyAlgorithms(ctx, userID, deviceID)
}
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
}
@ -999,6 +1016,12 @@ func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map
if err != nil {
return err
}
if len(keyJSON) == 0 {
keyJSON, err = d.FallbackKeysTable.SelectAndUpdateFallbackKey(ctx, txn, userID, deviceID, algo)
if err != nil {
return err
}
}
if keyJSON != nil {
result = append(result, api.OneTimeKeys{
UserID: userID,

View file

@ -0,0 +1,132 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2017 Vector Creations Ltd
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/element-hq/dendrite/internal"
"github.com/element-hq/dendrite/internal/sqlutil"
"github.com/element-hq/dendrite/userapi/api"
"github.com/element-hq/dendrite/userapi/storage/tables"
)
var fallbackKeysSchema = `
-- Stores one-time public keys for users
CREATE TABLE IF NOT EXISTS keyserver_fallback_keys (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
key_id TEXT NOT NULL,
algorithm TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
used BOOLEAN NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS keyserver_fallback_keys_unique_idx ON keyserver_fallback_keys(user_id, device_id, algorithm);
CREATE INDEX IF NOT EXISTS keyserver_fallback_keys_idx ON keyserver_fallback_keys (user_id, device_id);
`
const upsertFallbackKeysSQL = "" +
"INSERT INTO keyserver_fallback_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json, used)" +
" VALUES ($1, $2, $3, $4, $5, $6, false)" +
" ON CONFLICT (user_id, device_id, algorithm)" +
" DO UPDATE SET key_id = $3, key_json = $6, used = false"
const selectFallbackUnusedAlgorithmsSQL = "" +
"SELECT algorithm FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND used = false"
const selectFallbackKeysByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 ORDER BY used ASC LIMIT 1"
const deleteFallbackKeysSQL = "" +
"DELETE FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2"
const updateFallbackKeyUsedSQL = "" +
"UPDATE keyserver_fallback_keys SET used=true WHERE user_id = $1 AND device_id = $2 AND key_id = $3 AND algorithm = $4"
type fallbackKeysStatements struct {
db *sql.DB
upsertKeysStmt *sql.Stmt
selectUnusedAlgorithmsStmt *sql.Stmt
selectKeyByAlgorithmStmt *sql.Stmt
deleteFallbackKeysStmt *sql.Stmt
updateFallbackKeyUsedStmt *sql.Stmt
}
func NewSqliteFallbackKeysTable(db *sql.DB) (tables.FallbackKeys, error) {
s := &fallbackKeysStatements{
db: db,
}
_, err := db.Exec(fallbackKeysSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertKeysStmt, upsertFallbackKeysSQL},
{&s.selectUnusedAlgorithmsStmt, selectFallbackUnusedAlgorithmsSQL},
{&s.selectKeyByAlgorithmStmt, selectFallbackKeysByAlgorithmSQL},
{&s.deleteFallbackKeysStmt, deleteFallbackKeysSQL},
{&s.updateFallbackKeyUsedStmt, updateFallbackKeyUsedSQL},
}.Prepare(db)
}
func (s *fallbackKeysStatements) SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) {
rows, err := s.selectUnusedAlgorithmsStmt.QueryContext(ctx, userID, deviceID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
algos := []string{}
for rows.Next() {
var algorithm string
if err = rows.Scan(&algorithm); err != nil {
return nil, err
}
algos = append(algos, algorithm)
}
return algos, rows.Err()
}
func (s *fallbackKeysStatements) InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error) {
now := time.Now().Unix()
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo)
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
)
if err != nil {
return nil, err
}
}
return s.SelectUnusedFallbackKeyAlgorithms(ctx, keys.UserID, keys.DeviceID)
}
func (s *fallbackKeysStatements) DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteFallbackKeysStmt).ExecContext(ctx, userID, deviceID)
return err
}
func (s *fallbackKeysStatements) SelectAndUpdateFallbackKey(
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
) (map[string]json.RawMessage, error) {
var keyID string
var keyJSON string
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
_, err = sqlutil.TxStmtContext(ctx, txn, s.updateFallbackKeyUsedStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err
}

View file

@ -138,6 +138,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
if err != nil {
return nil, err
}
fk, err := NewSqliteFallbackKeysTable(db)
if err != nil {
return nil, err
}
dk, err := NewSqliteDeviceKeysTable(db)
if err != nil {
return nil, err
@ -161,6 +165,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
return &shared.KeyDatabase{
OneTimeKeysTable: otk,
FallbackKeysTable: fk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,

View file

@ -809,3 +809,42 @@ func TestOneTimeKeys(t *testing.T) {
}
})
}
func TestFallbackKeys(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := mustCreateKeyDatabase(t, dbType)
defer clean()
userID := "@alice:localhost"
deviceID := "alice_device"
fk := api.FallbackKeys{
UserID: userID,
DeviceID: deviceID,
KeyJSON: map[string]json.RawMessage{"curve25519:KEY1": []byte(`{"key":"v1"}`)},
}
_, err := db.StoreFallbackKeys(ctx, fk)
MustNotError(t, err)
unused, err := db.UnusedFallbackKeyAlgorithms(ctx, userID, deviceID)
MustNotError(t, err)
if c := len(unused); c != 1 {
t.Fatalf("Expected 1 unused key algorithm, got %d", c)
}
if unused[0] != "curve25519" {
t.Fatalf("Expected unused key algorithm to be 'curve25519', got '%s'", unused[0])
}
// No other one-time keys have been uploaded so we expect to get the fallback key instead.
claimed, err := db.ClaimKeys(ctx, map[string]map[string]string{userID: {deviceID: "curve25519"}})
MustNotError(t, err)
switch {
case claimed[0].UserID != fk.UserID:
t.Fatalf("Claimed user ID ID doesn't match, got %q, want %q", claimed[0].UserID, fk.DeviceID)
case claimed[0].DeviceID != fk.DeviceID:
t.Fatalf("Claimed device ID doesn't match, got %q, want %q", claimed[0].DeviceID, fk.DeviceID)
case claimed[0].KeyJSON["curve25519:KEY1"] == nil:
t.Fatalf("Claimed key JSON for curve25519:KEY1 not found")
}
})
}

View file

@ -170,6 +170,13 @@ type DeviceKeys interface {
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
}
type FallbackKeys interface {
SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error)
InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error)
DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
SelectAndUpdateFallbackKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
}
type KeyChanges interface {
InsertKeyChange(ctx context.Context, userID string) (int64, error)
// SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.