diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index 40e42d72a..59c2db4e0 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -266,32 +266,14 @@ func sendServerNotice( } func getAllUserRooms(ctx context.Context, rsAPI api.RoomserverInternalAPI, userID string) ([]string, error) { - allUserRooms := []string{} userRooms := api.QueryRoomsForUserResponse{} if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ UserID: userID, - WantMembership: "join", + WantMembership: "all", }, &userRooms); err != nil { return nil, err } - allUserRooms = append(allUserRooms, userRooms.RoomIDs...) - // get invites for specified user - if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ - UserID: userID, - WantMembership: "invite", - }, &userRooms); err != nil { - return nil, err - } - allUserRooms = append(allUserRooms, userRooms.RoomIDs...) - // get left rooms for specified user - if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ - UserID: userID, - WantMembership: "leave", - }, &userRooms); err != nil { - return nil, err - } - allUserRooms = append(allUserRooms, userRooms.RoomIDs...) - return allUserRooms, nil + return userRooms.RoomIDs, nil } func (r sendServerNoticeRequest) valid() (ok bool) { diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 48c2c35cd..0106d3e68 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -114,6 +114,9 @@ const updateMembershipForgetRoom = "" + const selectRoomsWithMembershipSQL = "" + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false" +const selectRoomsForUserSQL = "" + + "SELECT room_nid FROM roomserver_membership WHERE target_nid = $1 and forgotten = false" + // selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is // joined to. Since this information is used to populate the user directory, we will // only return users that the user would ordinarily be able to see anyway. @@ -157,6 +160,7 @@ type membershipStatements struct { updateMembershipForgetRoomStmt *sql.Stmt selectLocalServerInRoomStmt *sql.Stmt selectServerInRoomStmt *sql.Stmt + selectRoomsForUserStmt *sql.Stmt } func createMembershipTable(db *sql.DB) error { @@ -182,6 +186,7 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, {&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL}, {&s.selectServerInRoomStmt, selectServerInRoomSQL}, + {&s.selectRoomsForUserStmt, selectRoomsForUserSQL}, }.Prepare(db) } @@ -286,8 +291,19 @@ func (s *membershipStatements) SelectRoomsWithMembership( ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState, ) ([]types.RoomNID, error) { - stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt) - rows, err := stmt.QueryContext(ctx, membershipState, userID) + var ( + rows *sql.Rows + err error + ) + + if membershipState == tables.MemberShipStateAll { + stmt := sqlutil.TxStmt(txn, s.selectRoomsForUserStmt) + rows, err = stmt.QueryContext(ctx, userID) + } else { + stmt := sqlutil.TxStmt(txn, s.selectMembershipsFromRoomStmt) + rows, err = stmt.QueryContext(ctx, membershipState, userID) + } + if err != nil { return nil, err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index b255cfb3f..45ebed0b0 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -967,6 +967,8 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership membershipState = tables.MembershipStateLeaveOrBan case "ban": membershipState = tables.MembershipStateLeaveOrBan + case "all": + membershipState = tables.MemberShipStateAll default: return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership) } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 181b4b4c9..e2b37da4d 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -90,6 +90,9 @@ const updateMembershipForgetRoom = "" + const selectRoomsWithMembershipSQL = "" + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false" +const selectRoomsForUserSQL = "" + + "SELECT room_nid FROM roomserver_membership WHERE target_nid = $1 and forgotten = false" + // selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is // joined to. Since this information is used to populate the user directory, we will // only return users that the user would ordinarily be able to see anyway. @@ -133,6 +136,7 @@ type membershipStatements struct { updateMembershipForgetRoomStmt *sql.Stmt selectLocalServerInRoomStmt *sql.Stmt selectServerInRoomStmt *sql.Stmt + selectRoomsForUserStmt *sql.Stmt } func createMembershipTable(db *sql.DB) error { @@ -159,6 +163,7 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, {&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL}, {&s.selectServerInRoomStmt, selectServerInRoomSQL}, + {&s.selectRoomsForUserStmt, selectRoomsForUserSQL}, }.Prepare(db) } @@ -263,8 +268,19 @@ func (s *membershipStatements) UpdateMembership( func (s *membershipStatements) SelectRoomsWithMembership( ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState, ) ([]types.RoomNID, error) { - stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt) - rows, err := stmt.QueryContext(ctx, membershipState, userID) + var ( + rows *sql.Rows + err error + ) + + if membershipState == tables.MemberShipStateAll { + stmt := sqlutil.TxStmt(txn, s.selectRoomsForUserStmt) + rows, err = stmt.QueryContext(ctx, userID) + } else { + stmt := sqlutil.TxStmt(txn, s.selectMembershipsFromRoomStmt) + rows, err = stmt.QueryContext(ctx, membershipState, userID) + } + if err != nil { return nil, err } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index e3fed700b..468a427ff 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -113,6 +113,7 @@ type Invites interface { type MembershipState int64 const ( + MemberShipStateAll MembershipState = 0 MembershipStateLeaveOrBan MembershipState = 1 MembershipStateInvite MembershipState = 2 MembershipStateJoin MembershipState = 3