mirror of
https://github.com/element-hq/dendrite.git
synced 2025-09-13 21:02:25 +03:00
Implement server whitelist
This commit is contained in:
parent
add73ec866
commit
b2ecd5648c
14 changed files with 332 additions and 0 deletions
|
@ -107,6 +107,22 @@ func NewInternalAPI(
|
|||
|
||||
stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1, cfg.P2PFederationRetriesUntilAssumedOffline+1, cfg.EnableRelays)
|
||||
|
||||
// Add servers to whitelist if enabled
|
||||
if cfg.EnableWhitelist {
|
||||
// We need to clear the list of the whitelisted servers during init
|
||||
err = stats.DB.RemoveAllServersFromWhitelist()
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panic("failed to clear whitelisted servers")
|
||||
}
|
||||
// Add each whitelisted server to the database
|
||||
for _, server := range cfg.WhitelistedServers {
|
||||
err = stats.DB.AddServerToWhitelist(server)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to add server %s to whitelist", server)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
js, nats := natsInstance.Prepare(processContext, &cfg.Matrix.JetStream)
|
||||
|
||||
signingInfo := dendriteCfg.Global.SigningIdentities()
|
||||
|
|
|
@ -112,6 +112,12 @@ func NewFederationInternalAPI(
|
|||
}
|
||||
}
|
||||
|
||||
// IsWhitelistedOrAny checks if the server is whitelisted or the whitelist is disabled, so we can connect to any server
|
||||
func (a *FederationInternalAPI) IsWhitelistedOrAny(s spec.ServerName) bool {
|
||||
stats := a.statistics.ForServer(s)
|
||||
return stats.Whitelisted() || !a.cfg.EnableWhitelist
|
||||
}
|
||||
|
||||
func (a *FederationInternalAPI) IsBlacklistedOrBackingOff(s spec.ServerName) (*statistics.ServerStatistics, error) {
|
||||
stats := a.statistics.ForServer(s)
|
||||
if stats.Blacklisted() {
|
||||
|
|
|
@ -44,6 +44,9 @@ func (a *FederationInternalAPI) GetEventAuth(
|
|||
) (res fclient.RespEventAuth, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
||||
defer cancel()
|
||||
if !a.IsWhitelistedOrAny(s) {
|
||||
return fclient.RespEventAuth{}, nil
|
||||
}
|
||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||
return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID)
|
||||
})
|
||||
|
@ -58,6 +61,9 @@ func (a *FederationInternalAPI) GetUserDevices(
|
|||
) (fclient.RespUserDevices, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
||||
defer cancel()
|
||||
if !a.IsWhitelistedOrAny(s) {
|
||||
return fclient.RespUserDevices{}, nil
|
||||
}
|
||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||
return a.federation.GetUserDevices(ctx, origin, s, userID)
|
||||
})
|
||||
|
@ -72,6 +78,9 @@ func (a *FederationInternalAPI) ClaimKeys(
|
|||
) (fclient.RespClaimKeys, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
||||
defer cancel()
|
||||
if !a.IsWhitelistedOrAny(s) {
|
||||
return fclient.RespClaimKeys{}, nil
|
||||
}
|
||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||
return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys)
|
||||
})
|
||||
|
@ -84,6 +93,9 @@ func (a *FederationInternalAPI) ClaimKeys(
|
|||
func (a *FederationInternalAPI) QueryKeys(
|
||||
ctx context.Context, origin, s spec.ServerName, keys map[string][]string,
|
||||
) (fclient.RespQueryKeys, error) {
|
||||
if !a.IsWhitelistedOrAny(s) {
|
||||
return fclient.RespQueryKeys{}, nil
|
||||
}
|
||||
ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) {
|
||||
return a.federation.QueryKeys(ctx, origin, s, keys)
|
||||
})
|
||||
|
@ -98,6 +110,9 @@ func (a *FederationInternalAPI) Backfill(
|
|||
) (res gomatrixserverlib.Transaction, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
||||
defer cancel()
|
||||
if !a.IsWhitelistedOrAny(s) {
|
||||
return gomatrixserverlib.Transaction{}, nil
|
||||
}
|
||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||
return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs)
|
||||
})
|
||||
|
@ -112,6 +127,9 @@ func (a *FederationInternalAPI) LookupState(
|
|||
) (res gomatrixserverlib.StateResponse, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
||||
defer cancel()
|
||||
if !a.IsWhitelistedOrAny(s) {
|
||||
return &fclient.RespState{}, nil
|
||||
} // TODO check &
|
||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||
return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion)
|
||||
})
|
||||
|
@ -127,6 +145,9 @@ func (a *FederationInternalAPI) LookupStateIDs(
|
|||
) (res gomatrixserverlib.StateIDResponse, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
||||
defer cancel()
|
||||
if !a.IsWhitelistedOrAny(s) {
|
||||
return fclient.RespStateIDs{}, nil
|
||||
}
|
||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||
return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID)
|
||||
})
|
||||
|
@ -142,6 +163,9 @@ func (a *FederationInternalAPI) LookupMissingEvents(
|
|||
) (res fclient.RespMissingEvents, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
||||
defer cancel()
|
||||
if !a.IsWhitelistedOrAny(s) {
|
||||
return fclient.RespMissingEvents{}, nil
|
||||
}
|
||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||
return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion)
|
||||
})
|
||||
|
@ -156,6 +180,9 @@ func (a *FederationInternalAPI) GetEvent(
|
|||
) (res gomatrixserverlib.Transaction, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
||||
defer cancel()
|
||||
if !a.IsWhitelistedOrAny(s) {
|
||||
return gomatrixserverlib.Transaction{}, nil
|
||||
}
|
||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||
return a.federation.GetEvent(ctx, origin, s, eventID)
|
||||
})
|
||||
|
@ -170,6 +197,9 @@ func (a *FederationInternalAPI) LookupServerKeys(
|
|||
) ([]gomatrixserverlib.ServerKeys, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||
defer cancel()
|
||||
if !a.IsWhitelistedOrAny(s) {
|
||||
return []gomatrixserverlib.ServerKeys{}, nil
|
||||
}
|
||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||
return a.federation.LookupServerKeys(ctx, s, keyRequests)
|
||||
})
|
||||
|
@ -185,6 +215,9 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
|
|||
) (res fclient.MSC2836EventRelationshipsResponse, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||
defer cancel()
|
||||
if !a.IsWhitelistedOrAny(s) {
|
||||
return res, nil
|
||||
}
|
||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||
return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion)
|
||||
})
|
||||
|
@ -199,6 +232,9 @@ func (a *FederationInternalAPI) RoomHierarchies(
|
|||
) (res fclient.RoomHierarchyResponse, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||
defer cancel()
|
||||
if !a.IsWhitelistedOrAny(s) {
|
||||
return res, nil
|
||||
}
|
||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||
return a.federation.RoomHierarchy(ctx, origin, s, roomID, suggestedOnly)
|
||||
})
|
||||
|
|
|
@ -86,6 +86,10 @@ func (a *FederationInternalAPI) QueryServerKeys(
|
|||
}
|
||||
util.GetLogger(ctx).WithField("server", req.ServerName).WithError(err).Warn("notary: failed to satisfy keys request entirely from cache, hitting direct")
|
||||
|
||||
if !a.IsWhitelistedOrAny(req.ServerName) {
|
||||
return nil
|
||||
}
|
||||
|
||||
serverKeys, err := a.fetchServerKeysDirectly(ctx, req.ServerName)
|
||||
if err != nil {
|
||||
// try to load as much as we can from the cache in a best effort basis
|
||||
|
|
|
@ -78,6 +78,13 @@ func (s *Statistics) ForServer(serverName spec.ServerName) *ServerStatistics {
|
|||
server.blacklisted.Store(blacklisted)
|
||||
}
|
||||
|
||||
whitelisted, err := s.DB.IsServerWhitelisted(serverName)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to get whitelist entry %q", serverName)
|
||||
} else {
|
||||
server.whitelisted.Store(whitelisted)
|
||||
}
|
||||
|
||||
// Don't bother hitting the database 2 additional times
|
||||
// if we don't want to use relays.
|
||||
if !s.enableRelays {
|
||||
|
@ -118,6 +125,7 @@ type ServerStatistics struct {
|
|||
statistics *Statistics //
|
||||
serverName spec.ServerName //
|
||||
blacklisted atomic.Bool // is the node blacklisted
|
||||
whitelisted atomic.Bool // is the node whitelisted
|
||||
assumedOffline atomic.Bool // is the node assumed to be offline
|
||||
backoffStarted atomic.Bool // is the backoff started
|
||||
backoffUntil atomic.Value // time.Time until this backoff interval ends
|
||||
|
@ -281,6 +289,10 @@ func (s *ServerStatistics) Blacklisted() bool {
|
|||
return s.blacklisted.Load()
|
||||
}
|
||||
|
||||
// Whitelisted returns true if the server is whitelisted and false
|
||||
// otherwise.
|
||||
func (s *ServerStatistics) Whitelisted() bool { return s.whitelisted.Load() }
|
||||
|
||||
// AssumedOffline returns true if the server is assumed offline and false
|
||||
// otherwise.
|
||||
func (s *ServerStatistics) AssumedOffline() bool {
|
||||
|
@ -302,6 +314,16 @@ func (s *ServerStatistics) removeBlacklist() bool {
|
|||
return wasBlacklisted
|
||||
}
|
||||
|
||||
// removeWhitelist removes the whitelisted status from the server.
|
||||
// Returns whether the server was whitelisted.
|
||||
func (s *ServerStatistics) removeWhitelist() bool {
|
||||
if s.Whitelisted() {
|
||||
_ = s.statistics.DB.RemoveServerFromWhitelist(s.serverName)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// removeAssumedOffline removes the assumed offline status from the server.
|
||||
func (s *ServerStatistics) removeAssumedOffline() {
|
||||
if s.AssumedOffline() {
|
||||
|
|
|
@ -49,6 +49,11 @@ type Database interface {
|
|||
RemoveAllServersFromBlacklist() error
|
||||
IsServerBlacklisted(serverName spec.ServerName) (bool, error)
|
||||
|
||||
AddServerToWhitelist(serverName spec.ServerName) error
|
||||
RemoveServerFromWhitelist(serverName spec.ServerName) error
|
||||
RemoveAllServersFromWhitelist() error
|
||||
IsServerWhitelisted(serverName spec.ServerName) (bool, error)
|
||||
|
||||
// Adds the server to the list of assumed offline servers.
|
||||
// If the server already exists in the table, nothing happens and returns success.
|
||||
SetServerAssumedOffline(ctx context.Context, serverName spec.ServerName) error
|
||||
|
|
|
@ -38,6 +38,10 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
whitelist, err := NewPostgresWhitelistTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
joinedHosts, err := NewPostgresJoinedHostsTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -104,6 +108,7 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties
|
|||
FederationQueueEDUs: queueEDUs,
|
||||
FederationQueueJSON: queueJSON,
|
||||
FederationBlacklist: blacklist,
|
||||
FederationWhitelist: whitelist,
|
||||
FederationAssumedOffline: assumedOffline,
|
||||
FederationRelayServers: relayServers,
|
||||
FederationInboundPeeks: inboundPeeks,
|
||||
|
|
93
federationapi/storage/postgres/whitelist_table.go
Normal file
93
federationapi/storage/postgres/whitelist_table.go
Normal file
|
@ -0,0 +1,93 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/element-hq/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
)
|
||||
|
||||
const whitelistSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_whitelist (
|
||||
-- The whitelisted server name
|
||||
server_name TEXT NOT NULL,
|
||||
UNIQUE (server_name)
|
||||
`
|
||||
|
||||
const insertWhitelistSQL = "" +
|
||||
"INSERT INTO federationsender_whitelist (server_name) VALUES ($1)" +
|
||||
" ON CONFLICT DO NOTHING"
|
||||
|
||||
const selectWhitelistSQL = "" +
|
||||
"SELECT server_name FROM federationsender_whitelist WHERE server_name = $1"
|
||||
|
||||
const deleteWhitelistSQL = "" +
|
||||
"DELETE FROM federationsender_whitelist WHERE server_name = $1"
|
||||
|
||||
const deleteAllWhitelistSQL = "" +
|
||||
"TRUNCATE federationsender_whitelist"
|
||||
|
||||
type whitelistStatements struct {
|
||||
db *sql.DB
|
||||
insertWhitelistStmt *sql.Stmt
|
||||
selectWhitelistStmt *sql.Stmt
|
||||
deleteWhitelistStmt *sql.Stmt
|
||||
deleteAllWhitelistStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresWhitelistTable(db *sql.DB) (s *whitelistStatements, err error) {
|
||||
s = &whitelistStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(whitelistSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertWhitelistStmt, insertWhitelistSQL},
|
||||
{&s.selectWhitelistStmt, selectWhitelistSQL},
|
||||
{&s.deleteWhitelistStmt, deleteWhitelistSQL},
|
||||
{&s.deleteAllWhitelistStmt, deleteAllWhitelistSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *whitelistStatements) InsertWhitelist(
|
||||
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertWhitelistStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *whitelistStatements) SelectWhitelist(
|
||||
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
|
||||
) (bool, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectWhitelistStmt)
|
||||
res, err := stmt.QueryContext(ctx, serverName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer res.Close() // nolint:errcheck
|
||||
// The query will return the server name if the server is whitelisted, and
|
||||
// will return no rows if not. By calling Next, we find out if a row was
|
||||
// returned or not - we don't care about the value itself.
|
||||
return res.Next(), nil
|
||||
}
|
||||
|
||||
func (s *whitelistStatements) DeleteWhitelist(
|
||||
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteWhitelistStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *whitelistStatements) DeleteAllWhitelist(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteAllWhitelistStmt)
|
||||
_, err := stmt.ExecContext(ctx)
|
||||
return err
|
||||
}
|
|
@ -31,6 +31,7 @@ type Database struct {
|
|||
FederationQueueJSON tables.FederationQueueJSON
|
||||
FederationJoinedHosts tables.FederationJoinedHosts
|
||||
FederationBlacklist tables.FederationBlacklist
|
||||
FederationWhitelist tables.FederationWhitelist
|
||||
FederationAssumedOffline tables.FederationAssumedOffline
|
||||
FederationRelayServers tables.FederationRelayServers
|
||||
FederationOutboundPeeks tables.FederationOutboundPeeks
|
||||
|
@ -148,6 +149,14 @@ func (d *Database) AddServerToBlacklist(
|
|||
})
|
||||
}
|
||||
|
||||
func (d *Database) AddServerToWhitelist(
|
||||
serverName spec.ServerName,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationWhitelist.InsertWhitelist(context.TODO(), txn, serverName)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) RemoveServerFromBlacklist(
|
||||
serverName spec.ServerName,
|
||||
) error {
|
||||
|
@ -156,18 +165,38 @@ func (d *Database) RemoveServerFromBlacklist(
|
|||
})
|
||||
}
|
||||
|
||||
func (d *Database) RemoveServerFromWhitelist(
|
||||
serverName spec.ServerName,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationWhitelist.DeleteWhitelist(context.TODO(), txn, serverName)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) RemoveAllServersFromBlacklist() error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationBlacklist.DeleteAllBlacklist(context.TODO(), txn)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) RemoveAllServersFromWhitelist() error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationWhitelist.DeleteAllWhitelist(context.TODO(), txn)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) IsServerBlacklisted(
|
||||
serverName spec.ServerName,
|
||||
) (bool, error) {
|
||||
return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName)
|
||||
}
|
||||
|
||||
func (d *Database) IsServerWhitelisted(
|
||||
serverName spec.ServerName,
|
||||
) (bool, error) {
|
||||
return d.FederationWhitelist.SelectWhitelist(context.TODO(), nil, serverName)
|
||||
}
|
||||
|
||||
func (d *Database) SetServerAssumedOffline(
|
||||
ctx context.Context,
|
||||
serverName spec.ServerName,
|
||||
|
|
|
@ -36,6 +36,10 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
whitelist, err := NewSQLiteWhitelistTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -102,6 +106,7 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties
|
|||
FederationQueueEDUs: queueEDUs,
|
||||
FederationQueueJSON: queueJSON,
|
||||
FederationBlacklist: blacklist,
|
||||
FederationWhitelist: whitelist,
|
||||
FederationAssumedOffline: assumedOffline,
|
||||
FederationRelayServers: relayServers,
|
||||
FederationOutboundPeeks: outboundPeeks,
|
||||
|
|
94
federationapi/storage/sqlite3/whitelist_table.go
Normal file
94
federationapi/storage/sqlite3/whitelist_table.go
Normal file
|
@ -0,0 +1,94 @@
|
|||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/element-hq/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
)
|
||||
|
||||
const whitelistSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_whitelist (
|
||||
-- The whitelisted server name
|
||||
server_name TEXT NOT NULL,
|
||||
UNIQUE (server_name)
|
||||
);
|
||||
`
|
||||
|
||||
const insertWhitelistSQL = "" +
|
||||
"INSERT INTO federationsender_whitelist (server_name) VALUES ($1)" +
|
||||
" ON CONFLICT DO NOTHING"
|
||||
|
||||
const selectWhitelistSQL = "" +
|
||||
"SELECT server_name FROM federationsender_whitelist WHERE server_name = $1"
|
||||
|
||||
const deleteWhitelistSQL = "" +
|
||||
"DELETE FROM federationsender_whitelist WHERE server_name = $1"
|
||||
|
||||
const deleteAllWhitelistSQL = "" +
|
||||
"DELETE FROM federationsender_whitelist"
|
||||
|
||||
type whitelistStatements struct {
|
||||
db *sql.DB
|
||||
insertWhitelistStmt *sql.Stmt
|
||||
selectWhitelistStmt *sql.Stmt
|
||||
deleteWhitelistStmt *sql.Stmt
|
||||
deleteAllWhitelistStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSQLiteWhitelistTable(db *sql.DB) (s *whitelistStatements, err error) {
|
||||
s = &whitelistStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(whitelistSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertWhitelistStmt, insertWhitelistSQL},
|
||||
{&s.selectWhitelistStmt, selectWhitelistSQL},
|
||||
{&s.deleteWhitelistStmt, deleteWhitelistSQL},
|
||||
{&s.deleteAllWhitelistStmt, deleteAllWhitelistSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *whitelistStatements) InsertWhitelist(
|
||||
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertWhitelistStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *whitelistStatements) SelectWhitelist(
|
||||
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
|
||||
) (bool, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectWhitelistStmt)
|
||||
res, err := stmt.QueryContext(ctx, serverName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer res.Close() // nolint:errcheck
|
||||
// The query will return the server name if the server is whitelisted, and
|
||||
// will return no rows if not. By calling Next, we find out if a row was
|
||||
// returned or not - we don't care about the value itself.
|
||||
return res.Next(), nil
|
||||
}
|
||||
|
||||
func (s *whitelistStatements) DeleteWhitelist(
|
||||
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteWhitelistStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *whitelistStatements) DeleteAllWhitelist(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteAllWhitelistStmt)
|
||||
_, err := stmt.ExecContext(ctx)
|
||||
return err
|
||||
}
|
|
@ -72,6 +72,13 @@ type FederationBlacklist interface {
|
|||
DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error
|
||||
}
|
||||
|
||||
type FederationWhitelist interface {
|
||||
InsertWhitelist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error
|
||||
SelectWhitelist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (bool, error)
|
||||
DeleteWhitelist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error
|
||||
DeleteAllWhitelist(ctx context.Context, txn *sql.Tx) error
|
||||
}
|
||||
|
||||
type FederationAssumedOffline interface {
|
||||
InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error
|
||||
SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (bool, error)
|
||||
|
|
|
@ -46,6 +46,12 @@ type FederationAPI struct {
|
|||
|
||||
// Should we prefer direct key fetches over perspective ones?
|
||||
PreferDirectFetch bool `yaml:"prefer_direct_fetch"`
|
||||
|
||||
// Enable servers whitelist function
|
||||
EnableWhitelist bool `yaml:"enable_whitelist"`
|
||||
|
||||
// The list of whitelisted servers
|
||||
WhitelistedServers []spec.ServerName `yaml:"whitelisted_servers"`
|
||||
}
|
||||
|
||||
func (c *FederationAPI) Defaults(opts DefaultOpts) {
|
||||
|
@ -73,6 +79,8 @@ func (c *FederationAPI) Defaults(opts DefaultOpts) {
|
|||
c.Database.ConnectionString = "file:federationapi.db"
|
||||
}
|
||||
}
|
||||
c.EnableWhitelist = false
|
||||
c.WhitelistedServers = make([]spec.ServerName, 0)
|
||||
}
|
||||
|
||||
func (c *FederationAPI) Verify(configErrs *ConfigErrors) {
|
||||
|
|
|
@ -96,6 +96,8 @@ client_api:
|
|||
federation_api:
|
||||
database:
|
||||
connection_string: file:federationapi.db
|
||||
enable_whitelist: true
|
||||
whitelisted_servers: ["https://matrix.org"]
|
||||
key_server:
|
||||
database:
|
||||
connection_string: file:keyserver.db
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue