Add possibility to query all membership states

This commit is contained in:
Till Faelligen 2022-02-21 16:40:59 +01:00
parent 61cdb714df
commit c2b6019c35
5 changed files with 41 additions and 24 deletions

View file

@ -266,32 +266,14 @@ func sendServerNotice(
}
func getAllUserRooms(ctx context.Context, rsAPI api.RoomserverInternalAPI, userID string) ([]string, error) {
allUserRooms := []string{}
userRooms := api.QueryRoomsForUserResponse{}
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: userID,
WantMembership: "join",
WantMembership: "all",
}, &userRooms); err != nil {
return nil, err
}
allUserRooms = append(allUserRooms, userRooms.RoomIDs...)
// get invites for specified user
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: userID,
WantMembership: "invite",
}, &userRooms); err != nil {
return nil, err
}
allUserRooms = append(allUserRooms, userRooms.RoomIDs...)
// get left rooms for specified user
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: userID,
WantMembership: "leave",
}, &userRooms); err != nil {
return nil, err
}
allUserRooms = append(allUserRooms, userRooms.RoomIDs...)
return allUserRooms, nil
return userRooms.RoomIDs, nil
}
func (r sendServerNoticeRequest) valid() (ok bool) {

View file

@ -114,6 +114,9 @@ const updateMembershipForgetRoom = "" +
const selectRoomsWithMembershipSQL = "" +
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
const selectRoomsForUserSQL = "" +
"SELECT room_nid FROM roomserver_membership WHERE target_nid = $1 and forgotten = false"
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
// joined to. Since this information is used to populate the user directory, we will
// only return users that the user would ordinarily be able to see anyway.
@ -157,6 +160,7 @@ type membershipStatements struct {
updateMembershipForgetRoomStmt *sql.Stmt
selectLocalServerInRoomStmt *sql.Stmt
selectServerInRoomStmt *sql.Stmt
selectRoomsForUserStmt *sql.Stmt
}
func createMembershipTable(db *sql.DB) error {
@ -182,6 +186,7 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
{&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL},
{&s.selectServerInRoomStmt, selectServerInRoomSQL},
{&s.selectRoomsForUserStmt, selectRoomsForUserSQL},
}.Prepare(db)
}
@ -286,8 +291,19 @@ func (s *membershipStatements) SelectRoomsWithMembership(
ctx context.Context, txn *sql.Tx,
userID types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
rows, err := stmt.QueryContext(ctx, membershipState, userID)
var (
rows *sql.Rows
err error
)
if membershipState == tables.MemberShipStateAll {
stmt := sqlutil.TxStmt(txn, s.selectRoomsForUserStmt)
rows, err = stmt.QueryContext(ctx, userID)
} else {
stmt := sqlutil.TxStmt(txn, s.selectMembershipsFromRoomStmt)
rows, err = stmt.QueryContext(ctx, membershipState, userID)
}
if err != nil {
return nil, err
}

View file

@ -967,6 +967,8 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
membershipState = tables.MembershipStateLeaveOrBan
case "ban":
membershipState = tables.MembershipStateLeaveOrBan
case "all":
membershipState = tables.MemberShipStateAll
default:
return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership)
}

View file

@ -90,6 +90,9 @@ const updateMembershipForgetRoom = "" +
const selectRoomsWithMembershipSQL = "" +
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
const selectRoomsForUserSQL = "" +
"SELECT room_nid FROM roomserver_membership WHERE target_nid = $1 and forgotten = false"
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
// joined to. Since this information is used to populate the user directory, we will
// only return users that the user would ordinarily be able to see anyway.
@ -133,6 +136,7 @@ type membershipStatements struct {
updateMembershipForgetRoomStmt *sql.Stmt
selectLocalServerInRoomStmt *sql.Stmt
selectServerInRoomStmt *sql.Stmt
selectRoomsForUserStmt *sql.Stmt
}
func createMembershipTable(db *sql.DB) error {
@ -159,6 +163,7 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
{&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL},
{&s.selectServerInRoomStmt, selectServerInRoomSQL},
{&s.selectRoomsForUserStmt, selectRoomsForUserSQL},
}.Prepare(db)
}
@ -263,8 +268,19 @@ func (s *membershipStatements) UpdateMembership(
func (s *membershipStatements) SelectRoomsWithMembership(
ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
rows, err := stmt.QueryContext(ctx, membershipState, userID)
var (
rows *sql.Rows
err error
)
if membershipState == tables.MemberShipStateAll {
stmt := sqlutil.TxStmt(txn, s.selectRoomsForUserStmt)
rows, err = stmt.QueryContext(ctx, userID)
} else {
stmt := sqlutil.TxStmt(txn, s.selectMembershipsFromRoomStmt)
rows, err = stmt.QueryContext(ctx, membershipState, userID)
}
if err != nil {
return nil, err
}

View file

@ -113,6 +113,7 @@ type Invites interface {
type MembershipState int64
const (
MemberShipStateAll MembershipState = 0
MembershipStateLeaveOrBan MembershipState = 1
MembershipStateInvite MembershipState = 2
MembershipStateJoin MembershipState = 3