Query all joined rooms instead of just one

This commit is contained in:
Till Faelligen 2020-10-17 13:47:42 +02:00
parent 9d480d58c9
commit 2f1d2f53b7
5 changed files with 41 additions and 31 deletions

View file

@ -152,5 +152,5 @@ type Database interface {
// StoreReceipt stores new receipt events
StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error)
// GetRoomReceipts gets all receipts for a given roomID
GetRoomReceipts(ctx context.Context, roomId string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error)
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error)
}

View file

@ -19,6 +19,7 @@ import (
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
@ -53,7 +54,7 @@ const upsertReceipt = "" +
const selectRoomReceipts = "" +
"SELECT room_id, receipt_type, user_id, event_id, receipt_ts" +
" FROM syncapi_receipts" +
" WHERE room_id = $1 AND id > $2"
" WHERE room_id in $1 AND id > $2"
type receiptStatements struct {
db *sql.DB
@ -84,8 +85,8 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
return
}
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomId string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) {
rows, err := r.selectRoomReceipts.QueryContext(ctx, roomId, streamPos)
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) {
rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos)
if err != nil {
return nil, fmt.Errorf("unable to query room receipts: %w", err)
}

View file

@ -574,20 +574,26 @@ func (d *Database) addReceiptDeltaToResponse(
joinedRoomIDs []string,
res *types.Response,
) error {
// TODO: pass joinedRoomIDs to SelectRoomReceiptsAfter instead of iterating over every room
// check all joinedRooms for receipts
for _, roomID := range joinedRoomIDs {
receipts, err := d.Receipts.SelectRoomReceiptsAfter(context.TODO(), joinedRoomIDs, since.EDUPosition())
if err != nil {
return fmt.Errorf("unable to select receipts for rooms: %w", err)
}
// Group receipts by room, so we can create one ClientEvent for every room
receiptsByRoom := make(map[string][]eduAPI.OutputReceiptEvent)
for _, receipt := range receipts {
receiptsByRoom[receipt.RoomID] = append(receiptsByRoom[receipt.RoomID], receipt)
}
for roomID, receipts := range receiptsByRoom {
var jr types.JoinResponse
var ok bool
// Check if there's already a JoinResponse for this room,
// otherwise use a new one
// Make sure we use an existing JoinResponse if there is one.
// If not, we'll create a new one
if jr, ok = res.Rooms.Join[roomID]; !ok {
jr = types.JoinResponse{}
}
receipts, err := d.Receipts.SelectRoomReceiptsAfter(context.TODO(), roomID, since.EDUPosition())
if err != nil {
return err
}
ev := gomatrixserverlib.ClientEvent{
Type: "m.receipt",
@ -604,19 +610,14 @@ func (d *Database) addReceiptDeltaToResponse(
},
},
}
}
ev.Content, err = json.Marshal(content)
if err != nil {
return err
}
ev.Content, err = json.Marshal(content)
if err != nil {
return err
}
jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev)
res.Rooms.Join[roomID] = jr
}
// Only add new events if we didn't find the room in the map.
// If we found the room, they should already be added
if !ok {
res.Rooms.Join[roomID] = jr
}
jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev)
res.Rooms.Join[roomID] = jr
}
return nil
@ -1487,6 +1488,6 @@ func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId
return
}
func (d *Database) GetRoomReceipts(ctx context.Context, roomId string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) {
return d.Receipts.SelectRoomReceiptsAfter(ctx, roomId, streamPos)
func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) {
return d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos)
}

View file

@ -18,6 +18,7 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/internal"
@ -52,7 +53,7 @@ const upsertReceipt = "" +
const selectRoomReceipts = "" +
"SELECT room_id, receipt_type, user_id, event_id, receipt_ts" +
" FROM syncapi_receipts" +
" WHERE room_id = $1 AND id > $2"
" WHERE id > $1 and room_id in ($2)"
type receiptStatements struct {
db *sql.DB
@ -91,8 +92,15 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
}
// SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomId string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) {
rows, err := r.selectRoomReceipts.QueryContext(ctx, roomId, streamPos)
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) {
selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1)
params := make([]interface{}, len(roomIDs)+1)
params[0] = streamPos
for k, v := range roomIDs {
params[k+1] = v
}
rows, err := r.db.QueryContext(ctx, selectSQL, params...)
if err != nil {
return nil, fmt.Errorf("unable to query room receipts: %w", err)
}

View file

@ -160,5 +160,5 @@ type Filter interface {
type Receipts interface {
UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error)
SelectRoomReceiptsAfter(ctx context.Context, roomId string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error)
SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error)
}