diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index ff3ed7e5d..3953586b2 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -48,7 +48,7 @@ func (r *Inputer) updateMemberships( // Load the event JSON so we can look up the "membership" key. // TODO: Maybe add a membership key to the events table so we can load that // key without having to load the entire event JSON? - events, err := r.DB.Events(ctx, eventNIDs) + events, err := updater.Events(ctx, eventNIDs) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index ece1d9e3c..c136f039a 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -74,7 +74,7 @@ const insertEventSQL = "" + "INSERT INTO roomserver_events AS e (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + " ON CONFLICT ON CONSTRAINT roomserver_event_id_unique DO UPDATE" + - " SET is_rejected = $8 WHERE e.is_rejected = FALSE" + + " SET is_rejected = $8 WHERE e.event_id = $4 AND e.is_rejected = FALSE" + " RETURNING event_nid, state_snapshot_nid" const selectEventSQL = "" + @@ -192,7 +192,8 @@ func (s *eventStatements) InsertEvent( ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 var stateNID int64 - err := s.insertEventStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.insertEventStmt) + err := stmt.QueryRowContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, isRejected, diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index f1f589a31..66ac2f5b6 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -136,7 +136,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd } // Look up the NID of the new join event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}) if err != nil { return fmt.Errorf("u.d.EventNIDs: %w", err) } @@ -170,7 +170,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s } // Look up the NID of the new leave event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}) if err != nil { return fmt.Errorf("u.d.EventNIDs: %w", err) } @@ -196,7 +196,7 @@ func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, er } if u.membership != tables.MembershipStateKnock { // Look up the NID of the new knock event - nIDs, err := u.d.EventNIDs(u.ctx, []string{event.EventID()}) + nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}) if err != nil { return fmt.Errorf("u.d.EventNIDs: %w", err) }