From b5ea31ff5c0666a94d3bb7dab9fc4acd640e1662 Mon Sep 17 00:00:00 2001 From: enaix Date: Fri, 17 Jan 2025 14:55:11 +0300 Subject: [PATCH] Fix federationclient whitelist checks, improve performance --- federationapi/internal/api.go | 7 +-- federationapi/internal/federationclient.go | 50 ++++++++++++---------- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/federationapi/internal/api.go b/federationapi/internal/api.go index 4b165bd0..ba771b2d 100644 --- a/federationapi/internal/api.go +++ b/federationapi/internal/api.go @@ -112,10 +112,11 @@ func NewFederationInternalAPI( } } -// IsWhitelistedOrAny checks if the server is whitelisted or the whitelist is disabled, so we can connect to any server +// IsWhitelistedOrAny checks if the server is whitelisted or the whitelist is disabled (we can connect to any server) func (a *FederationInternalAPI) IsWhitelistedOrAny(s spec.ServerName) bool { - stats := a.statistics.ForServer(s) - return stats.Whitelisted() || !a.cfg.EnableWhitelist + // Thread-safe, since DB access is performed in mutex and stats.Whitelisted is constant + stats := a.statistics.ForServer(s) // Calls mutex if the stats do not exist yet + return !a.cfg.EnableWhitelist || stats.Whitelisted() // Lazy eval } func (a *FederationInternalAPI) IsBlacklistedOrBackingOff(s spec.ServerName) (*statistics.ServerStatistics, error) { diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index 45272f8e..d3bc51fd 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -17,6 +17,9 @@ const defaultTimeout = time.Second * 30 func (a *FederationInternalAPI) MakeJoin( ctx context.Context, origin, s spec.ServerName, roomID, userID string, ) (res gomatrixserverlib.MakeJoinResponse, err error) { + if !a.IsWhitelistedOrAny(s) { + return &fclient.RespMakeJoin{}, nil + } // Is thread-safe, so we can omit ctx call ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.federation.MakeJoin(ctx, origin, s, roomID, userID) @@ -29,6 +32,9 @@ func (a *FederationInternalAPI) MakeJoin( func (a *FederationInternalAPI) SendJoin( ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU, ) (res gomatrixserverlib.SendJoinResponse, err error) { + if !a.IsWhitelistedOrAny(s) { + return &fclient.RespSendJoin{}, nil + } ctx, cancel := context.WithTimeout(ctx, time.Minute*5) defer cancel() ires, err := a.federation.SendJoin(ctx, origin, s, event) @@ -42,11 +48,11 @@ func (a *FederationInternalAPI) GetEventAuth( ctx context.Context, origin, s spec.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, ) (res fclient.RespEventAuth, err error) { - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() if !a.IsWhitelistedOrAny(s) { return fclient.RespEventAuth{}, nil } + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID) }) @@ -59,11 +65,11 @@ func (a *FederationInternalAPI) GetEventAuth( func (a *FederationInternalAPI) GetUserDevices( ctx context.Context, origin, s spec.ServerName, userID string, ) (fclient.RespUserDevices, error) { - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() if !a.IsWhitelistedOrAny(s) { return fclient.RespUserDevices{}, nil } + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.GetUserDevices(ctx, origin, s, userID) }) @@ -76,11 +82,11 @@ func (a *FederationInternalAPI) GetUserDevices( func (a *FederationInternalAPI) ClaimKeys( ctx context.Context, origin, s spec.ServerName, oneTimeKeys map[string]map[string]string, ) (fclient.RespClaimKeys, error) { - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() if !a.IsWhitelistedOrAny(s) { return fclient.RespClaimKeys{}, nil } + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys) }) @@ -108,11 +114,11 @@ func (a *FederationInternalAPI) QueryKeys( func (a *FederationInternalAPI) Backfill( ctx context.Context, origin, s spec.ServerName, roomID string, limit int, eventIDs []string, ) (res gomatrixserverlib.Transaction, err error) { - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() if !a.IsWhitelistedOrAny(s) { return gomatrixserverlib.Transaction{}, nil } + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs) }) @@ -125,11 +131,11 @@ func (a *FederationInternalAPI) Backfill( func (a *FederationInternalAPI) LookupState( ctx context.Context, origin, s spec.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.StateResponse, err error) { - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() if !a.IsWhitelistedOrAny(s) { return &fclient.RespState{}, nil } + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion) }) @@ -143,11 +149,11 @@ func (a *FederationInternalAPI) LookupState( func (a *FederationInternalAPI) LookupStateIDs( ctx context.Context, origin, s spec.ServerName, roomID, eventID string, ) (res gomatrixserverlib.StateIDResponse, err error) { - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() if !a.IsWhitelistedOrAny(s) { return fclient.RespStateIDs{}, nil } + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID) }) @@ -161,11 +167,11 @@ func (a *FederationInternalAPI) LookupMissingEvents( ctx context.Context, origin, s spec.ServerName, roomID string, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, ) (res fclient.RespMissingEvents, err error) { - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() if !a.IsWhitelistedOrAny(s) { return fclient.RespMissingEvents{}, nil } + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion) }) @@ -178,11 +184,11 @@ func (a *FederationInternalAPI) LookupMissingEvents( func (a *FederationInternalAPI) GetEvent( ctx context.Context, origin, s spec.ServerName, eventID string, ) (res gomatrixserverlib.Transaction, err error) { - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() if !a.IsWhitelistedOrAny(s) { return gomatrixserverlib.Transaction{}, nil } + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.GetEvent(ctx, origin, s, eventID) }) @@ -195,11 +201,11 @@ func (a *FederationInternalAPI) GetEvent( func (a *FederationInternalAPI) LookupServerKeys( ctx context.Context, s spec.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, ) ([]gomatrixserverlib.ServerKeys, error) { - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() if !a.IsWhitelistedOrAny(s) { return []gomatrixserverlib.ServerKeys{}, nil } + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupServerKeys(ctx, s, keyRequests) }) @@ -213,11 +219,11 @@ func (a *FederationInternalAPI) MSC2836EventRelationships( ctx context.Context, origin, s spec.ServerName, r fclient.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion, ) (res fclient.MSC2836EventRelationshipsResponse, err error) { - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() if !a.IsWhitelistedOrAny(s) { return res, nil } + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion) }) @@ -230,11 +236,11 @@ func (a *FederationInternalAPI) MSC2836EventRelationships( func (a *FederationInternalAPI) RoomHierarchies( ctx context.Context, origin, s spec.ServerName, roomID string, suggestedOnly bool, ) (res fclient.RoomHierarchyResponse, err error) { - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() if !a.IsWhitelistedOrAny(s) { return res, nil } + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.RoomHierarchy(ctx, origin, s, roomID, suggestedOnly) })