Consistent *sql.Tx usage across sync API (#2744)

This tidies up the `storage` package so that everything takes a
transaction parameter instead of something things that do and some that
don't.
This commit is contained in:
Neil Alexander 2022-09-28 10:18:03 +01:00 committed by GitHub
parent a574ed5369
commit 3f9e38e80a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 99 additions and 77 deletions

View file

@ -99,14 +99,15 @@ func (s *accountDataStatements) InsertAccountData(
}
func (s *accountDataStatements) SelectAccountDataInRange(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
userID string,
r types.Range,
accountDataEventFilter *gomatrixserverlib.EventFilter,
) (data map[string][]string, pos types.StreamPosition, err error) {
data = make(map[string][]string)
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High(),
rows, err := sqlutil.TxStmt(txn, s.selectAccountDataInRangeStmt).QueryContext(
ctx, userID, r.Low(), r.High(),
pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.NotTypes)),
accountDataEventFilter.Limit,

View file

@ -79,9 +79,9 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
}
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
ctx context.Context, roomID string,
ctx context.Context, txn *sql.Tx, roomID string,
) (bwExtrems map[string][]string, err error) {
rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID)
rows, err := sqlutil.TxStmt(txn, s.selectBackwardExtremitiesForRoomStmt).QueryContext(ctx, roomID)
if err != nil {
return
}

View file

@ -185,9 +185,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
func (s *currentRoomStateStatements) SelectJoinedUsers(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
) (map[string][]string, error) {
rows, err := s.selectJoinedUsersStmt.QueryContext(ctx)
rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt).QueryContext(ctx)
if err != nil {
return nil, err
}
@ -209,9 +209,9 @@ func (s *currentRoomStateStatements) SelectJoinedUsers(
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
func (s *currentRoomStateStatements) SelectJoinedUsersInRoom(
ctx context.Context, roomIDs []string,
ctx context.Context, txn *sql.Tx, roomIDs []string,
) (map[string][]string, error) {
rows, err := s.selectJoinedUsersInRoomStmt.QueryContext(ctx, pq.StringArray(roomIDs))
rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersInRoomStmt).QueryContext(ctx, pq.StringArray(roomIDs))
if err != nil {
return nil, err
}
@ -387,9 +387,9 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) {
}
func (s *currentRoomStateStatements) SelectStateEvent(
ctx context.Context, roomID, evType, stateKey string,
ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string,
) (*gomatrixserverlib.HeaderedEvent, error) {
stmt := s.selectStateEventStmt
stmt := sqlutil.TxStmt(txn, s.selectStateEventStmt)
var res []byte
err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res)
if err == sql.ErrNoRows {

View file

@ -19,6 +19,7 @@ import (
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
@ -73,11 +74,11 @@ func NewPostgresFilterTable(db *sql.DB) (tables.Filter, error) {
}
func (s *filterStatements) SelectFilter(
ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
ctx context.Context, txn *sql.Tx, target *gomatrixserverlib.Filter, localpart string, filterID string,
) error {
// Retrieve filter from database (stored as canonical JSON)
var filterData []byte
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
err := sqlutil.TxStmt(txn, s.selectFilterStmt).QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
if err != nil {
return err
}
@ -90,7 +91,7 @@ func (s *filterStatements) SelectFilter(
}
func (s *filterStatements) InsertFilter(
ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
ctx context.Context, txn *sql.Tx, filter *gomatrixserverlib.Filter, localpart string,
) (filterID string, err error) {
var existingFilterID string
@ -111,8 +112,9 @@ func (s *filterStatements) InsertFilter(
// This can result in a race condition when two clients try to insert the
// same filter and localpart at the same time, however this is not a
// problem as both calls will result in the same filterID
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
localpart, filterJSON).Scan(&existingFilterID)
err = sqlutil.TxStmt(txn, s.selectFilterIDByContentStmt).QueryRowContext(
ctx, localpart, filterJSON,
).Scan(&existingFilterID)
if err != nil && err != sql.ErrNoRows {
return "", err
}
@ -122,7 +124,7 @@ func (s *filterStatements) InsertFilter(
}
// Otherwise insert the filter and return the new ID
err = s.insertFilterStmt.QueryRowContext(ctx, filterJSON, localpart).
err = sqlutil.TxStmt(txn, s.insertFilterStmt).QueryRowContext(ctx, filterJSON, localpart).
Scan(&filterID)
return
}

View file

@ -99,7 +99,7 @@ func (s *inviteEventsStatements) InsertInviteEvent(
return
}
err = s.insertInviteEventStmt.QueryRowContext(
err = sqlutil.TxStmt(txn, s.insertInviteEventStmt).QueryRowContext(
ctx,
inviteEvent.RoomID(),
inviteEvent.EventID(),

View file

@ -222,12 +222,12 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
}.Prepare(db)
}
func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error {
func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error {
headeredJSON, err := json.Marshal(event)
if err != nil {
return err
}
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
_, err = sqlutil.TxStmt(txn, s.updateEventJSONStmt).ExecContext(ctx, headeredJSON, event.EventID())
return err
}

View file

@ -173,7 +173,7 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology(
ctx context.Context, txn *sql.Tx, eventID string,
) (pos, spos types.StreamPosition, err error) {
err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos)
err = sqlutil.TxStmt(txn, s.selectPositionInTopologyStmt).QueryRowContext(ctx, eventID).Scan(&pos, &spos)
return
}
@ -183,9 +183,9 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition(
ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, backwardOrdering bool,
) (topoPos types.StreamPosition, err error) {
if backwardOrdering {
err = s.selectStreamToTopologicalPositionDescStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionDescStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
} else {
err = s.selectStreamToTopologicalPositionAscStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionAscStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
}
return
}
@ -193,6 +193,6 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition(
func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
ctx context.Context, txn *sql.Tx, roomID string,
) (pos types.StreamPosition, spos types.StreamPosition, err error) {
err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
err = sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt).QueryRowContext(ctx, roomID).Scan(&pos, &spos)
return
}

View file

@ -152,9 +152,9 @@ func (s *peekStatements) SelectPeeksInRange(
}
func (s *peekStatements) SelectPeekingDevices(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
) (peekingDevices map[string][]types.PeekingDevice, err error) {
rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx)
rows, err := sqlutil.TxStmt(txn, s.selectPeekingDevicesStmt).QueryContext(ctx)
if err != nil {
return nil, err
}

View file

@ -104,9 +104,9 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
return
}
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) {
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) {
var lastPos types.StreamPosition
rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos)
rows, err := sqlutil.TxStmt(txn, r.selectRoomReceipts).QueryContext(ctx, pq.Array(roomIDs), streamPos)
if err != nil {
return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
}