mirror of
https://github.com/element-hq/dendrite.git
synced 2025-09-13 21:02:25 +03:00
Implement server whitelist
This commit is contained in:
parent
add73ec866
commit
b2ecd5648c
14 changed files with 332 additions and 0 deletions
|
@ -107,6 +107,22 @@ func NewInternalAPI(
|
||||||
|
|
||||||
stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1, cfg.P2PFederationRetriesUntilAssumedOffline+1, cfg.EnableRelays)
|
stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1, cfg.P2PFederationRetriesUntilAssumedOffline+1, cfg.EnableRelays)
|
||||||
|
|
||||||
|
// Add servers to whitelist if enabled
|
||||||
|
if cfg.EnableWhitelist {
|
||||||
|
// We need to clear the list of the whitelisted servers during init
|
||||||
|
err = stats.DB.RemoveAllServersFromWhitelist()
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Panic("failed to clear whitelisted servers")
|
||||||
|
}
|
||||||
|
// Add each whitelisted server to the database
|
||||||
|
for _, server := range cfg.WhitelistedServers {
|
||||||
|
err = stats.DB.AddServerToWhitelist(server)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Panicf("failed to add server %s to whitelist", server)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
js, nats := natsInstance.Prepare(processContext, &cfg.Matrix.JetStream)
|
js, nats := natsInstance.Prepare(processContext, &cfg.Matrix.JetStream)
|
||||||
|
|
||||||
signingInfo := dendriteCfg.Global.SigningIdentities()
|
signingInfo := dendriteCfg.Global.SigningIdentities()
|
||||||
|
|
|
@ -112,6 +112,12 @@ func NewFederationInternalAPI(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsWhitelistedOrAny checks if the server is whitelisted or the whitelist is disabled, so we can connect to any server
|
||||||
|
func (a *FederationInternalAPI) IsWhitelistedOrAny(s spec.ServerName) bool {
|
||||||
|
stats := a.statistics.ForServer(s)
|
||||||
|
return stats.Whitelisted() || !a.cfg.EnableWhitelist
|
||||||
|
}
|
||||||
|
|
||||||
func (a *FederationInternalAPI) IsBlacklistedOrBackingOff(s spec.ServerName) (*statistics.ServerStatistics, error) {
|
func (a *FederationInternalAPI) IsBlacklistedOrBackingOff(s spec.ServerName) (*statistics.ServerStatistics, error) {
|
||||||
stats := a.statistics.ForServer(s)
|
stats := a.statistics.ForServer(s)
|
||||||
if stats.Blacklisted() {
|
if stats.Blacklisted() {
|
||||||
|
|
|
@ -44,6 +44,9 @@ func (a *FederationInternalAPI) GetEventAuth(
|
||||||
) (res fclient.RespEventAuth, err error) {
|
) (res fclient.RespEventAuth, err error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
if !a.IsWhitelistedOrAny(s) {
|
||||||
|
return fclient.RespEventAuth{}, nil
|
||||||
|
}
|
||||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||||
return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID)
|
return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID)
|
||||||
})
|
})
|
||||||
|
@ -58,6 +61,9 @@ func (a *FederationInternalAPI) GetUserDevices(
|
||||||
) (fclient.RespUserDevices, error) {
|
) (fclient.RespUserDevices, error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
if !a.IsWhitelistedOrAny(s) {
|
||||||
|
return fclient.RespUserDevices{}, nil
|
||||||
|
}
|
||||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||||
return a.federation.GetUserDevices(ctx, origin, s, userID)
|
return a.federation.GetUserDevices(ctx, origin, s, userID)
|
||||||
})
|
})
|
||||||
|
@ -72,6 +78,9 @@ func (a *FederationInternalAPI) ClaimKeys(
|
||||||
) (fclient.RespClaimKeys, error) {
|
) (fclient.RespClaimKeys, error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
if !a.IsWhitelistedOrAny(s) {
|
||||||
|
return fclient.RespClaimKeys{}, nil
|
||||||
|
}
|
||||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||||
return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys)
|
return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys)
|
||||||
})
|
})
|
||||||
|
@ -84,6 +93,9 @@ func (a *FederationInternalAPI) ClaimKeys(
|
||||||
func (a *FederationInternalAPI) QueryKeys(
|
func (a *FederationInternalAPI) QueryKeys(
|
||||||
ctx context.Context, origin, s spec.ServerName, keys map[string][]string,
|
ctx context.Context, origin, s spec.ServerName, keys map[string][]string,
|
||||||
) (fclient.RespQueryKeys, error) {
|
) (fclient.RespQueryKeys, error) {
|
||||||
|
if !a.IsWhitelistedOrAny(s) {
|
||||||
|
return fclient.RespQueryKeys{}, nil
|
||||||
|
}
|
||||||
ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) {
|
ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) {
|
||||||
return a.federation.QueryKeys(ctx, origin, s, keys)
|
return a.federation.QueryKeys(ctx, origin, s, keys)
|
||||||
})
|
})
|
||||||
|
@ -98,6 +110,9 @@ func (a *FederationInternalAPI) Backfill(
|
||||||
) (res gomatrixserverlib.Transaction, err error) {
|
) (res gomatrixserverlib.Transaction, err error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
if !a.IsWhitelistedOrAny(s) {
|
||||||
|
return gomatrixserverlib.Transaction{}, nil
|
||||||
|
}
|
||||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||||
return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs)
|
return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs)
|
||||||
})
|
})
|
||||||
|
@ -112,6 +127,9 @@ func (a *FederationInternalAPI) LookupState(
|
||||||
) (res gomatrixserverlib.StateResponse, err error) {
|
) (res gomatrixserverlib.StateResponse, err error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
if !a.IsWhitelistedOrAny(s) {
|
||||||
|
return &fclient.RespState{}, nil
|
||||||
|
} // TODO check &
|
||||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||||
return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion)
|
return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion)
|
||||||
})
|
})
|
||||||
|
@ -127,6 +145,9 @@ func (a *FederationInternalAPI) LookupStateIDs(
|
||||||
) (res gomatrixserverlib.StateIDResponse, err error) {
|
) (res gomatrixserverlib.StateIDResponse, err error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
if !a.IsWhitelistedOrAny(s) {
|
||||||
|
return fclient.RespStateIDs{}, nil
|
||||||
|
}
|
||||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||||
return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID)
|
return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID)
|
||||||
})
|
})
|
||||||
|
@ -142,6 +163,9 @@ func (a *FederationInternalAPI) LookupMissingEvents(
|
||||||
) (res fclient.RespMissingEvents, err error) {
|
) (res fclient.RespMissingEvents, err error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
if !a.IsWhitelistedOrAny(s) {
|
||||||
|
return fclient.RespMissingEvents{}, nil
|
||||||
|
}
|
||||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||||
return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion)
|
return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion)
|
||||||
})
|
})
|
||||||
|
@ -156,6 +180,9 @@ func (a *FederationInternalAPI) GetEvent(
|
||||||
) (res gomatrixserverlib.Transaction, err error) {
|
) (res gomatrixserverlib.Transaction, err error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
if !a.IsWhitelistedOrAny(s) {
|
||||||
|
return gomatrixserverlib.Transaction{}, nil
|
||||||
|
}
|
||||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||||
return a.federation.GetEvent(ctx, origin, s, eventID)
|
return a.federation.GetEvent(ctx, origin, s, eventID)
|
||||||
})
|
})
|
||||||
|
@ -170,6 +197,9 @@ func (a *FederationInternalAPI) LookupServerKeys(
|
||||||
) ([]gomatrixserverlib.ServerKeys, error) {
|
) ([]gomatrixserverlib.ServerKeys, error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
if !a.IsWhitelistedOrAny(s) {
|
||||||
|
return []gomatrixserverlib.ServerKeys{}, nil
|
||||||
|
}
|
||||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||||
return a.federation.LookupServerKeys(ctx, s, keyRequests)
|
return a.federation.LookupServerKeys(ctx, s, keyRequests)
|
||||||
})
|
})
|
||||||
|
@ -185,6 +215,9 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
|
||||||
) (res fclient.MSC2836EventRelationshipsResponse, err error) {
|
) (res fclient.MSC2836EventRelationshipsResponse, err error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
if !a.IsWhitelistedOrAny(s) {
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||||
return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion)
|
return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion)
|
||||||
})
|
})
|
||||||
|
@ -199,6 +232,9 @@ func (a *FederationInternalAPI) RoomHierarchies(
|
||||||
) (res fclient.RoomHierarchyResponse, err error) {
|
) (res fclient.RoomHierarchyResponse, err error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
if !a.IsWhitelistedOrAny(s) {
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||||
return a.federation.RoomHierarchy(ctx, origin, s, roomID, suggestedOnly)
|
return a.federation.RoomHierarchy(ctx, origin, s, roomID, suggestedOnly)
|
||||||
})
|
})
|
||||||
|
|
|
@ -86,6 +86,10 @@ func (a *FederationInternalAPI) QueryServerKeys(
|
||||||
}
|
}
|
||||||
util.GetLogger(ctx).WithField("server", req.ServerName).WithError(err).Warn("notary: failed to satisfy keys request entirely from cache, hitting direct")
|
util.GetLogger(ctx).WithField("server", req.ServerName).WithError(err).Warn("notary: failed to satisfy keys request entirely from cache, hitting direct")
|
||||||
|
|
||||||
|
if !a.IsWhitelistedOrAny(req.ServerName) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
serverKeys, err := a.fetchServerKeysDirectly(ctx, req.ServerName)
|
serverKeys, err := a.fetchServerKeysDirectly(ctx, req.ServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// try to load as much as we can from the cache in a best effort basis
|
// try to load as much as we can from the cache in a best effort basis
|
||||||
|
|
|
@ -78,6 +78,13 @@ func (s *Statistics) ForServer(serverName spec.ServerName) *ServerStatistics {
|
||||||
server.blacklisted.Store(blacklisted)
|
server.blacklisted.Store(blacklisted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
whitelisted, err := s.DB.IsServerWhitelisted(serverName)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Errorf("Failed to get whitelist entry %q", serverName)
|
||||||
|
} else {
|
||||||
|
server.whitelisted.Store(whitelisted)
|
||||||
|
}
|
||||||
|
|
||||||
// Don't bother hitting the database 2 additional times
|
// Don't bother hitting the database 2 additional times
|
||||||
// if we don't want to use relays.
|
// if we don't want to use relays.
|
||||||
if !s.enableRelays {
|
if !s.enableRelays {
|
||||||
|
@ -118,6 +125,7 @@ type ServerStatistics struct {
|
||||||
statistics *Statistics //
|
statistics *Statistics //
|
||||||
serverName spec.ServerName //
|
serverName spec.ServerName //
|
||||||
blacklisted atomic.Bool // is the node blacklisted
|
blacklisted atomic.Bool // is the node blacklisted
|
||||||
|
whitelisted atomic.Bool // is the node whitelisted
|
||||||
assumedOffline atomic.Bool // is the node assumed to be offline
|
assumedOffline atomic.Bool // is the node assumed to be offline
|
||||||
backoffStarted atomic.Bool // is the backoff started
|
backoffStarted atomic.Bool // is the backoff started
|
||||||
backoffUntil atomic.Value // time.Time until this backoff interval ends
|
backoffUntil atomic.Value // time.Time until this backoff interval ends
|
||||||
|
@ -281,6 +289,10 @@ func (s *ServerStatistics) Blacklisted() bool {
|
||||||
return s.blacklisted.Load()
|
return s.blacklisted.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Whitelisted returns true if the server is whitelisted and false
|
||||||
|
// otherwise.
|
||||||
|
func (s *ServerStatistics) Whitelisted() bool { return s.whitelisted.Load() }
|
||||||
|
|
||||||
// AssumedOffline returns true if the server is assumed offline and false
|
// AssumedOffline returns true if the server is assumed offline and false
|
||||||
// otherwise.
|
// otherwise.
|
||||||
func (s *ServerStatistics) AssumedOffline() bool {
|
func (s *ServerStatistics) AssumedOffline() bool {
|
||||||
|
@ -302,6 +314,16 @@ func (s *ServerStatistics) removeBlacklist() bool {
|
||||||
return wasBlacklisted
|
return wasBlacklisted
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// removeWhitelist removes the whitelisted status from the server.
|
||||||
|
// Returns whether the server was whitelisted.
|
||||||
|
func (s *ServerStatistics) removeWhitelist() bool {
|
||||||
|
if s.Whitelisted() {
|
||||||
|
_ = s.statistics.DB.RemoveServerFromWhitelist(s.serverName)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// removeAssumedOffline removes the assumed offline status from the server.
|
// removeAssumedOffline removes the assumed offline status from the server.
|
||||||
func (s *ServerStatistics) removeAssumedOffline() {
|
func (s *ServerStatistics) removeAssumedOffline() {
|
||||||
if s.AssumedOffline() {
|
if s.AssumedOffline() {
|
||||||
|
|
|
@ -49,6 +49,11 @@ type Database interface {
|
||||||
RemoveAllServersFromBlacklist() error
|
RemoveAllServersFromBlacklist() error
|
||||||
IsServerBlacklisted(serverName spec.ServerName) (bool, error)
|
IsServerBlacklisted(serverName spec.ServerName) (bool, error)
|
||||||
|
|
||||||
|
AddServerToWhitelist(serverName spec.ServerName) error
|
||||||
|
RemoveServerFromWhitelist(serverName spec.ServerName) error
|
||||||
|
RemoveAllServersFromWhitelist() error
|
||||||
|
IsServerWhitelisted(serverName spec.ServerName) (bool, error)
|
||||||
|
|
||||||
// Adds the server to the list of assumed offline servers.
|
// Adds the server to the list of assumed offline servers.
|
||||||
// If the server already exists in the table, nothing happens and returns success.
|
// If the server already exists in the table, nothing happens and returns success.
|
||||||
SetServerAssumedOffline(ctx context.Context, serverName spec.ServerName) error
|
SetServerAssumedOffline(ctx context.Context, serverName spec.ServerName) error
|
||||||
|
|
|
@ -38,6 +38,10 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
whitelist, err := NewPostgresWhitelistTable(d.db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
joinedHosts, err := NewPostgresJoinedHostsTable(d.db)
|
joinedHosts, err := NewPostgresJoinedHostsTable(d.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -104,6 +108,7 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties
|
||||||
FederationQueueEDUs: queueEDUs,
|
FederationQueueEDUs: queueEDUs,
|
||||||
FederationQueueJSON: queueJSON,
|
FederationQueueJSON: queueJSON,
|
||||||
FederationBlacklist: blacklist,
|
FederationBlacklist: blacklist,
|
||||||
|
FederationWhitelist: whitelist,
|
||||||
FederationAssumedOffline: assumedOffline,
|
FederationAssumedOffline: assumedOffline,
|
||||||
FederationRelayServers: relayServers,
|
FederationRelayServers: relayServers,
|
||||||
FederationInboundPeeks: inboundPeeks,
|
FederationInboundPeeks: inboundPeeks,
|
||||||
|
|
93
federationapi/storage/postgres/whitelist_table.go
Normal file
93
federationapi/storage/postgres/whitelist_table.go
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/element-hq/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
|
)
|
||||||
|
|
||||||
|
const whitelistSchema = `
|
||||||
|
CREATE TABLE IF NOT EXISTS federationsender_whitelist (
|
||||||
|
-- The whitelisted server name
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
|
UNIQUE (server_name)
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertWhitelistSQL = "" +
|
||||||
|
"INSERT INTO federationsender_whitelist (server_name) VALUES ($1)" +
|
||||||
|
" ON CONFLICT DO NOTHING"
|
||||||
|
|
||||||
|
const selectWhitelistSQL = "" +
|
||||||
|
"SELECT server_name FROM federationsender_whitelist WHERE server_name = $1"
|
||||||
|
|
||||||
|
const deleteWhitelistSQL = "" +
|
||||||
|
"DELETE FROM federationsender_whitelist WHERE server_name = $1"
|
||||||
|
|
||||||
|
const deleteAllWhitelistSQL = "" +
|
||||||
|
"TRUNCATE federationsender_whitelist"
|
||||||
|
|
||||||
|
type whitelistStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
insertWhitelistStmt *sql.Stmt
|
||||||
|
selectWhitelistStmt *sql.Stmt
|
||||||
|
deleteWhitelistStmt *sql.Stmt
|
||||||
|
deleteAllWhitelistStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresWhitelistTable(db *sql.DB) (s *whitelistStatements, err error) {
|
||||||
|
s = &whitelistStatements{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
_, err = db.Exec(whitelistSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
|
{&s.insertWhitelistStmt, insertWhitelistSQL},
|
||||||
|
{&s.selectWhitelistStmt, selectWhitelistSQL},
|
||||||
|
{&s.deleteWhitelistStmt, deleteWhitelistSQL},
|
||||||
|
{&s.deleteAllWhitelistStmt, deleteAllWhitelistSQL},
|
||||||
|
}.Prepare(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *whitelistStatements) InsertWhitelist(
|
||||||
|
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
|
||||||
|
) error {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.insertWhitelistStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, serverName)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *whitelistStatements) SelectWhitelist(
|
||||||
|
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
|
||||||
|
) (bool, error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectWhitelistStmt)
|
||||||
|
res, err := stmt.QueryContext(ctx, serverName)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer res.Close() // nolint:errcheck
|
||||||
|
// The query will return the server name if the server is whitelisted, and
|
||||||
|
// will return no rows if not. By calling Next, we find out if a row was
|
||||||
|
// returned or not - we don't care about the value itself.
|
||||||
|
return res.Next(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *whitelistStatements) DeleteWhitelist(
|
||||||
|
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
|
||||||
|
) error {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.deleteWhitelistStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, serverName)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *whitelistStatements) DeleteAllWhitelist(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
) error {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.deleteAllWhitelistStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx)
|
||||||
|
return err
|
||||||
|
}
|
|
@ -31,6 +31,7 @@ type Database struct {
|
||||||
FederationQueueJSON tables.FederationQueueJSON
|
FederationQueueJSON tables.FederationQueueJSON
|
||||||
FederationJoinedHosts tables.FederationJoinedHosts
|
FederationJoinedHosts tables.FederationJoinedHosts
|
||||||
FederationBlacklist tables.FederationBlacklist
|
FederationBlacklist tables.FederationBlacklist
|
||||||
|
FederationWhitelist tables.FederationWhitelist
|
||||||
FederationAssumedOffline tables.FederationAssumedOffline
|
FederationAssumedOffline tables.FederationAssumedOffline
|
||||||
FederationRelayServers tables.FederationRelayServers
|
FederationRelayServers tables.FederationRelayServers
|
||||||
FederationOutboundPeeks tables.FederationOutboundPeeks
|
FederationOutboundPeeks tables.FederationOutboundPeeks
|
||||||
|
@ -148,6 +149,14 @@ func (d *Database) AddServerToBlacklist(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) AddServerToWhitelist(
|
||||||
|
serverName spec.ServerName,
|
||||||
|
) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.FederationWhitelist.InsertWhitelist(context.TODO(), txn, serverName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Database) RemoveServerFromBlacklist(
|
func (d *Database) RemoveServerFromBlacklist(
|
||||||
serverName spec.ServerName,
|
serverName spec.ServerName,
|
||||||
) error {
|
) error {
|
||||||
|
@ -156,18 +165,38 @@ func (d *Database) RemoveServerFromBlacklist(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) RemoveServerFromWhitelist(
|
||||||
|
serverName spec.ServerName,
|
||||||
|
) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.FederationWhitelist.DeleteWhitelist(context.TODO(), txn, serverName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Database) RemoveAllServersFromBlacklist() error {
|
func (d *Database) RemoveAllServersFromBlacklist() error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.FederationBlacklist.DeleteAllBlacklist(context.TODO(), txn)
|
return d.FederationBlacklist.DeleteAllBlacklist(context.TODO(), txn)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) RemoveAllServersFromWhitelist() error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.FederationWhitelist.DeleteAllWhitelist(context.TODO(), txn)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Database) IsServerBlacklisted(
|
func (d *Database) IsServerBlacklisted(
|
||||||
serverName spec.ServerName,
|
serverName spec.ServerName,
|
||||||
) (bool, error) {
|
) (bool, error) {
|
||||||
return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName)
|
return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) IsServerWhitelisted(
|
||||||
|
serverName spec.ServerName,
|
||||||
|
) (bool, error) {
|
||||||
|
return d.FederationWhitelist.SelectWhitelist(context.TODO(), nil, serverName)
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Database) SetServerAssumedOffline(
|
func (d *Database) SetServerAssumedOffline(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
serverName spec.ServerName,
|
serverName spec.ServerName,
|
||||||
|
|
|
@ -36,6 +36,10 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
whitelist, err := NewSQLiteWhitelistTable(d.db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
|
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -102,6 +106,7 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties
|
||||||
FederationQueueEDUs: queueEDUs,
|
FederationQueueEDUs: queueEDUs,
|
||||||
FederationQueueJSON: queueJSON,
|
FederationQueueJSON: queueJSON,
|
||||||
FederationBlacklist: blacklist,
|
FederationBlacklist: blacklist,
|
||||||
|
FederationWhitelist: whitelist,
|
||||||
FederationAssumedOffline: assumedOffline,
|
FederationAssumedOffline: assumedOffline,
|
||||||
FederationRelayServers: relayServers,
|
FederationRelayServers: relayServers,
|
||||||
FederationOutboundPeeks: outboundPeeks,
|
FederationOutboundPeeks: outboundPeeks,
|
||||||
|
|
94
federationapi/storage/sqlite3/whitelist_table.go
Normal file
94
federationapi/storage/sqlite3/whitelist_table.go
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/element-hq/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
|
)
|
||||||
|
|
||||||
|
const whitelistSchema = `
|
||||||
|
CREATE TABLE IF NOT EXISTS federationsender_whitelist (
|
||||||
|
-- The whitelisted server name
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
|
UNIQUE (server_name)
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertWhitelistSQL = "" +
|
||||||
|
"INSERT INTO federationsender_whitelist (server_name) VALUES ($1)" +
|
||||||
|
" ON CONFLICT DO NOTHING"
|
||||||
|
|
||||||
|
const selectWhitelistSQL = "" +
|
||||||
|
"SELECT server_name FROM federationsender_whitelist WHERE server_name = $1"
|
||||||
|
|
||||||
|
const deleteWhitelistSQL = "" +
|
||||||
|
"DELETE FROM federationsender_whitelist WHERE server_name = $1"
|
||||||
|
|
||||||
|
const deleteAllWhitelistSQL = "" +
|
||||||
|
"DELETE FROM federationsender_whitelist"
|
||||||
|
|
||||||
|
type whitelistStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
insertWhitelistStmt *sql.Stmt
|
||||||
|
selectWhitelistStmt *sql.Stmt
|
||||||
|
deleteWhitelistStmt *sql.Stmt
|
||||||
|
deleteAllWhitelistStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSQLiteWhitelistTable(db *sql.DB) (s *whitelistStatements, err error) {
|
||||||
|
s = &whitelistStatements{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
_, err = db.Exec(whitelistSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
|
{&s.insertWhitelistStmt, insertWhitelistSQL},
|
||||||
|
{&s.selectWhitelistStmt, selectWhitelistSQL},
|
||||||
|
{&s.deleteWhitelistStmt, deleteWhitelistSQL},
|
||||||
|
{&s.deleteAllWhitelistStmt, deleteAllWhitelistSQL},
|
||||||
|
}.Prepare(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *whitelistStatements) InsertWhitelist(
|
||||||
|
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
|
||||||
|
) error {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.insertWhitelistStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, serverName)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *whitelistStatements) SelectWhitelist(
|
||||||
|
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
|
||||||
|
) (bool, error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectWhitelistStmt)
|
||||||
|
res, err := stmt.QueryContext(ctx, serverName)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer res.Close() // nolint:errcheck
|
||||||
|
// The query will return the server name if the server is whitelisted, and
|
||||||
|
// will return no rows if not. By calling Next, we find out if a row was
|
||||||
|
// returned or not - we don't care about the value itself.
|
||||||
|
return res.Next(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *whitelistStatements) DeleteWhitelist(
|
||||||
|
ctx context.Context, txn *sql.Tx, serverName spec.ServerName,
|
||||||
|
) error {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.deleteWhitelistStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, serverName)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *whitelistStatements) DeleteAllWhitelist(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
) error {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.deleteAllWhitelistStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx)
|
||||||
|
return err
|
||||||
|
}
|
|
@ -72,6 +72,13 @@ type FederationBlacklist interface {
|
||||||
DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error
|
DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type FederationWhitelist interface {
|
||||||
|
InsertWhitelist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error
|
||||||
|
SelectWhitelist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (bool, error)
|
||||||
|
DeleteWhitelist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error
|
||||||
|
DeleteAllWhitelist(ctx context.Context, txn *sql.Tx) error
|
||||||
|
}
|
||||||
|
|
||||||
type FederationAssumedOffline interface {
|
type FederationAssumedOffline interface {
|
||||||
InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error
|
InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error
|
||||||
SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (bool, error)
|
SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (bool, error)
|
||||||
|
|
|
@ -46,6 +46,12 @@ type FederationAPI struct {
|
||||||
|
|
||||||
// Should we prefer direct key fetches over perspective ones?
|
// Should we prefer direct key fetches over perspective ones?
|
||||||
PreferDirectFetch bool `yaml:"prefer_direct_fetch"`
|
PreferDirectFetch bool `yaml:"prefer_direct_fetch"`
|
||||||
|
|
||||||
|
// Enable servers whitelist function
|
||||||
|
EnableWhitelist bool `yaml:"enable_whitelist"`
|
||||||
|
|
||||||
|
// The list of whitelisted servers
|
||||||
|
WhitelistedServers []spec.ServerName `yaml:"whitelisted_servers"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *FederationAPI) Defaults(opts DefaultOpts) {
|
func (c *FederationAPI) Defaults(opts DefaultOpts) {
|
||||||
|
@ -73,6 +79,8 @@ func (c *FederationAPI) Defaults(opts DefaultOpts) {
|
||||||
c.Database.ConnectionString = "file:federationapi.db"
|
c.Database.ConnectionString = "file:federationapi.db"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
c.EnableWhitelist = false
|
||||||
|
c.WhitelistedServers = make([]spec.ServerName, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *FederationAPI) Verify(configErrs *ConfigErrors) {
|
func (c *FederationAPI) Verify(configErrs *ConfigErrors) {
|
||||||
|
|
|
@ -96,6 +96,8 @@ client_api:
|
||||||
federation_api:
|
federation_api:
|
||||||
database:
|
database:
|
||||||
connection_string: file:federationapi.db
|
connection_string: file:federationapi.db
|
||||||
|
enable_whitelist: true
|
||||||
|
whitelisted_servers: ["https://matrix.org"]
|
||||||
key_server:
|
key_server:
|
||||||
database:
|
database:
|
||||||
connection_string: file:keyserver.db
|
connection_string: file:keyserver.db
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue