mirror of
https://github.com/element-hq/dendrite.git
synced 2025-09-13 21:02:25 +03:00
Support for fallback keys (#3451)
Backports support for fallback keys from Harmony, which should make E2EE more reliable in the face of OTK exhaustion. Signed-off-by: Neil Alexander <git@neilalexander.dev> Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com> [skip ci]
This commit is contained in:
parent
c3d7a34c15
commit
78dbf21c5f
13 changed files with 446 additions and 20 deletions
134
userapi/storage/postgres/fallback_keys_table.go
Normal file
134
userapi/storage/postgres/fallback_keys_table.go
Normal file
|
@ -0,0 +1,134 @@
|
|||
// Copyright 2024 New Vector Ltd.
|
||||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/element-hq/dendrite/internal"
|
||||
"github.com/element-hq/dendrite/internal/sqlutil"
|
||||
"github.com/element-hq/dendrite/userapi/api"
|
||||
"github.com/element-hq/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var fallbackKeysSchema = `
|
||||
-- Stores one-time public keys for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_fallback_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
key_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
used BOOLEAN NOT NULL,
|
||||
-- Clobber based on tuple of user/device/algorithm.
|
||||
CONSTRAINT keyserver_fallback_keys_unique UNIQUE (user_id, device_id, algorithm)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_fallback_keys_idx ON keyserver_fallback_keys (user_id, device_id);
|
||||
`
|
||||
|
||||
const upsertFallbackKeysSQL = "" +
|
||||
"INSERT INTO keyserver_fallback_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json, used)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, false)" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_fallback_keys_unique" +
|
||||
" DO UPDATE SET key_id = $3, key_json = $6, used = false"
|
||||
|
||||
const selectFallbackUnusedAlgorithmsSQL = "" +
|
||||
"SELECT algorithm FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND used = false"
|
||||
|
||||
const selectFallbackKeysByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 ORDER BY used ASC LIMIT 1"
|
||||
|
||||
const deleteFallbackKeysSQL = "" +
|
||||
"DELETE FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2"
|
||||
|
||||
const updateFallbackKeyUsedSQL = "" +
|
||||
"UPDATE keyserver_fallback_keys SET used=true WHERE user_id = $1 AND device_id = $2 AND key_id = $3 AND algorithm = $4"
|
||||
|
||||
type fallbackKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeysStmt *sql.Stmt
|
||||
selectUnusedAlgorithmsStmt *sql.Stmt
|
||||
selectKeyByAlgorithmStmt *sql.Stmt
|
||||
deleteFallbackKeysStmt *sql.Stmt
|
||||
updateFallbackKeyUsedStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresFallbackKeysTable(db *sql.DB) (tables.FallbackKeys, error) {
|
||||
s := &fallbackKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(fallbackKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertKeysStmt, upsertFallbackKeysSQL},
|
||||
{&s.selectUnusedAlgorithmsStmt, selectFallbackUnusedAlgorithmsSQL},
|
||||
{&s.selectKeyByAlgorithmStmt, selectFallbackKeysByAlgorithmSQL},
|
||||
{&s.deleteFallbackKeysStmt, deleteFallbackKeysSQL},
|
||||
{&s.updateFallbackKeyUsedStmt, updateFallbackKeyUsedSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *fallbackKeysStatements) SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) {
|
||||
rows, err := s.selectUnusedAlgorithmsStmt.QueryContext(ctx, userID, deviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||
algos := []string{}
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
if err = rows.Scan(&algorithm); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
algos = append(algos, algorithm)
|
||||
}
|
||||
return algos, rows.Err()
|
||||
}
|
||||
|
||||
func (s *fallbackKeysStatements) InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error) {
|
||||
now := time.Now().Unix()
|
||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
|
||||
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return s.SelectUnusedFallbackKeyAlgorithms(ctx, keys.UserID, keys.DeviceID)
|
||||
}
|
||||
|
||||
func (s *fallbackKeysStatements) DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteFallbackKeysStmt).ExecContext(ctx, userID, deviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *fallbackKeysStatements) SelectAndUpdateFallbackKey(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
|
||||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.updateFallbackKeyUsedStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||
return map[string]json.RawMessage{
|
||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||
}, err
|
||||
}
|
|
@ -141,6 +141,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fk, err := NewPostgresFallbackKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dk, err := NewPostgresDeviceKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -164,6 +168,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
|||
|
||||
return &shared.KeyDatabase{
|
||||
OneTimeKeysTable: otk,
|
||||
FallbackKeysTable: fk,
|
||||
DeviceKeysTable: dk,
|
||||
KeyChangesTable: kc,
|
||||
StaleDeviceListsTable: sdl,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue