diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 46cd206a..3f0b19c3 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -264,7 +264,11 @@ func (w *worker) _next() { }) msgs, err := w.subscription.Fetch(1, nats.Context(ctx)) switch err { - case nil: + case nil, nats.ErrTimeout, context.DeadlineExceeded, context.Canceled: + // Is the server shutting down? If so, stop processing. + if w.r.ProcessContext.Context().Err() != nil { + return + } // Make sure that once we're done here, we queue up another call // to _next in the inbox. defer w.Act(nil, w._next) @@ -275,11 +279,11 @@ func (w *worker) _next() { return } - case context.DeadlineExceeded, context.Canceled: - // The context exceeded, so we've been waiting for more than a - // minute for activity in this room. At this point we will shut - // down the subscriber to free up resources. It'll get started - // again if new activity happens. + case nats.ErrConsumerDeleted, nats.ErrConsumerNotFound: + // The consumer is gone, therefore it's reached the inactivity + // threshold. Clean up and stop processing at this point, if a + // new event comes in for this room then the ordered consumer + // over the entire stream will recreate this anyway. if err = w.subscription.Unsubscribe(); err != nil { logrus.WithError(err).Errorf("Failed to unsubscribe to stream for room %q", w.roomID) } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index e5ee8f83..37de303b 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -13,7 +13,7 @@ import ( "errors" "fmt" - //"github.com/element-hq/dendrite/roomserver/internal" + // "github.com/element-hq/dendrite/roomserver/internal" "github.com/element-hq/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" @@ -747,7 +747,7 @@ func GetAuthChain( // from the database and the `eventsToFetch` will be updated with any new // events that we have learned about and need to find. When `eventsToFetch` // is eventually empty, we should have reached the end of the chain. - eventsToFetch := authEventIDs + eventsToFetch := append([]string{}, authEventIDs...) authEventsMap := make(map[string]gomatrixserverlib.PDU) for len(eventsToFetch) > 0 { @@ -779,7 +779,7 @@ func GetAuthChain( // We've now retrieved all of the events we can. Flatten them down into an // array and return them. - var authEvents []gomatrixserverlib.PDU + authEvents := make([]gomatrixserverlib.PDU, 0, len(authEventsMap)) for _, event := range authEventsMap { authEvents = append(authEvents, event) } diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go index e4b558ff..c8feaa6a 100644 --- a/syncapi/storage/postgres/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -92,6 +92,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( accountDataEventFilter *synctypes.EventFilter, ) (data map[string][]string, pos types.StreamPosition, err error) { data = make(map[string][]string) + pos = r.Low() rows, err := sqlutil.TxStmt(txn, s.selectAccountDataInRangeStmt).QueryContext( ctx, userID, r.Low(), r.High(), @@ -122,7 +123,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( pos = id } } - if pos == 0 { + if len(data) == 0 { pos = r.High() } return data, pos, rows.Err() diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index b98cf227..a84deb87 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -84,6 +84,8 @@ func (s *accountDataStatements) SelectAccountDataInRange( filter *synctypes.EventFilter, ) (data map[string][]string, pos types.StreamPosition, err error) { data = make(map[string][]string) + pos = r.Low() + stmt, params, err := prepareWithFilters( s.db, txn, selectAccountDataInRangeSQL, []interface{}{ @@ -119,7 +121,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( pos = id } } - if pos == 0 { + if len(data) == 0 { pos = r.High() } return data, pos, rows.Err()