diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index caac076f5..0d332c73a 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -152,5 +152,5 @@ type Database interface { // StoreReceipt stores new receipt events StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) // GetRoomReceipts gets all receipts for a given roomID - GetRoomReceipts(ctx context.Context, roomId string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) + GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) } diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go index 0ceafc818..5b98c7251 100644 --- a/syncapi/storage/postgres/receipt_table.go +++ b/syncapi/storage/postgres/receipt_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "fmt" + "github.com/lib/pq" "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -53,7 +54,7 @@ const upsertReceipt = "" + const selectRoomReceipts = "" + "SELECT room_id, receipt_type, user_id, event_id, receipt_ts" + " FROM syncapi_receipts" + - " WHERE room_id = $1 AND id > $2" + " WHERE room_id in $1 AND id > $2" type receiptStatements struct { db *sql.DB @@ -84,8 +85,8 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room return } -func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomId string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) { - rows, err := r.selectRoomReceipts.QueryContext(ctx, roomId, streamPos) +func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) { + rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos) if err != nil { return nil, fmt.Errorf("unable to query room receipts: %w", err) } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index e994b0d4e..899a81655 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -574,20 +574,26 @@ func (d *Database) addReceiptDeltaToResponse( joinedRoomIDs []string, res *types.Response, ) error { - // TODO: pass joinedRoomIDs to SelectRoomReceiptsAfter instead of iterating over every room - // check all joinedRooms for receipts - for _, roomID := range joinedRoomIDs { + receipts, err := d.Receipts.SelectRoomReceiptsAfter(context.TODO(), joinedRoomIDs, since.EDUPosition()) + if err != nil { + return fmt.Errorf("unable to select receipts for rooms: %w", err) + } + + // Group receipts by room, so we can create one ClientEvent for every room + receiptsByRoom := make(map[string][]eduAPI.OutputReceiptEvent) + for _, receipt := range receipts { + receiptsByRoom[receipt.RoomID] = append(receiptsByRoom[receipt.RoomID], receipt) + } + + for roomID, receipts := range receiptsByRoom { var jr types.JoinResponse var ok bool - // Check if there's already a JoinResponse for this room, - // otherwise use a new one + + // Make sure we use an existing JoinResponse if there is one. + // If not, we'll create a new one if jr, ok = res.Rooms.Join[roomID]; !ok { jr = types.JoinResponse{} } - receipts, err := d.Receipts.SelectRoomReceiptsAfter(context.TODO(), roomID, since.EDUPosition()) - if err != nil { - return err - } ev := gomatrixserverlib.ClientEvent{ Type: "m.receipt", @@ -604,19 +610,14 @@ func (d *Database) addReceiptDeltaToResponse( }, }, } + } + ev.Content, err = json.Marshal(content) + if err != nil { + return err + } - ev.Content, err = json.Marshal(content) - if err != nil { - return err - } - jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev) - res.Rooms.Join[roomID] = jr - } - // Only add new events if we didn't find the room in the map. - // If we found the room, they should already be added - if !ok { - res.Rooms.Join[roomID] = jr - } + jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev) + res.Rooms.Join[roomID] = jr } return nil @@ -1487,6 +1488,6 @@ func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId return } -func (d *Database) GetRoomReceipts(ctx context.Context, roomId string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) { - return d.Receipts.SelectRoomReceiptsAfter(ctx, roomId, streamPos) +func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) { + return d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos) } diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go index 961af79fd..b1770e801 100644 --- a/syncapi/storage/sqlite3/receipt_table.go +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "strings" "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/internal" @@ -52,7 +53,7 @@ const upsertReceipt = "" + const selectRoomReceipts = "" + "SELECT room_id, receipt_type, user_id, event_id, receipt_ts" + " FROM syncapi_receipts" + - " WHERE room_id = $1 AND id > $2" + " WHERE id > $1 and room_id in ($2)" type receiptStatements struct { db *sql.DB @@ -91,8 +92,15 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room } // SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp -func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomId string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) { - rows, err := r.selectRoomReceipts.QueryContext(ctx, roomId, streamPos) +func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) { + selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1) + + params := make([]interface{}, len(roomIDs)+1) + params[0] = streamPos + for k, v := range roomIDs { + params[k+1] = v + } + rows, err := r.db.QueryContext(ctx, selectSQL, params...) if err != nil { return nil, fmt.Errorf("unable to query room receipts: %w", err) } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index b57eef6df..f8e7a224a 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -160,5 +160,5 @@ type Filter interface { type Receipts interface { UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) - SelectRoomReceiptsAfter(ctx context.Context, roomId string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) + SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) }