diff --git a/roomserver/api/query.go b/roomserver/api/query.go index c70db65c1..599156bb1 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -181,11 +181,8 @@ type QueryServerJoinedToRoomRequest struct { type QueryServerJoinedToRoomResponse struct { // True if the room exists on the server RoomExists bool `json:"room_exists"` - // True if we still believe that we are participating in the room + // True if we still believe that the server is participating in the room IsInRoom bool `json:"is_in_room"` - // List of servers that are also in the room. This will not be populated - // if the queried ServerName is the local server name. - ServerNames []gomatrixserverlib.ServerName `json:"server_names"` } // QueryServerAllowedToSeeEventRequest is a request to QueryServerAllowedToSeeEvent diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 4af0e6397..b80f08ab6 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -330,46 +330,17 @@ func (r *Queryer) QueryServerJoinedToRoom( response.RoomExists = true if request.ServerName == r.ServerName || request.ServerName == "" { - var joined bool - joined, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID) + response.IsInRoom, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID) if err != nil { return fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err) } - response.IsInRoom = joined - return nil - } - - eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) - if err != nil { - return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) - } - if len(eventNIDs) == 0 { - return nil - } - - events, err := r.DB.Events(ctx, eventNIDs) - if err != nil { - return fmt.Errorf("r.DB.Events: %w", err) - } - - servers := map[gomatrixserverlib.ServerName]struct{}{} - for _, e := range events { - if e.Type() == gomatrixserverlib.MRoomMember && e.StateKey() != nil { - _, serverName, err := gomatrixserverlib.SplitID('@', *e.StateKey()) - if err != nil { - continue - } - servers[serverName] = struct{}{} - if serverName == request.ServerName { - response.IsInRoom = true - } + } else { + response.IsInRoom, err = r.DB.GetServerInRoom(ctx, info.RoomNID, request.ServerName) + if err != nil { + return fmt.Errorf("r.DB.GetServerInRoom: %w", err) } } - for server := range servers { - response.ServerNames = append(response.ServerNames, server) - } - return nil } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index c25820aac..62aa73ad4 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -156,6 +156,8 @@ type Database interface { JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) + // GetServerInRoom returns true if we think a server is in a given room or false otherwise. + GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) // GetKnownUsers searches all users that userID knows about. GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) // GetKnownRooms returns a list of all rooms we know about. diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 9102f26a3..b4a27900c 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" ) const membershipSchema = ` @@ -132,6 +133,16 @@ var selectKnownUsersSQL = "" + const selectLocalServerInRoomSQL = "" + "SELECT room_nid FROM roomserver_membership WHERE target_local = true AND membership_nid = $1 AND room_nid = $2 LIMIT 1" +// selectServerMembersInRoomSQL is an optimised case for checking for server members in a room. +// The JOIN is significantly leaner than the previous case of looking up event NIDs and reading the +// membership events from the database, as the JOIN query amounts to little more than two index +// scans which are very fast. The presence of a single row from this query suggests the server is +// in the room, no rows returned suggests they aren't. +const selectServerInRoomSQL = "" + + "SELECT room_nid FROM roomserver_membership" + + " JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + + " WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1" + type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt @@ -146,6 +157,7 @@ type membershipStatements struct { selectKnownUsersStmt *sql.Stmt updateMembershipForgetRoomStmt *sql.Stmt selectLocalServerInRoomStmt *sql.Stmt + selectServerInRoomStmt *sql.Stmt } func createMembershipTable(db *sql.DB) error { @@ -170,6 +182,7 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectKnownUsersStmt, selectKnownUsersSQL}, {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, {&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL}, + {&s.selectServerInRoomStmt, selectServerInRoomSQL}, }.Prepare(db) } @@ -347,3 +360,15 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room found := nid > 0 return found, nil } + +func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { + var nid types.RoomNID + err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return roomNID == nid, nil +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 8e787851b..4c1aae42d 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1068,6 +1068,11 @@ func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomN return d.MembershipTable.SelectLocalServerInRoom(ctx, roomNID) } +// GetServerInRoom returns true if we think a server is in a given room or false otherwise. +func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { + return d.MembershipTable.SelectServerInRoom(ctx, roomNID, serverName) +} + // GetKnownUsers searches all users that userID knows about. func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) { stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID) diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 82babe0d2..911a25168 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" ) const membershipSchema = ` @@ -108,6 +109,16 @@ var selectKnownUsersSQL = "" + const selectLocalServerInRoomSQL = "" + "SELECT room_nid FROM roomserver_membership WHERE target_local = 1 AND membership_nid = $1 AND room_nid = $2 LIMIT 1" +// selectServerMembersInRoomSQL is an optimised case for checking for server members in a room. +// The JOIN is significantly leaner than the previous case of looking up event NIDs and reading the +// membership events from the database, as the JOIN query amounts to little more than two index +// scans which are very fast. The presence of a single row from this query suggests the server is +// in the room, no rows returned suggests they aren't. +const selectServerInRoomSQL = "" + + "SELECT room_nid FROM roomserver_membership" + + " JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + + " WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1" + type membershipStatements struct { db *sql.DB insertMembershipStmt *sql.Stmt @@ -122,6 +133,7 @@ type membershipStatements struct { selectKnownUsersStmt *sql.Stmt updateMembershipForgetRoomStmt *sql.Stmt selectLocalServerInRoomStmt *sql.Stmt + selectServerInRoomStmt *sql.Stmt } func createMembershipTable(db *sql.DB) error { @@ -147,6 +159,7 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectKnownUsersStmt, selectKnownUsersSQL}, {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, {&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL}, + {&s.selectServerInRoomStmt, selectServerInRoomSQL}, }.Prepare(db) } @@ -327,3 +340,15 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room found := nid > 0 return found, nil } + +func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { + var nid types.RoomNID + err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return roomNID == nid, nil +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 4a893663f..f762cb712 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -136,6 +136,7 @@ type Membership interface { SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) + SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) } type Published interface {