Implement server whitelist

This commit is contained in:
enaix 2025-01-16 12:31:37 +03:00
parent add73ec866
commit b2ecd5648c
14 changed files with 332 additions and 0 deletions

View file

@ -107,6 +107,22 @@ func NewInternalAPI(
stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1, cfg.P2PFederationRetriesUntilAssumedOffline+1, cfg.EnableRelays) stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1, cfg.P2PFederationRetriesUntilAssumedOffline+1, cfg.EnableRelays)
// Add servers to whitelist if enabled
if cfg.EnableWhitelist {
// We need to clear the list of the whitelisted servers during init
err = stats.DB.RemoveAllServersFromWhitelist()
if err != nil {
logrus.WithError(err).Panic("failed to clear whitelisted servers")
}
// Add each whitelisted server to the database
for _, server := range cfg.WhitelistedServers {
err = stats.DB.AddServerToWhitelist(server)
if err != nil {
logrus.WithError(err).Panicf("failed to add server %s to whitelist", server)
}
}
}
js, nats := natsInstance.Prepare(processContext, &cfg.Matrix.JetStream) js, nats := natsInstance.Prepare(processContext, &cfg.Matrix.JetStream)
signingInfo := dendriteCfg.Global.SigningIdentities() signingInfo := dendriteCfg.Global.SigningIdentities()

View file

@ -112,6 +112,12 @@ func NewFederationInternalAPI(
} }
} }
// IsWhitelistedOrAny checks if the server is whitelisted or the whitelist is disabled, so we can connect to any server
func (a *FederationInternalAPI) IsWhitelistedOrAny(s spec.ServerName) bool {
stats := a.statistics.ForServer(s)
return stats.Whitelisted() || !a.cfg.EnableWhitelist
}
func (a *FederationInternalAPI) IsBlacklistedOrBackingOff(s spec.ServerName) (*statistics.ServerStatistics, error) { func (a *FederationInternalAPI) IsBlacklistedOrBackingOff(s spec.ServerName) (*statistics.ServerStatistics, error) {
stats := a.statistics.ForServer(s) stats := a.statistics.ForServer(s)
if stats.Blacklisted() { if stats.Blacklisted() {

View file

@ -44,6 +44,9 @@ func (a *FederationInternalAPI) GetEventAuth(
) (res fclient.RespEventAuth, err error) { ) (res fclient.RespEventAuth, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout) ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel() defer cancel()
if !a.IsWhitelistedOrAny(s) {
return fclient.RespEventAuth{}, nil
}
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID) return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID)
}) })
@ -58,6 +61,9 @@ func (a *FederationInternalAPI) GetUserDevices(
) (fclient.RespUserDevices, error) { ) (fclient.RespUserDevices, error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout) ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel() defer cancel()
if !a.IsWhitelistedOrAny(s) {
return fclient.RespUserDevices{}, nil
}
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetUserDevices(ctx, origin, s, userID) return a.federation.GetUserDevices(ctx, origin, s, userID)
}) })
@ -72,6 +78,9 @@ func (a *FederationInternalAPI) ClaimKeys(
) (fclient.RespClaimKeys, error) { ) (fclient.RespClaimKeys, error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout) ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel() defer cancel()
if !a.IsWhitelistedOrAny(s) {
return fclient.RespClaimKeys{}, nil
}
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys) return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys)
}) })
@ -84,6 +93,9 @@ func (a *FederationInternalAPI) ClaimKeys(
func (a *FederationInternalAPI) QueryKeys( func (a *FederationInternalAPI) QueryKeys(
ctx context.Context, origin, s spec.ServerName, keys map[string][]string, ctx context.Context, origin, s spec.ServerName, keys map[string][]string,
) (fclient.RespQueryKeys, error) { ) (fclient.RespQueryKeys, error) {
if !a.IsWhitelistedOrAny(s) {
return fclient.RespQueryKeys{}, nil
}
ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) {
return a.federation.QueryKeys(ctx, origin, s, keys) return a.federation.QueryKeys(ctx, origin, s, keys)
}) })
@ -98,6 +110,9 @@ func (a *FederationInternalAPI) Backfill(
) (res gomatrixserverlib.Transaction, err error) { ) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout) ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel() defer cancel()
if !a.IsWhitelistedOrAny(s) {
return gomatrixserverlib.Transaction{}, nil
}
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs) return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs)
}) })
@ -112,6 +127,9 @@ func (a *FederationInternalAPI) LookupState(
) (res gomatrixserverlib.StateResponse, err error) { ) (res gomatrixserverlib.StateResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout) ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel() defer cancel()
if !a.IsWhitelistedOrAny(s) {
return &fclient.RespState{}, nil
} // TODO check &
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion) return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion)
}) })
@ -127,6 +145,9 @@ func (a *FederationInternalAPI) LookupStateIDs(
) (res gomatrixserverlib.StateIDResponse, err error) { ) (res gomatrixserverlib.StateIDResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout) ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel() defer cancel()
if !a.IsWhitelistedOrAny(s) {
return fclient.RespStateIDs{}, nil
}
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID) return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID)
}) })
@ -142,6 +163,9 @@ func (a *FederationInternalAPI) LookupMissingEvents(
) (res fclient.RespMissingEvents, err error) { ) (res fclient.RespMissingEvents, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout) ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel() defer cancel()
if !a.IsWhitelistedOrAny(s) {
return fclient.RespMissingEvents{}, nil
}
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion) return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion)
}) })
@ -156,6 +180,9 @@ func (a *FederationInternalAPI) GetEvent(
) (res gomatrixserverlib.Transaction, err error) { ) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout) ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel() defer cancel()
if !a.IsWhitelistedOrAny(s) {
return gomatrixserverlib.Transaction{}, nil
}
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetEvent(ctx, origin, s, eventID) return a.federation.GetEvent(ctx, origin, s, eventID)
}) })
@ -170,6 +197,9 @@ func (a *FederationInternalAPI) LookupServerKeys(
) ([]gomatrixserverlib.ServerKeys, error) { ) ([]gomatrixserverlib.ServerKeys, error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute) ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() defer cancel()
if !a.IsWhitelistedOrAny(s) {
return []gomatrixserverlib.ServerKeys{}, nil
}
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupServerKeys(ctx, s, keyRequests) return a.federation.LookupServerKeys(ctx, s, keyRequests)
}) })
@ -185,6 +215,9 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
) (res fclient.MSC2836EventRelationshipsResponse, err error) { ) (res fclient.MSC2836EventRelationshipsResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute) ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() defer cancel()
if !a.IsWhitelistedOrAny(s) {
return res, nil
}
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion) return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion)
}) })
@ -199,6 +232,9 @@ func (a *FederationInternalAPI) RoomHierarchies(
) (res fclient.RoomHierarchyResponse, err error) { ) (res fclient.RoomHierarchyResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute) ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() defer cancel()
if !a.IsWhitelistedOrAny(s) {
return res, nil
}
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.RoomHierarchy(ctx, origin, s, roomID, suggestedOnly) return a.federation.RoomHierarchy(ctx, origin, s, roomID, suggestedOnly)
}) })

View file

@ -86,6 +86,10 @@ func (a *FederationInternalAPI) QueryServerKeys(
} }
util.GetLogger(ctx).WithField("server", req.ServerName).WithError(err).Warn("notary: failed to satisfy keys request entirely from cache, hitting direct") util.GetLogger(ctx).WithField("server", req.ServerName).WithError(err).Warn("notary: failed to satisfy keys request entirely from cache, hitting direct")
if !a.IsWhitelistedOrAny(req.ServerName) {
return nil
}
serverKeys, err := a.fetchServerKeysDirectly(ctx, req.ServerName) serverKeys, err := a.fetchServerKeysDirectly(ctx, req.ServerName)
if err != nil { if err != nil {
// try to load as much as we can from the cache in a best effort basis // try to load as much as we can from the cache in a best effort basis

View file

@ -78,6 +78,13 @@ func (s *Statistics) ForServer(serverName spec.ServerName) *ServerStatistics {
server.blacklisted.Store(blacklisted) server.blacklisted.Store(blacklisted)
} }
whitelisted, err := s.DB.IsServerWhitelisted(serverName)
if err != nil {
logrus.WithError(err).Errorf("Failed to get whitelist entry %q", serverName)
} else {
server.whitelisted.Store(whitelisted)
}
// Don't bother hitting the database 2 additional times // Don't bother hitting the database 2 additional times
// if we don't want to use relays. // if we don't want to use relays.
if !s.enableRelays { if !s.enableRelays {
@ -118,6 +125,7 @@ type ServerStatistics struct {
statistics *Statistics // statistics *Statistics //
serverName spec.ServerName // serverName spec.ServerName //
blacklisted atomic.Bool // is the node blacklisted blacklisted atomic.Bool // is the node blacklisted
whitelisted atomic.Bool // is the node whitelisted
assumedOffline atomic.Bool // is the node assumed to be offline assumedOffline atomic.Bool // is the node assumed to be offline
backoffStarted atomic.Bool // is the backoff started backoffStarted atomic.Bool // is the backoff started
backoffUntil atomic.Value // time.Time until this backoff interval ends backoffUntil atomic.Value // time.Time until this backoff interval ends
@ -281,6 +289,10 @@ func (s *ServerStatistics) Blacklisted() bool {
return s.blacklisted.Load() return s.blacklisted.Load()
} }
// Whitelisted returns true if the server is whitelisted and false
// otherwise.
func (s *ServerStatistics) Whitelisted() bool { return s.whitelisted.Load() }
// AssumedOffline returns true if the server is assumed offline and false // AssumedOffline returns true if the server is assumed offline and false
// otherwise. // otherwise.
func (s *ServerStatistics) AssumedOffline() bool { func (s *ServerStatistics) AssumedOffline() bool {
@ -302,6 +314,16 @@ func (s *ServerStatistics) removeBlacklist() bool {
return wasBlacklisted return wasBlacklisted
} }
// removeWhitelist removes the whitelisted status from the server.
// Returns whether the server was whitelisted.
func (s *ServerStatistics) removeWhitelist() bool {
if s.Whitelisted() {
_ = s.statistics.DB.RemoveServerFromWhitelist(s.serverName)
return true
}
return false
}
// removeAssumedOffline removes the assumed offline status from the server. // removeAssumedOffline removes the assumed offline status from the server.
func (s *ServerStatistics) removeAssumedOffline() { func (s *ServerStatistics) removeAssumedOffline() {
if s.AssumedOffline() { if s.AssumedOffline() {

View file

@ -49,6 +49,11 @@ type Database interface {
RemoveAllServersFromBlacklist() error RemoveAllServersFromBlacklist() error
IsServerBlacklisted(serverName spec.ServerName) (bool, error) IsServerBlacklisted(serverName spec.ServerName) (bool, error)
AddServerToWhitelist(serverName spec.ServerName) error
RemoveServerFromWhitelist(serverName spec.ServerName) error
RemoveAllServersFromWhitelist() error
IsServerWhitelisted(serverName spec.ServerName) (bool, error)
// Adds the server to the list of assumed offline servers. // Adds the server to the list of assumed offline servers.
// If the server already exists in the table, nothing happens and returns success. // If the server already exists in the table, nothing happens and returns success.
SetServerAssumedOffline(ctx context.Context, serverName spec.ServerName) error SetServerAssumedOffline(ctx context.Context, serverName spec.ServerName) error

View file

@ -38,6 +38,10 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties
if err != nil { if err != nil {
return nil, err return nil, err
} }
whitelist, err := NewPostgresWhitelistTable(d.db)
if err != nil {
return nil, err
}
joinedHosts, err := NewPostgresJoinedHostsTable(d.db) joinedHosts, err := NewPostgresJoinedHostsTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -104,6 +108,7 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties
FederationQueueEDUs: queueEDUs, FederationQueueEDUs: queueEDUs,
FederationQueueJSON: queueJSON, FederationQueueJSON: queueJSON,
FederationBlacklist: blacklist, FederationBlacklist: blacklist,
FederationWhitelist: whitelist,
FederationAssumedOffline: assumedOffline, FederationAssumedOffline: assumedOffline,
FederationRelayServers: relayServers, FederationRelayServers: relayServers,
FederationInboundPeeks: inboundPeeks, FederationInboundPeeks: inboundPeeks,

View file

@ -0,0 +1,93 @@
package postgres
import (
"context"
"database/sql"
"github.com/element-hq/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib/spec"
)
const whitelistSchema = `
CREATE TABLE IF NOT EXISTS federationsender_whitelist (
-- The whitelisted server name
server_name TEXT NOT NULL,
UNIQUE (server_name)
`
const insertWhitelistSQL = "" +
"INSERT INTO federationsender_whitelist (server_name) VALUES ($1)" +
" ON CONFLICT DO NOTHING"
const selectWhitelistSQL = "" +
"SELECT server_name FROM federationsender_whitelist WHERE server_name = $1"
const deleteWhitelistSQL = "" +
"DELETE FROM federationsender_whitelist WHERE server_name = $1"
const deleteAllWhitelistSQL = "" +
"TRUNCATE federationsender_whitelist"
type whitelistStatements struct {
db *sql.DB
insertWhitelistStmt *sql.Stmt
selectWhitelistStmt *sql.Stmt
deleteWhitelistStmt *sql.Stmt
deleteAllWhitelistStmt *sql.Stmt
}
func NewPostgresWhitelistTable(db *sql.DB) (s *whitelistStatements, err error) {
s = &whitelistStatements{
db: db,
}
_, err = db.Exec(whitelistSchema)
if err != nil {
return
}
return s, sqlutil.StatementList{
{&s.insertWhitelistStmt, insertWhitelistSQL},
{&s.selectWhitelistStmt, selectWhitelistSQL},
{&s.deleteWhitelistStmt, deleteWhitelistSQL},
{&s.deleteAllWhitelistStmt, deleteAllWhitelistSQL},
}.Prepare(db)
}
func (s *whitelistStatements) InsertWhitelist(
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.insertWhitelistStmt)
_, err := stmt.ExecContext(ctx, serverName)
return err
}
func (s *whitelistStatements) SelectWhitelist(
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
) (bool, error) {
stmt := sqlutil.TxStmt(txn, s.selectWhitelistStmt)
res, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return false, err
}
defer res.Close() // nolint:errcheck
// The query will return the server name if the server is whitelisted, and
// will return no rows if not. By calling Next, we find out if a row was
// returned or not - we don't care about the value itself.
return res.Next(), nil
}
func (s *whitelistStatements) DeleteWhitelist(
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteWhitelistStmt)
_, err := stmt.ExecContext(ctx, serverName)
return err
}
func (s *whitelistStatements) DeleteAllWhitelist(
ctx context.Context, txn *sql.Tx,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllWhitelistStmt)
_, err := stmt.ExecContext(ctx)
return err
}

View file

@ -31,6 +31,7 @@ type Database struct {
FederationQueueJSON tables.FederationQueueJSON FederationQueueJSON tables.FederationQueueJSON
FederationJoinedHosts tables.FederationJoinedHosts FederationJoinedHosts tables.FederationJoinedHosts
FederationBlacklist tables.FederationBlacklist FederationBlacklist tables.FederationBlacklist
FederationWhitelist tables.FederationWhitelist
FederationAssumedOffline tables.FederationAssumedOffline FederationAssumedOffline tables.FederationAssumedOffline
FederationRelayServers tables.FederationRelayServers FederationRelayServers tables.FederationRelayServers
FederationOutboundPeeks tables.FederationOutboundPeeks FederationOutboundPeeks tables.FederationOutboundPeeks
@ -148,6 +149,14 @@ func (d *Database) AddServerToBlacklist(
}) })
} }
func (d *Database) AddServerToWhitelist(
serverName spec.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationWhitelist.InsertWhitelist(context.TODO(), txn, serverName)
})
}
func (d *Database) RemoveServerFromBlacklist( func (d *Database) RemoveServerFromBlacklist(
serverName spec.ServerName, serverName spec.ServerName,
) error { ) error {
@ -156,18 +165,38 @@ func (d *Database) RemoveServerFromBlacklist(
}) })
} }
func (d *Database) RemoveServerFromWhitelist(
serverName spec.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationWhitelist.DeleteWhitelist(context.TODO(), txn, serverName)
})
}
func (d *Database) RemoveAllServersFromBlacklist() error { func (d *Database) RemoveAllServersFromBlacklist() error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationBlacklist.DeleteAllBlacklist(context.TODO(), txn) return d.FederationBlacklist.DeleteAllBlacklist(context.TODO(), txn)
}) })
} }
func (d *Database) RemoveAllServersFromWhitelist() error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationWhitelist.DeleteAllWhitelist(context.TODO(), txn)
})
}
func (d *Database) IsServerBlacklisted( func (d *Database) IsServerBlacklisted(
serverName spec.ServerName, serverName spec.ServerName,
) (bool, error) { ) (bool, error) {
return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName) return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName)
} }
func (d *Database) IsServerWhitelisted(
serverName spec.ServerName,
) (bool, error) {
return d.FederationWhitelist.SelectWhitelist(context.TODO(), nil, serverName)
}
func (d *Database) SetServerAssumedOffline( func (d *Database) SetServerAssumedOffline(
ctx context.Context, ctx context.Context,
serverName spec.ServerName, serverName spec.ServerName,

View file

@ -36,6 +36,10 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties
if err != nil { if err != nil {
return nil, err return nil, err
} }
whitelist, err := NewSQLiteWhitelistTable(d.db)
if err != nil {
return nil, err
}
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db) joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -102,6 +106,7 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties
FederationQueueEDUs: queueEDUs, FederationQueueEDUs: queueEDUs,
FederationQueueJSON: queueJSON, FederationQueueJSON: queueJSON,
FederationBlacklist: blacklist, FederationBlacklist: blacklist,
FederationWhitelist: whitelist,
FederationAssumedOffline: assumedOffline, FederationAssumedOffline: assumedOffline,
FederationRelayServers: relayServers, FederationRelayServers: relayServers,
FederationOutboundPeeks: outboundPeeks, FederationOutboundPeeks: outboundPeeks,

View file

@ -0,0 +1,94 @@
package sqlite3
import (
"context"
"database/sql"
"github.com/element-hq/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib/spec"
)
const whitelistSchema = `
CREATE TABLE IF NOT EXISTS federationsender_whitelist (
-- The whitelisted server name
server_name TEXT NOT NULL,
UNIQUE (server_name)
);
`
const insertWhitelistSQL = "" +
"INSERT INTO federationsender_whitelist (server_name) VALUES ($1)" +
" ON CONFLICT DO NOTHING"
const selectWhitelistSQL = "" +
"SELECT server_name FROM federationsender_whitelist WHERE server_name = $1"
const deleteWhitelistSQL = "" +
"DELETE FROM federationsender_whitelist WHERE server_name = $1"
const deleteAllWhitelistSQL = "" +
"DELETE FROM federationsender_whitelist"
type whitelistStatements struct {
db *sql.DB
insertWhitelistStmt *sql.Stmt
selectWhitelistStmt *sql.Stmt
deleteWhitelistStmt *sql.Stmt
deleteAllWhitelistStmt *sql.Stmt
}
func NewSQLiteWhitelistTable(db *sql.DB) (s *whitelistStatements, err error) {
s = &whitelistStatements{
db: db,
}
_, err = db.Exec(whitelistSchema)
if err != nil {
return
}
return s, sqlutil.StatementList{
{&s.insertWhitelistStmt, insertWhitelistSQL},
{&s.selectWhitelistStmt, selectWhitelistSQL},
{&s.deleteWhitelistStmt, deleteWhitelistSQL},
{&s.deleteAllWhitelistStmt, deleteAllWhitelistSQL},
}.Prepare(db)
}
func (s *whitelistStatements) InsertWhitelist(
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.insertWhitelistStmt)
_, err := stmt.ExecContext(ctx, serverName)
return err
}
func (s *whitelistStatements) SelectWhitelist(
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
) (bool, error) {
stmt := sqlutil.TxStmt(txn, s.selectWhitelistStmt)
res, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return false, err
}
defer res.Close() // nolint:errcheck
// The query will return the server name if the server is whitelisted, and
// will return no rows if not. By calling Next, we find out if a row was
// returned or not - we don't care about the value itself.
return res.Next(), nil
}
func (s *whitelistStatements) DeleteWhitelist(
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteWhitelistStmt)
_, err := stmt.ExecContext(ctx, serverName)
return err
}
func (s *whitelistStatements) DeleteAllWhitelist(
ctx context.Context, txn *sql.Tx,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllWhitelistStmt)
_, err := stmt.ExecContext(ctx)
return err
}

View file

@ -72,6 +72,13 @@ type FederationBlacklist interface {
DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error
} }
type FederationWhitelist interface {
InsertWhitelist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error
SelectWhitelist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (bool, error)
DeleteWhitelist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error
DeleteAllWhitelist(ctx context.Context, txn *sql.Tx) error
}
type FederationAssumedOffline interface { type FederationAssumedOffline interface {
InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error
SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (bool, error) SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (bool, error)

View file

@ -46,6 +46,12 @@ type FederationAPI struct {
// Should we prefer direct key fetches over perspective ones? // Should we prefer direct key fetches over perspective ones?
PreferDirectFetch bool `yaml:"prefer_direct_fetch"` PreferDirectFetch bool `yaml:"prefer_direct_fetch"`
// Enable servers whitelist function
EnableWhitelist bool `yaml:"enable_whitelist"`
// The list of whitelisted servers
WhitelistedServers []spec.ServerName `yaml:"whitelisted_servers"`
} }
func (c *FederationAPI) Defaults(opts DefaultOpts) { func (c *FederationAPI) Defaults(opts DefaultOpts) {
@ -73,6 +79,8 @@ func (c *FederationAPI) Defaults(opts DefaultOpts) {
c.Database.ConnectionString = "file:federationapi.db" c.Database.ConnectionString = "file:federationapi.db"
} }
} }
c.EnableWhitelist = false
c.WhitelistedServers = make([]spec.ServerName, 0)
} }
func (c *FederationAPI) Verify(configErrs *ConfigErrors) { func (c *FederationAPI) Verify(configErrs *ConfigErrors) {

View file

@ -96,6 +96,8 @@ client_api:
federation_api: federation_api:
database: database:
connection_string: file:federationapi.db connection_string: file:federationapi.db
enable_whitelist: true
whitelisted_servers: ["https://matrix.org"]
key_server: key_server:
database: database:
connection_string: file:keyserver.db connection_string: file:keyserver.db