Optimise QuerySharedUsers
so that we can only work on local users (#2766)
Otherwise the sync API key change consumer wastes a lot of time trying to wake up the notifiers for non-local users.
This commit is contained in:
parent
6f602bb096
commit
c85bc3434f
|
@ -278,6 +278,7 @@ type QuerySharedUsersRequest struct {
|
||||||
OtherUserIDs []string
|
OtherUserIDs []string
|
||||||
ExcludeRoomIDs []string
|
ExcludeRoomIDs []string
|
||||||
IncludeRoomIDs []string
|
IncludeRoomIDs []string
|
||||||
|
LocalOnly bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type QuerySharedUsersResponse struct {
|
type QuerySharedUsersResponse struct {
|
||||||
|
|
|
@ -799,7 +799,7 @@ func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUser
|
||||||
}
|
}
|
||||||
roomIDs = roomIDs[:j]
|
roomIDs = roomIDs[:j]
|
||||||
|
|
||||||
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs)
|
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs, req.LocalOnly)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -157,7 +157,7 @@ type Database interface {
|
||||||
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
|
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
|
||||||
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
|
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
|
||||||
// JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms.
|
// JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms.
|
||||||
JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error)
|
JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string, localOnly bool) (map[string]int, error)
|
||||||
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
|
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
|
||||||
GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
|
GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
|
||||||
// GetServerInRoom returns true if we think a server is in a given room or false otherwise.
|
// GetServerInRoom returns true if we think a server is in a given room or false otherwise.
|
||||||
|
|
|
@ -68,14 +68,18 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
|
||||||
|
|
||||||
var selectJoinedUsersSetForRoomsAndUserSQL = "" +
|
var selectJoinedUsersSetForRoomsAndUserSQL = "" +
|
||||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||||
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
|
" WHERE (target_local OR $1 = false)" +
|
||||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
" AND room_nid = ANY($2) AND target_nid = ANY($3)" +
|
||||||
|
" AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
||||||
|
" AND forgotten = false" +
|
||||||
" GROUP BY target_nid"
|
" GROUP BY target_nid"
|
||||||
|
|
||||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||||
" WHERE room_nid = ANY($1) AND" +
|
" WHERE (target_local OR $1 = false) " +
|
||||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
" AND room_nid = ANY($2)" +
|
||||||
|
" AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
||||||
|
" AND forgotten = false" +
|
||||||
" GROUP BY target_nid"
|
" GROUP BY target_nid"
|
||||||
|
|
||||||
// Insert a row in to membership table so that it can be locked by the
|
// Insert a row in to membership table so that it can be locked by the
|
||||||
|
@ -334,6 +338,7 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
roomNIDs []types.RoomNID,
|
roomNIDs []types.RoomNID,
|
||||||
userNIDs []types.EventStateKeyNID,
|
userNIDs []types.EventStateKeyNID,
|
||||||
|
localOnly bool,
|
||||||
) (map[types.EventStateKeyNID]int, error) {
|
) (map[types.EventStateKeyNID]int, error) {
|
||||||
var (
|
var (
|
||||||
rows *sql.Rows
|
rows *sql.Rows
|
||||||
|
@ -342,9 +347,9 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
|
||||||
if len(userNIDs) > 0 {
|
if len(userNIDs) > 0 {
|
||||||
stmt = sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsAndUserStmt)
|
stmt = sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsAndUserStmt)
|
||||||
rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
|
rows, err = stmt.QueryContext(ctx, localOnly, pq.Array(roomNIDs), pq.Array(userNIDs))
|
||||||
} else {
|
} else {
|
||||||
rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs))
|
rows, err = stmt.QueryContext(ctx, localOnly, pq.Array(roomNIDs))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -1280,7 +1280,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
||||||
}
|
}
|
||||||
|
|
||||||
// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms.
|
// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms.
|
||||||
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) {
|
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string, localOnly bool) (map[string]int, error) {
|
||||||
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
|
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -1295,7 +1295,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [
|
||||||
userNIDs = append(userNIDs, nid)
|
userNIDs = append(userNIDs, nid)
|
||||||
nidToUserID[nid] = id
|
nidToUserID[nid] = id
|
||||||
}
|
}
|
||||||
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs)
|
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs, localOnly)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,14 +44,18 @@ const membershipSchema = `
|
||||||
|
|
||||||
var selectJoinedUsersSetForRoomsAndUserSQL = "" +
|
var selectJoinedUsersSetForRoomsAndUserSQL = "" +
|
||||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||||
" WHERE room_nid IN ($1) AND target_nid IN ($2) AND" +
|
" WHERE (target_local OR $1 = false)" +
|
||||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
" AND room_nid IN ($2) AND target_nid IN ($3)" +
|
||||||
|
" AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
||||||
|
" AND forgotten = false" +
|
||||||
" GROUP BY target_nid"
|
" GROUP BY target_nid"
|
||||||
|
|
||||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||||
" WHERE room_nid IN ($1) AND " +
|
" WHERE (target_local OR $1 = false)" +
|
||||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
" AND room_nid IN ($2)" +
|
||||||
|
" AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
||||||
|
" AND forgotten = false" +
|
||||||
" GROUP BY target_nid"
|
" GROUP BY target_nid"
|
||||||
|
|
||||||
// Insert a row in to membership table so that it can be locked by the
|
// Insert a row in to membership table so that it can be locked by the
|
||||||
|
@ -305,8 +309,9 @@ func (s *membershipStatements) SelectRoomsWithMembership(
|
||||||
return roomNIDs, nil
|
return roomNIDs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) {
|
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error) {
|
||||||
params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs))
|
params := make([]interface{}, 0, 1+len(roomNIDs)+len(userNIDs))
|
||||||
|
params = append(params, localOnly)
|
||||||
for _, v := range roomNIDs {
|
for _, v := range roomNIDs {
|
||||||
params = append(params, v)
|
params = append(params, v)
|
||||||
}
|
}
|
||||||
|
@ -314,10 +319,10 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
|
||||||
params = append(params, v)
|
params = append(params, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($2)", sqlutil.QueryVariadicOffset(len(roomNIDs), 1), 1)
|
||||||
if len(userNIDs) > 0 {
|
if len(userNIDs) > 0 {
|
||||||
query = strings.Replace(selectJoinedUsersSetForRoomsAndUserSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
query = strings.Replace(selectJoinedUsersSetForRoomsAndUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(roomNIDs), 1), 1)
|
||||||
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
|
query = strings.Replace(query, "($3)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)+1), 1)
|
||||||
}
|
}
|
||||||
var rows *sql.Rows
|
var rows *sql.Rows
|
||||||
var err error
|
var err error
|
||||||
|
|
|
@ -137,7 +137,7 @@ type Membership interface {
|
||||||
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error)
|
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error)
|
||||||
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
||||||
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
|
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
|
||||||
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error)
|
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error)
|
||||||
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
|
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
|
||||||
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
|
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
|
||||||
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
|
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
|
||||||
|
|
|
@ -79,7 +79,7 @@ func TestMembershipTable(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.True(t, inRoom)
|
assert.True(t, inRoom)
|
||||||
|
|
||||||
userJoinedToRooms, err := tab.SelectJoinedUsersSetForRooms(ctx, nil, []types.RoomNID{1}, userNIDs)
|
userJoinedToRooms, err := tab.SelectJoinedUsersSetForRooms(ctx, nil, []types.RoomNID{1}, userNIDs, false)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 1, len(userJoinedToRooms))
|
assert.Equal(t, 1, len(userJoinedToRooms))
|
||||||
|
|
||||||
|
|
|
@ -111,7 +111,8 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, d
|
||||||
// work out who we need to notify about the new key
|
// work out who we need to notify about the new key
|
||||||
var queryRes roomserverAPI.QuerySharedUsersResponse
|
var queryRes roomserverAPI.QuerySharedUsersResponse
|
||||||
err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{
|
err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{
|
||||||
UserID: output.UserID,
|
UserID: output.UserID,
|
||||||
|
LocalOnly: true,
|
||||||
}, &queryRes)
|
}, &queryRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server")
|
logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server")
|
||||||
|
@ -135,7 +136,8 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage
|
||||||
// work out who we need to notify about the new key
|
// work out who we need to notify about the new key
|
||||||
var queryRes roomserverAPI.QuerySharedUsersResponse
|
var queryRes roomserverAPI.QuerySharedUsersResponse
|
||||||
err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{
|
err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{
|
||||||
UserID: output.UserID,
|
UserID: output.UserID,
|
||||||
|
LocalOnly: true,
|
||||||
}, &queryRes)
|
}, &queryRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server")
|
logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server")
|
||||||
|
|
Loading…
Reference in a new issue