Copy and simplify GetBulkStateContent for ACLs

This commit is contained in:
Till Faelligen 2024-12-19 20:06:39 +01:00
parent e0b153930c
commit e2fd591d9f
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E
4 changed files with 65 additions and 9 deletions

View file

@ -17,7 +17,6 @@ import (
"time" "time"
"github.com/element-hq/dendrite/roomserver/storage/tables" "github.com/element-hq/dendrite/roomserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -28,9 +27,8 @@ type ServerACLDatabase interface {
// RoomsWithACLs returns all room IDs for rooms with ACLs // RoomsWithACLs returns all room IDs for rooms with ACLs
RoomsWithACLs(ctx context.Context) ([]string, error) RoomsWithACLs(ctx context.Context) ([]string, error)
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // GetBulkStateACLs returns all server ACLs for the given rooms.
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. GetBulkStateACLs(ctx context.Context, roomIDs []string) ([]tables.StrippedEvent, error)
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
} }
type ServerACLs struct { type ServerACLs struct {
@ -59,7 +57,7 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs {
aclRegexCache: make(map[string]**regexp.Regexp, 100), aclRegexCache: make(map[string]**regexp.Regexp, 100),
} }
// Look up all of the rooms that the current state server knows about. // Look up all rooms with ACLs.
rooms, err := db.RoomsWithACLs(ctx) rooms, err := db.RoomsWithACLs(ctx)
if err != nil { if err != nil {
logrus.WithError(err).Fatalf("Failed to get known rooms") logrus.WithError(err).Fatalf("Failed to get known rooms")
@ -68,7 +66,7 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs {
// do then we'll process it into memory so that we have the regexes to // do then we'll process it into memory so that we have the regexes to
// hand. // hand.
events, err := db.GetBulkStateContent(ctx, rooms, []gomatrixserverlib.StateKeyTuple{{EventType: MRoomServerACL, StateKey: ""}}, false) events, err := db.GetBulkStateACLs(ctx, rooms)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("Failed to get server ACLs for all rooms: %q", err) logrus.WithError(err).Errorf("Failed to get server ACLs for all rooms: %q", err)
} }

View file

@ -12,7 +12,6 @@ import (
"testing" "testing"
"github.com/element-hq/dendrite/roomserver/storage/tables" "github.com/element-hq/dendrite/roomserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -108,11 +107,11 @@ var (
type dummyACLDB struct{} type dummyACLDB struct{}
func (d dummyACLDB) RoomsWithACLs(ctx context.Context) ([]string, error) { func (d dummyACLDB) RoomsWithACLs(_ context.Context) ([]string, error) {
return []string{"1", "2"}, nil return []string{"1", "2"}, nil
} }
func (d dummyACLDB) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) { func (d dummyACLDB) GetBulkStateACLs(_ context.Context, _ []string) ([]tables.StrippedEvent, error) {
return []tables.StrippedEvent{ return []tables.StrippedEvent{
{ {
RoomID: "1", RoomID: "1",

View file

@ -187,6 +187,8 @@ type Database interface {
// RoomsWithACLs returns all room IDs for rooms with ACLs // RoomsWithACLs returns all room IDs for rooms with ACLs
RoomsWithACLs(ctx context.Context) ([]string, error) RoomsWithACLs(ctx context.Context) ([]string, error)
// GetBulkStateACLs returns all server ACLs for the given rooms.
GetBulkStateACLs(ctx context.Context, roomIDs []string) ([]tables.StrippedEvent, error)
QueryAdminEventReports(ctx context.Context, from uint64, limit uint64, backwards bool, userID string, roomID string) ([]api.QueryAdminEventReportsResponse, int64, error) QueryAdminEventReports(ctx context.Context, from uint64, limit uint64, backwards bool, userID string, roomID string) ([]api.QueryAdminEventReportsResponse, int64, error)
QueryAdminEventReport(ctx context.Context, reportID uint64) (api.QueryAdminEventReportResponse, error) QueryAdminEventReport(ctx context.Context, reportID uint64) (api.QueryAdminEventReportResponse, error)
AdminDeleteEventReport(ctx context.Context, reportID uint64) error AdminDeleteEventReport(ctx context.Context, reportID uint64) error

View file

@ -1437,6 +1437,63 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID spec.UserID,
return roomIDs, nil return roomIDs, nil
} }
// GetBulkStateACLs is a lighter weight form of GetBulkStateContent, which only returns ACL state events.
func (d *Database) GetBulkStateACLs(ctx context.Context, roomIDs []string) ([]tables.StrippedEvent, error) {
tuples := []gomatrixserverlib.StateKeyTuple{{EventType: "m.room.server_acl", StateKey: ""}}
var eventNIDs []types.EventNID
eventNIDToVer := make(map[types.EventNID]gomatrixserverlib.RoomVersion)
// TODO: This feels like this is going to be really slow...
for _, roomID := range roomIDs {
roomInfo, err2 := d.roomInfo(ctx, nil, roomID)
if err2 != nil {
return nil, fmt.Errorf("GetBulkStateACLs: failed to load room info for room %s : %w", roomID, err2)
}
// for unknown rooms or rooms which we don't have the current state, skip them.
if roomInfo == nil || roomInfo.IsStub() {
continue
}
// No querier needed, as we don't actually do state resolution
stateRes := state.NewStateResolution(d, roomInfo, nil)
entries, err2 := stateRes.LoadStateAtSnapshotForStringTuples(ctx, roomInfo.StateSnapshotNID(), tuples)
if err2 != nil {
return nil, fmt.Errorf("GetBulkStateACLs: failed to load state for room %s : %w", roomID, err2)
}
for _, entry := range entries {
eventNIDs = append(eventNIDs, entry.EventNID)
eventNIDToVer[entry.EventNID] = roomInfo.RoomVersion
}
}
eventIDs, err := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs)
if err != nil {
return nil, fmt.Errorf("GetBulkStateACLs: failed to load event JSON for event nids: %w", err)
}
result := make([]tables.StrippedEvent, len(events))
for i := range events {
roomVer := eventNIDToVer[events[i].EventNID]
verImpl, err := gomatrixserverlib.GetRoomVersion(roomVer)
if err != nil {
return nil, err
}
ev, err := verImpl.NewEventFromTrustedJSONWithEventID(eventIDs[events[i].EventNID], events[i].EventJSON, false)
if err != nil {
return nil, fmt.Errorf("GetBulkStateACLs: failed to load event JSON for event NID %v : %w", events[i].EventNID, err)
}
result[i] = tables.StrippedEvent{
EventType: ev.Type(),
RoomID: ev.RoomID().String(),
StateKey: *ev.StateKey(),
ContentValue: tables.ExtractContentValue(&types.HeaderedEvent{PDU: ev}),
}
}
return result, nil
}
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) { func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) {