mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-01-18 09:54:27 -06:00
Optimise QuerySharedUsers
so that we can only work on local users (#2766)
Otherwise the sync API key change consumer wastes a lot of time trying to wake up the notifiers for non-local users.
This commit is contained in:
parent
6f602bb096
commit
c85bc3434f
|
@ -278,6 +278,7 @@ type QuerySharedUsersRequest struct {
|
|||
OtherUserIDs []string
|
||||
ExcludeRoomIDs []string
|
||||
IncludeRoomIDs []string
|
||||
LocalOnly bool
|
||||
}
|
||||
|
||||
type QuerySharedUsersResponse struct {
|
||||
|
|
|
@ -799,7 +799,7 @@ func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUser
|
|||
}
|
||||
roomIDs = roomIDs[:j]
|
||||
|
||||
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs)
|
||||
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs, req.LocalOnly)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -157,7 +157,7 @@ type Database interface {
|
|||
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
|
||||
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
|
||||
// JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms.
|
||||
JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error)
|
||||
JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string, localOnly bool) (map[string]int, error)
|
||||
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
|
||||
GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
|
||||
// GetServerInRoom returns true if we think a server is in a given room or false otherwise.
|
||||
|
|
|
@ -68,14 +68,18 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
|
|||
|
||||
var selectJoinedUsersSetForRoomsAndUserSQL = "" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
|
||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
" WHERE (target_local OR $1 = false)" +
|
||||
" AND room_nid = ANY($2) AND target_nid = ANY($3)" +
|
||||
" AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
||||
" AND forgotten = false" +
|
||||
" GROUP BY target_nid"
|
||||
|
||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||
" WHERE room_nid = ANY($1) AND" +
|
||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
" WHERE (target_local OR $1 = false) " +
|
||||
" AND room_nid = ANY($2)" +
|
||||
" AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
||||
" AND forgotten = false" +
|
||||
" GROUP BY target_nid"
|
||||
|
||||
// Insert a row in to membership table so that it can be locked by the
|
||||
|
@ -334,6 +338,7 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(
|
|||
ctx context.Context, txn *sql.Tx,
|
||||
roomNIDs []types.RoomNID,
|
||||
userNIDs []types.EventStateKeyNID,
|
||||
localOnly bool,
|
||||
) (map[types.EventStateKeyNID]int, error) {
|
||||
var (
|
||||
rows *sql.Rows
|
||||
|
@ -342,9 +347,9 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(
|
|||
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
|
||||
if len(userNIDs) > 0 {
|
||||
stmt = sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsAndUserStmt)
|
||||
rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
|
||||
rows, err = stmt.QueryContext(ctx, localOnly, pq.Array(roomNIDs), pq.Array(userNIDs))
|
||||
} else {
|
||||
rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs))
|
||||
rows, err = stmt.QueryContext(ctx, localOnly, pq.Array(roomNIDs))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
|
|
@ -1280,7 +1280,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
|||
}
|
||||
|
||||
// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms.
|
||||
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) {
|
||||
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string, localOnly bool) (map[string]int, error) {
|
||||
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -1295,7 +1295,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [
|
|||
userNIDs = append(userNIDs, nid)
|
||||
nidToUserID[nid] = id
|
||||
}
|
||||
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs)
|
||||
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs, localOnly)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -44,14 +44,18 @@ const membershipSchema = `
|
|||
|
||||
var selectJoinedUsersSetForRoomsAndUserSQL = "" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||
" WHERE room_nid IN ($1) AND target_nid IN ($2) AND" +
|
||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
" WHERE (target_local OR $1 = false)" +
|
||||
" AND room_nid IN ($2) AND target_nid IN ($3)" +
|
||||
" AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
||||
" AND forgotten = false" +
|
||||
" GROUP BY target_nid"
|
||||
|
||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||
" WHERE room_nid IN ($1) AND " +
|
||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
" WHERE (target_local OR $1 = false)" +
|
||||
" AND room_nid IN ($2)" +
|
||||
" AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
||||
" AND forgotten = false" +
|
||||
" GROUP BY target_nid"
|
||||
|
||||
// Insert a row in to membership table so that it can be locked by the
|
||||
|
@ -305,8 +309,9 @@ func (s *membershipStatements) SelectRoomsWithMembership(
|
|||
return roomNIDs, nil
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) {
|
||||
params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs))
|
||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error) {
|
||||
params := make([]interface{}, 0, 1+len(roomNIDs)+len(userNIDs))
|
||||
params = append(params, localOnly)
|
||||
for _, v := range roomNIDs {
|
||||
params = append(params, v)
|
||||
}
|
||||
|
@ -314,10 +319,10 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
|
|||
params = append(params, v)
|
||||
}
|
||||
|
||||
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($2)", sqlutil.QueryVariadicOffset(len(roomNIDs), 1), 1)
|
||||
if len(userNIDs) > 0 {
|
||||
query = strings.Replace(selectJoinedUsersSetForRoomsAndUserSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
|
||||
query = strings.Replace(selectJoinedUsersSetForRoomsAndUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(roomNIDs), 1), 1)
|
||||
query = strings.Replace(query, "($3)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)+1), 1)
|
||||
}
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
|
|
|
@ -137,7 +137,7 @@ type Membership interface {
|
|||
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error)
|
||||
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
||||
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
|
||||
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error)
|
||||
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error)
|
||||
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
|
||||
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
|
||||
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
|
||||
|
|
|
@ -79,7 +79,7 @@ func TestMembershipTable(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.True(t, inRoom)
|
||||
|
||||
userJoinedToRooms, err := tab.SelectJoinedUsersSetForRooms(ctx, nil, []types.RoomNID{1}, userNIDs)
|
||||
userJoinedToRooms, err := tab.SelectJoinedUsersSetForRooms(ctx, nil, []types.RoomNID{1}, userNIDs, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(userJoinedToRooms))
|
||||
|
||||
|
|
|
@ -111,7 +111,8 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, d
|
|||
// work out who we need to notify about the new key
|
||||
var queryRes roomserverAPI.QuerySharedUsersResponse
|
||||
err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{
|
||||
UserID: output.UserID,
|
||||
UserID: output.UserID,
|
||||
LocalOnly: true,
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server")
|
||||
|
@ -135,7 +136,8 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage
|
|||
// work out who we need to notify about the new key
|
||||
var queryRes roomserverAPI.QuerySharedUsersResponse
|
||||
err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{
|
||||
UserID: output.UserID,
|
||||
UserID: output.UserID,
|
||||
LocalOnly: true,
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server")
|
||||
|
|
Loading…
Reference in a new issue