Query all joined rooms instead of just one

This commit is contained in:
Till Faelligen 2020-10-17 13:47:42 +02:00
parent 9d480d58c9
commit 2f1d2f53b7
5 changed files with 41 additions and 31 deletions

View file

@ -152,5 +152,5 @@ type Database interface {
// StoreReceipt stores new receipt events // StoreReceipt stores new receipt events
StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error)
// GetRoomReceipts gets all receipts for a given roomID // GetRoomReceipts gets all receipts for a given roomID
GetRoomReceipts(ctx context.Context, roomId string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error)
} }

View file

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -53,7 +54,7 @@ const upsertReceipt = "" +
const selectRoomReceipts = "" + const selectRoomReceipts = "" +
"SELECT room_id, receipt_type, user_id, event_id, receipt_ts" + "SELECT room_id, receipt_type, user_id, event_id, receipt_ts" +
" FROM syncapi_receipts" + " FROM syncapi_receipts" +
" WHERE room_id = $1 AND id > $2" " WHERE room_id in $1 AND id > $2"
type receiptStatements struct { type receiptStatements struct {
db *sql.DB db *sql.DB
@ -84,8 +85,8 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
return return
} }
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomId string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) { func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) {
rows, err := r.selectRoomReceipts.QueryContext(ctx, roomId, streamPos) rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to query room receipts: %w", err) return nil, fmt.Errorf("unable to query room receipts: %w", err)
} }

View file

@ -574,20 +574,26 @@ func (d *Database) addReceiptDeltaToResponse(
joinedRoomIDs []string, joinedRoomIDs []string,
res *types.Response, res *types.Response,
) error { ) error {
// TODO: pass joinedRoomIDs to SelectRoomReceiptsAfter instead of iterating over every room receipts, err := d.Receipts.SelectRoomReceiptsAfter(context.TODO(), joinedRoomIDs, since.EDUPosition())
// check all joinedRooms for receipts if err != nil {
for _, roomID := range joinedRoomIDs { return fmt.Errorf("unable to select receipts for rooms: %w", err)
}
// Group receipts by room, so we can create one ClientEvent for every room
receiptsByRoom := make(map[string][]eduAPI.OutputReceiptEvent)
for _, receipt := range receipts {
receiptsByRoom[receipt.RoomID] = append(receiptsByRoom[receipt.RoomID], receipt)
}
for roomID, receipts := range receiptsByRoom {
var jr types.JoinResponse var jr types.JoinResponse
var ok bool var ok bool
// Check if there's already a JoinResponse for this room,
// otherwise use a new one // Make sure we use an existing JoinResponse if there is one.
// If not, we'll create a new one
if jr, ok = res.Rooms.Join[roomID]; !ok { if jr, ok = res.Rooms.Join[roomID]; !ok {
jr = types.JoinResponse{} jr = types.JoinResponse{}
} }
receipts, err := d.Receipts.SelectRoomReceiptsAfter(context.TODO(), roomID, since.EDUPosition())
if err != nil {
return err
}
ev := gomatrixserverlib.ClientEvent{ ev := gomatrixserverlib.ClientEvent{
Type: "m.receipt", Type: "m.receipt",
@ -604,20 +610,15 @@ func (d *Database) addReceiptDeltaToResponse(
}, },
}, },
} }
}
ev.Content, err = json.Marshal(content) ev.Content, err = json.Marshal(content)
if err != nil { if err != nil {
return err return err
} }
jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev) jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev)
res.Rooms.Join[roomID] = jr res.Rooms.Join[roomID] = jr
} }
// Only add new events if we didn't find the room in the map.
// If we found the room, they should already be added
if !ok {
res.Rooms.Join[roomID] = jr
}
}
return nil return nil
} }
@ -1487,6 +1488,6 @@ func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId
return return
} }
func (d *Database) GetRoomReceipts(ctx context.Context, roomId string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) { func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) {
return d.Receipts.SelectRoomReceiptsAfter(ctx, roomId, streamPos) return d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos)
} }

View file

@ -18,6 +18,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings"
"github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
@ -52,7 +53,7 @@ const upsertReceipt = "" +
const selectRoomReceipts = "" + const selectRoomReceipts = "" +
"SELECT room_id, receipt_type, user_id, event_id, receipt_ts" + "SELECT room_id, receipt_type, user_id, event_id, receipt_ts" +
" FROM syncapi_receipts" + " FROM syncapi_receipts" +
" WHERE room_id = $1 AND id > $2" " WHERE id > $1 and room_id in ($2)"
type receiptStatements struct { type receiptStatements struct {
db *sql.DB db *sql.DB
@ -91,8 +92,15 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
} }
// SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp // SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomId string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) { func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) {
rows, err := r.selectRoomReceipts.QueryContext(ctx, roomId, streamPos) selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1)
params := make([]interface{}, len(roomIDs)+1)
params[0] = streamPos
for k, v := range roomIDs {
params[k+1] = v
}
rows, err := r.db.QueryContext(ctx, selectSQL, params...)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to query room receipts: %w", err) return nil, fmt.Errorf("unable to query room receipts: %w", err)
} }

View file

@ -160,5 +160,5 @@ type Filter interface {
type Receipts interface { type Receipts interface {
UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error)
SelectRoomReceiptsAfter(ctx context.Context, roomId string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error)
} }