Fix federationclient whitelist checks, improve performance

This commit is contained in:
enaix 2025-01-17 14:55:11 +03:00
parent 94deed77ec
commit b5ea31ff5c
2 changed files with 32 additions and 25 deletions

View file

@ -112,10 +112,11 @@ func NewFederationInternalAPI(
} }
} }
// IsWhitelistedOrAny checks if the server is whitelisted or the whitelist is disabled, so we can connect to any server // IsWhitelistedOrAny checks if the server is whitelisted or the whitelist is disabled (we can connect to any server)
func (a *FederationInternalAPI) IsWhitelistedOrAny(s spec.ServerName) bool { func (a *FederationInternalAPI) IsWhitelistedOrAny(s spec.ServerName) bool {
stats := a.statistics.ForServer(s) // Thread-safe, since DB access is performed in mutex and stats.Whitelisted is constant
return stats.Whitelisted() || !a.cfg.EnableWhitelist stats := a.statistics.ForServer(s) // Calls mutex if the stats do not exist yet
return !a.cfg.EnableWhitelist || stats.Whitelisted() // Lazy eval
} }
func (a *FederationInternalAPI) IsBlacklistedOrBackingOff(s spec.ServerName) (*statistics.ServerStatistics, error) { func (a *FederationInternalAPI) IsBlacklistedOrBackingOff(s spec.ServerName) (*statistics.ServerStatistics, error) {

View file

@ -17,6 +17,9 @@ const defaultTimeout = time.Second * 30
func (a *FederationInternalAPI) MakeJoin( func (a *FederationInternalAPI) MakeJoin(
ctx context.Context, origin, s spec.ServerName, roomID, userID string, ctx context.Context, origin, s spec.ServerName, roomID, userID string,
) (res gomatrixserverlib.MakeJoinResponse, err error) { ) (res gomatrixserverlib.MakeJoinResponse, err error) {
if !a.IsWhitelistedOrAny(s) {
return &fclient.RespMakeJoin{}, nil
} // Is thread-safe, so we can omit ctx call
ctx, cancel := context.WithTimeout(ctx, defaultTimeout) ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel() defer cancel()
ires, err := a.federation.MakeJoin(ctx, origin, s, roomID, userID) ires, err := a.federation.MakeJoin(ctx, origin, s, roomID, userID)
@ -29,6 +32,9 @@ func (a *FederationInternalAPI) MakeJoin(
func (a *FederationInternalAPI) SendJoin( func (a *FederationInternalAPI) SendJoin(
ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU, ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU,
) (res gomatrixserverlib.SendJoinResponse, err error) { ) (res gomatrixserverlib.SendJoinResponse, err error) {
if !a.IsWhitelistedOrAny(s) {
return &fclient.RespSendJoin{}, nil
}
ctx, cancel := context.WithTimeout(ctx, time.Minute*5) ctx, cancel := context.WithTimeout(ctx, time.Minute*5)
defer cancel() defer cancel()
ires, err := a.federation.SendJoin(ctx, origin, s, event) ires, err := a.federation.SendJoin(ctx, origin, s, event)
@ -42,11 +48,11 @@ func (a *FederationInternalAPI) GetEventAuth(
ctx context.Context, origin, s spec.ServerName, ctx context.Context, origin, s spec.ServerName,
roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string,
) (res fclient.RespEventAuth, err error) { ) (res fclient.RespEventAuth, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
if !a.IsWhitelistedOrAny(s) { if !a.IsWhitelistedOrAny(s) {
return fclient.RespEventAuth{}, nil return fclient.RespEventAuth{}, nil
} }
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID) return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID)
}) })
@ -59,11 +65,11 @@ func (a *FederationInternalAPI) GetEventAuth(
func (a *FederationInternalAPI) GetUserDevices( func (a *FederationInternalAPI) GetUserDevices(
ctx context.Context, origin, s spec.ServerName, userID string, ctx context.Context, origin, s spec.ServerName, userID string,
) (fclient.RespUserDevices, error) { ) (fclient.RespUserDevices, error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
if !a.IsWhitelistedOrAny(s) { if !a.IsWhitelistedOrAny(s) {
return fclient.RespUserDevices{}, nil return fclient.RespUserDevices{}, nil
} }
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetUserDevices(ctx, origin, s, userID) return a.federation.GetUserDevices(ctx, origin, s, userID)
}) })
@ -76,11 +82,11 @@ func (a *FederationInternalAPI) GetUserDevices(
func (a *FederationInternalAPI) ClaimKeys( func (a *FederationInternalAPI) ClaimKeys(
ctx context.Context, origin, s spec.ServerName, oneTimeKeys map[string]map[string]string, ctx context.Context, origin, s spec.ServerName, oneTimeKeys map[string]map[string]string,
) (fclient.RespClaimKeys, error) { ) (fclient.RespClaimKeys, error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
if !a.IsWhitelistedOrAny(s) { if !a.IsWhitelistedOrAny(s) {
return fclient.RespClaimKeys{}, nil return fclient.RespClaimKeys{}, nil
} }
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys) return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys)
}) })
@ -108,11 +114,11 @@ func (a *FederationInternalAPI) QueryKeys(
func (a *FederationInternalAPI) Backfill( func (a *FederationInternalAPI) Backfill(
ctx context.Context, origin, s spec.ServerName, roomID string, limit int, eventIDs []string, ctx context.Context, origin, s spec.ServerName, roomID string, limit int, eventIDs []string,
) (res gomatrixserverlib.Transaction, err error) { ) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
if !a.IsWhitelistedOrAny(s) { if !a.IsWhitelistedOrAny(s) {
return gomatrixserverlib.Transaction{}, nil return gomatrixserverlib.Transaction{}, nil
} }
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs) return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs)
}) })
@ -125,11 +131,11 @@ func (a *FederationInternalAPI) Backfill(
func (a *FederationInternalAPI) LookupState( func (a *FederationInternalAPI) LookupState(
ctx context.Context, origin, s spec.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, origin, s spec.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.StateResponse, err error) { ) (res gomatrixserverlib.StateResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
if !a.IsWhitelistedOrAny(s) { if !a.IsWhitelistedOrAny(s) {
return &fclient.RespState{}, nil return &fclient.RespState{}, nil
} }
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion) return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion)
}) })
@ -143,11 +149,11 @@ func (a *FederationInternalAPI) LookupState(
func (a *FederationInternalAPI) LookupStateIDs( func (a *FederationInternalAPI) LookupStateIDs(
ctx context.Context, origin, s spec.ServerName, roomID, eventID string, ctx context.Context, origin, s spec.ServerName, roomID, eventID string,
) (res gomatrixserverlib.StateIDResponse, err error) { ) (res gomatrixserverlib.StateIDResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
if !a.IsWhitelistedOrAny(s) { if !a.IsWhitelistedOrAny(s) {
return fclient.RespStateIDs{}, nil return fclient.RespStateIDs{}, nil
} }
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID) return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID)
}) })
@ -161,11 +167,11 @@ func (a *FederationInternalAPI) LookupMissingEvents(
ctx context.Context, origin, s spec.ServerName, roomID string, ctx context.Context, origin, s spec.ServerName, roomID string,
missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion,
) (res fclient.RespMissingEvents, err error) { ) (res fclient.RespMissingEvents, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
if !a.IsWhitelistedOrAny(s) { if !a.IsWhitelistedOrAny(s) {
return fclient.RespMissingEvents{}, nil return fclient.RespMissingEvents{}, nil
} }
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion) return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion)
}) })
@ -178,11 +184,11 @@ func (a *FederationInternalAPI) LookupMissingEvents(
func (a *FederationInternalAPI) GetEvent( func (a *FederationInternalAPI) GetEvent(
ctx context.Context, origin, s spec.ServerName, eventID string, ctx context.Context, origin, s spec.ServerName, eventID string,
) (res gomatrixserverlib.Transaction, err error) { ) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
if !a.IsWhitelistedOrAny(s) { if !a.IsWhitelistedOrAny(s) {
return gomatrixserverlib.Transaction{}, nil return gomatrixserverlib.Transaction{}, nil
} }
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetEvent(ctx, origin, s, eventID) return a.federation.GetEvent(ctx, origin, s, eventID)
}) })
@ -195,11 +201,11 @@ func (a *FederationInternalAPI) GetEvent(
func (a *FederationInternalAPI) LookupServerKeys( func (a *FederationInternalAPI) LookupServerKeys(
ctx context.Context, s spec.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, ctx context.Context, s spec.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp,
) ([]gomatrixserverlib.ServerKeys, error) { ) ([]gomatrixserverlib.ServerKeys, error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
if !a.IsWhitelistedOrAny(s) { if !a.IsWhitelistedOrAny(s) {
return []gomatrixserverlib.ServerKeys{}, nil return []gomatrixserverlib.ServerKeys{}, nil
} }
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupServerKeys(ctx, s, keyRequests) return a.federation.LookupServerKeys(ctx, s, keyRequests)
}) })
@ -213,11 +219,11 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
ctx context.Context, origin, s spec.ServerName, r fclient.MSC2836EventRelationshipsRequest, ctx context.Context, origin, s spec.ServerName, r fclient.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion, roomVersion gomatrixserverlib.RoomVersion,
) (res fclient.MSC2836EventRelationshipsResponse, err error) { ) (res fclient.MSC2836EventRelationshipsResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
if !a.IsWhitelistedOrAny(s) { if !a.IsWhitelistedOrAny(s) {
return res, nil return res, nil
} }
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion) return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion)
}) })
@ -230,11 +236,11 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
func (a *FederationInternalAPI) RoomHierarchies( func (a *FederationInternalAPI) RoomHierarchies(
ctx context.Context, origin, s spec.ServerName, roomID string, suggestedOnly bool, ctx context.Context, origin, s spec.ServerName, roomID string, suggestedOnly bool,
) (res fclient.RoomHierarchyResponse, err error) { ) (res fclient.RoomHierarchyResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
if !a.IsWhitelistedOrAny(s) { if !a.IsWhitelistedOrAny(s) {
return res, nil return res, nil
} }
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.RoomHierarchy(ctx, origin, s, roomID, suggestedOnly) return a.federation.RoomHierarchy(ctx, origin, s, roomID, suggestedOnly)
}) })