Expand RoomInfo to cover more DB storage functions (#1377)

* Factor more things to RoomInfo

* Factor out remaining bits for RoomInfo

* Linting for now
This commit is contained in:
Kegsay 2020-09-02 10:02:48 +01:00 committed by GitHub
parent 82a9617659
commit 02a73f29f8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 148 additions and 179 deletions

View file

@ -38,27 +38,22 @@ func (r *RoomserverInternalAPI) QueryLatestEventsAndState(
request *api.QueryLatestEventsAndStateRequest,
response *api.QueryLatestEventsAndStateResponse,
) error {
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID)
roomInfo, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil {
return err
}
if roomInfo == nil || roomInfo.IsStub {
response.RoomExists = false
return nil
}
roomState := state.NewStateResolution(r.DB)
info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil {
return err
}
if info.IsStub {
return nil
}
roomState := state.NewStateResolution(r.DB, *roomInfo)
response.RoomExists = true
response.RoomVersion = roomVersion
response.RoomVersion = roomInfo.RoomVersion
var currentStateSnapshotNID types.StateSnapshotNID
response.LatestEvents, currentStateSnapshotNID, response.Depth, err =
r.DB.LatestEventIDs(ctx, info.RoomNID)
r.DB.LatestEventIDs(ctx, roomInfo.RoomNID)
if err != nil {
return err
}
@ -85,7 +80,7 @@ func (r *RoomserverInternalAPI) QueryLatestEventsAndState(
}
for _, event := range stateEvents {
response.StateEvents = append(response.StateEvents, event.Headered(roomVersion))
response.StateEvents = append(response.StateEvents, event.Headered(roomInfo.RoomVersion))
}
return nil
@ -97,23 +92,17 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents(
request *api.QueryStateAfterEventsRequest,
response *api.QueryStateAfterEventsResponse,
) error {
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID)
if err != nil {
response.RoomExists = false
return nil
}
roomState := state.NewStateResolution(r.DB)
info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil {
return err
}
if info.IsStub {
if info == nil || info.IsStub {
return nil
}
roomState := state.NewStateResolution(r.DB, *info)
response.RoomExists = true
response.RoomVersion = roomVersion
response.RoomVersion = info.RoomVersion
prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs)
if err != nil {
@ -128,7 +117,7 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents(
// Look up the currrent state for the requested tuples.
stateEntries, err := roomState.LoadStateAfterEventsForStringTuples(
ctx, info.RoomNID, prevStates, request.StateToFetch,
ctx, prevStates, request.StateToFetch,
)
if err != nil {
return err
@ -140,7 +129,7 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents(
}
for _, event := range stateEvents {
response.StateEvents = append(response.StateEvents, event.Headered(roomVersion))
response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion))
}
return nil
@ -168,7 +157,7 @@ func (r *RoomserverInternalAPI) QueryEventsByID(
}
for _, event := range events {
roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, event.RoomID())
roomVersion, verr := r.roomVersion(event.RoomID())
if verr != nil {
return verr
}
@ -277,7 +266,7 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom(
events, err = r.DB.Events(ctx, eventNIDs)
} else {
stateEntries, err = stateBeforeEvent(ctx, r.DB, membershipEventNID)
stateEntries, err = stateBeforeEvent(ctx, r.DB, *info, membershipEventNID)
if err != nil {
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
return err
@ -297,8 +286,8 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom(
return nil
}
func stateBeforeEvent(ctx context.Context, db storage.Database, eventNID types.EventNID) ([]types.StateEntry, error) {
roomState := state.NewStateResolution(db)
func stateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) {
roomState := state.NewStateResolution(db, info)
// Lookup the event NID
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
if err != nil {
@ -370,20 +359,28 @@ func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent(
response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see
return
}
isServerInRoom, err := r.isServerCurrentlyInRoom(ctx, request.ServerName, events[0].RoomID())
roomID := events[0].RoomID()
isServerInRoom, err := r.isServerCurrentlyInRoom(ctx, request.ServerName, roomID)
if err != nil {
return
}
info, err := r.DB.RoomInfo(ctx, roomID)
if err != nil {
return err
}
if info == nil {
return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID)
}
response.AllowedToSeeEvent, err = r.checkServerAllowedToSeeEvent(
ctx, request.EventID, request.ServerName, isServerInRoom,
ctx, *info, request.EventID, request.ServerName, isServerInRoom,
)
return
}
func (r *RoomserverInternalAPI) checkServerAllowedToSeeEvent(
ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
ctx context.Context, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
) (bool, error) {
roomState := state.NewStateResolution(r.DB)
roomState := state.NewStateResolution(r.DB, info)
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
if err != nil {
return false, err
@ -400,6 +397,7 @@ func (r *RoomserverInternalAPI) checkServerAllowedToSeeEvent(
}
// QueryMissingEvents implements api.RoomserverInternalAPI
// nolint:gocyclo
func (r *RoomserverInternalAPI) QueryMissingEvents(
ctx context.Context,
request *api.QueryMissingEventsRequest,
@ -418,8 +416,22 @@ func (r *RoomserverInternalAPI) QueryMissingEvents(
eventsToFilter[id] = true
}
}
events, err := r.DB.EventsFromIDs(ctx, front)
if err != nil {
return err
}
if len(events) == 0 {
return nil // we are missing the events being asked to search from, give up.
}
info, err := r.DB.RoomInfo(ctx, events[0].RoomID())
if err != nil {
return err
}
if info == nil || info.IsStub {
return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID())
}
resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName)
resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName)
if err != nil {
return err
}
@ -432,7 +444,7 @@ func (r *RoomserverInternalAPI) QueryMissingEvents(
response.Events = make([]gomatrixserverlib.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter))
for _, event := range loadedEvents {
if !eventsToFilter[event.EventID()] {
roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, event.RoomID())
roomVersion, verr := r.roomVersion(event.RoomID())
if verr != nil {
return verr
}
@ -467,8 +479,16 @@ func (r *RoomserverInternalAPI) PerformBackfill(
// this will include these events which is what we want
front = request.PrevEventIDs()
info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil {
return err
}
if info == nil || info.IsStub {
return fmt.Errorf("PerformBackfill: missing room info for room %s", request.RoomID)
}
// Scan the event tree for events to send back.
resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName)
resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName)
if err != nil {
return err
}
@ -481,19 +501,14 @@ func (r *RoomserverInternalAPI) PerformBackfill(
}
for _, event := range loadedEvents {
roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, event.RoomID())
if verr != nil {
return verr
}
response.Events = append(response.Events, event.Headered(roomVersion))
response.Events = append(response.Events, event.Headered(info.RoomVersion))
}
return err
}
func (r *RoomserverInternalAPI) backfillViaFederation(ctx context.Context, req *api.PerformBackfillRequest, res *api.PerformBackfillResponse) error {
roomVer, err := r.DB.GetRoomVersionForRoom(ctx, req.RoomID)
roomVer, err := r.roomVersion(req.RoomID)
if err != nil {
return fmt.Errorf("backfillViaFederation: unknown room version for room %s : %w", req.RoomID, err)
}
@ -642,7 +657,7 @@ func (r *RoomserverInternalAPI) fetchAndStoreMissingEvents(ctx context.Context,
// TODO: Remove this when we have tests to assert correctness of this function
// nolint:gocyclo
func (r *RoomserverInternalAPI) scanEventTree(
ctx context.Context, front []string, visited map[string]bool, limit int,
ctx context.Context, info types.RoomInfo, front []string, visited map[string]bool, limit int,
serverName gomatrixserverlib.ServerName,
) ([]types.EventNID, error) {
var resultNIDs []types.EventNID
@ -708,7 +723,7 @@ BFSLoop:
// hasn't been seen before.
if !visited[pre] {
visited[pre] = true
allowed, err = r.checkServerAllowedToSeeEvent(ctx, pre, serverName, isServerInRoom)
allowed, err = r.checkServerAllowedToSeeEvent(ctx, info, pre, serverName, isServerInRoom)
if err != nil {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
"Error checking if allowed to see event",
@ -744,13 +759,13 @@ func (r *RoomserverInternalAPI) QueryStateAndAuthChain(
if err != nil {
return err
}
if info.IsStub {
if info == nil || info.IsStub {
return nil
}
response.RoomExists = true
response.RoomVersion = info.RoomVersion
stateEvents, err := r.loadStateAtEventIDs(ctx, request.PrevEventIDs)
stateEvents, err := r.loadStateAtEventIDs(ctx, *info, request.PrevEventIDs)
if err != nil {
return err
}
@ -788,8 +803,8 @@ func (r *RoomserverInternalAPI) QueryStateAndAuthChain(
return err
}
func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) {
roomState := state.NewStateResolution(r.DB)
func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.Event, error) {
roomState := state.NewStateResolution(r.DB, roomInfo)
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
if err != nil {
switch err.(type) {
@ -937,15 +952,26 @@ func (r *RoomserverInternalAPI) QueryRoomVersionForRoom(
return nil
}
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID)
info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil {
return err
}
response.RoomVersion = roomVersion
if info == nil {
return fmt.Errorf("QueryRoomVersionForRoom: missing room info for room %s", request.RoomID)
}
response.RoomVersion = info.RoomVersion
r.Cache.StoreRoomVersion(request.RoomID, response.RoomVersion)
return nil
}
func (r *RoomserverInternalAPI) roomVersion(roomID string) (gomatrixserverlib.RoomVersion, error) {
var res api.QueryRoomVersionForRoomResponse
err := r.QueryRoomVersionForRoom(context.Background(), &api.QueryRoomVersionForRoomRequest{
RoomID: roomID,
}, &res)
return res.RoomVersion, err
}
func (r *RoomserverInternalAPI) QueryPublishedRooms(
ctx context.Context,
req *api.QueryPublishedRoomsRequest,