Fix QuerySharedUsers if no UserIDs are supplied

This commit is contained in:
Till Faelligen 2022-07-05 11:02:51 +02:00
parent 0848c83501
commit d3c7b4195d
5 changed files with 113 additions and 5 deletions

View file

@ -31,6 +31,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/producers"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
@ -69,6 +70,7 @@ import (
// or C.
type Inputer struct {
Cfg *config.RoomServer
Base *base.BaseDendrite
ProcessContext *process.ProcessContext
DB storage.Database
NATSClient *nats.Conn
@ -160,7 +162,9 @@ func (r *Inputer) startWorkerForRoom(roomID string) {
// will look to see if we have a worker for that room which has its
// own consumer. If we don't, we'll start one.
func (r *Inputer) Start() error {
prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration)
if r.Base.EnableMetrics {
prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration)
}
_, err := r.JetStream.Subscribe(
"", // This is blank because we specified it in BindStream.
func(m *nats.Msg) {

View file

@ -0,0 +1,69 @@
package roomserver_test
import (
"context"
"testing"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib"
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) {
base, close := testrig.CreateBaseDendrite(t, dbType)
db, err := storage.Open(base, &base.Cfg.KeyServer.Database, base.Caches)
if err != nil {
t.Fatalf("failed to create Database: %v", err)
}
return base, db, close
}
func Test_SharedUsers(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
// Invite and join Bob
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "invite",
}, test.WithStateKey(bob.ID))
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(bob.ID))
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, _, close := mustCreateDatabase(t, dbType)
defer close()
rsAPI := roomserver.NewInternalAPI(base)
// SetFederationAPI starts the room event input consumer
rsAPI.SetFederationAPI(nil, nil)
// Create the room
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err)
}
// Query the shared users for Alice, there should only be Bob.
// This is used by the SyncAPI keychange consumer.
res := &api.QuerySharedUsersResponse{}
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
t.Fatalf("unable to query known users: %v", err)
}
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
}
// Also verify that we get the expected result when specifying OtherUserIDs.
// This is used by the SyncAPI when getting device list changes.
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
t.Fatalf("unable to query known users: %v", err)
}
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
}
})
}

View file

@ -65,12 +65,18 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
);
`
var selectJoinedUsersSetForRoomsSQL = "" +
var selectJoinedUsersSetForRoomsAndUserSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid = ANY($1) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
// Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE
const insertMembershipSQL = "" +
@ -153,6 +159,7 @@ type membershipStatements struct {
selectLocalMembershipsFromRoomStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
selectRoomsWithMembershipStmt *sql.Stmt
selectJoinedUsersSetForRoomsAndUserStmt *sql.Stmt
selectJoinedUsersSetForRoomsStmt *sql.Stmt
selectKnownUsersStmt *sql.Stmt
updateMembershipForgetRoomStmt *sql.Stmt
@ -178,6 +185,7 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL},
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
{&s.selectJoinedUsersSetForRoomsAndUserStmt, selectJoinedUsersSetForRoomsAndUserSQL},
{&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL},
{&s.selectKnownUsersStmt, selectKnownUsersSQL},
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
@ -313,8 +321,18 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(
roomNIDs []types.RoomNID,
userNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]int, error) {
var (
rows *sql.Rows
err error
)
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
if len(userNIDs) > 0 {
stmt = sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsAndUserStmt)
rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
} else {
rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs))
}
if err != nil {
return nil, err
}

View file

@ -1214,6 +1214,13 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [
stateKeyNIDs[i] = nid
i++
}
// If we didn't have any userIDs to look up, get the UserIDs for the returned userNIDToCount now
if len(userIDs) == 0 {
nidToUserID, err = d.EventStateKeys(ctx, stateKeyNIDs)
if err != nil {
return nil, err
}
}
result := make(map[string]int, len(userNIDToCount))
for nid, count := range userNIDToCount {
result[nidToUserID[nid]] = count

View file

@ -41,12 +41,18 @@ const membershipSchema = `
);
`
var selectJoinedUsersSetForRoomsSQL = "" +
var selectJoinedUsersSetForRoomsAndUserSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid IN ($1) AND target_nid IN ($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid IN ($1) AND " +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
// Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE
const insertMembershipSQL = "" +
@ -293,8 +299,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
for _, v := range userNIDs {
params = append(params, v)
}
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
if len(userNIDs) > 0 {
query = strings.Replace(selectJoinedUsersSetForRoomsAndUserSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
}
var rows *sql.Rows
var err error
if txn != nil {