Add possibility to query all membership states
This commit is contained in:
parent
61cdb714df
commit
c2b6019c35
|
@ -266,32 +266,14 @@ func sendServerNotice(
|
|||
}
|
||||
|
||||
func getAllUserRooms(ctx context.Context, rsAPI api.RoomserverInternalAPI, userID string) ([]string, error) {
|
||||
allUserRooms := []string{}
|
||||
userRooms := api.QueryRoomsForUserResponse{}
|
||||
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
|
||||
UserID: userID,
|
||||
WantMembership: "join",
|
||||
WantMembership: "all",
|
||||
}, &userRooms); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allUserRooms = append(allUserRooms, userRooms.RoomIDs...)
|
||||
// get invites for specified user
|
||||
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
|
||||
UserID: userID,
|
||||
WantMembership: "invite",
|
||||
}, &userRooms); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allUserRooms = append(allUserRooms, userRooms.RoomIDs...)
|
||||
// get left rooms for specified user
|
||||
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
|
||||
UserID: userID,
|
||||
WantMembership: "leave",
|
||||
}, &userRooms); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allUserRooms = append(allUserRooms, userRooms.RoomIDs...)
|
||||
return allUserRooms, nil
|
||||
return userRooms.RoomIDs, nil
|
||||
}
|
||||
|
||||
func (r sendServerNoticeRequest) valid() (ok bool) {
|
||||
|
|
|
@ -114,6 +114,9 @@ const updateMembershipForgetRoom = "" +
|
|||
const selectRoomsWithMembershipSQL = "" +
|
||||
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
|
||||
|
||||
const selectRoomsForUserSQL = "" +
|
||||
"SELECT room_nid FROM roomserver_membership WHERE target_nid = $1 and forgotten = false"
|
||||
|
||||
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
|
||||
// joined to. Since this information is used to populate the user directory, we will
|
||||
// only return users that the user would ordinarily be able to see anyway.
|
||||
|
@ -157,6 +160,7 @@ type membershipStatements struct {
|
|||
updateMembershipForgetRoomStmt *sql.Stmt
|
||||
selectLocalServerInRoomStmt *sql.Stmt
|
||||
selectServerInRoomStmt *sql.Stmt
|
||||
selectRoomsForUserStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func createMembershipTable(db *sql.DB) error {
|
||||
|
@ -182,6 +186,7 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) {
|
|||
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
|
||||
{&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL},
|
||||
{&s.selectServerInRoomStmt, selectServerInRoomSQL},
|
||||
{&s.selectRoomsForUserStmt, selectRoomsForUserSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
|
@ -286,8 +291,19 @@ func (s *membershipStatements) SelectRoomsWithMembership(
|
|||
ctx context.Context, txn *sql.Tx,
|
||||
userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||
) ([]types.RoomNID, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
|
||||
rows, err := stmt.QueryContext(ctx, membershipState, userID)
|
||||
var (
|
||||
rows *sql.Rows
|
||||
err error
|
||||
)
|
||||
|
||||
if membershipState == tables.MemberShipStateAll {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomsForUserStmt)
|
||||
rows, err = stmt.QueryContext(ctx, userID)
|
||||
} else {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMembershipsFromRoomStmt)
|
||||
rows, err = stmt.QueryContext(ctx, membershipState, userID)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -967,6 +967,8 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
|
|||
membershipState = tables.MembershipStateLeaveOrBan
|
||||
case "ban":
|
||||
membershipState = tables.MembershipStateLeaveOrBan
|
||||
case "all":
|
||||
membershipState = tables.MemberShipStateAll
|
||||
default:
|
||||
return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership)
|
||||
}
|
||||
|
|
|
@ -90,6 +90,9 @@ const updateMembershipForgetRoom = "" +
|
|||
const selectRoomsWithMembershipSQL = "" +
|
||||
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
|
||||
|
||||
const selectRoomsForUserSQL = "" +
|
||||
"SELECT room_nid FROM roomserver_membership WHERE target_nid = $1 and forgotten = false"
|
||||
|
||||
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
|
||||
// joined to. Since this information is used to populate the user directory, we will
|
||||
// only return users that the user would ordinarily be able to see anyway.
|
||||
|
@ -133,6 +136,7 @@ type membershipStatements struct {
|
|||
updateMembershipForgetRoomStmt *sql.Stmt
|
||||
selectLocalServerInRoomStmt *sql.Stmt
|
||||
selectServerInRoomStmt *sql.Stmt
|
||||
selectRoomsForUserStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func createMembershipTable(db *sql.DB) error {
|
||||
|
@ -159,6 +163,7 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) {
|
|||
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
|
||||
{&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL},
|
||||
{&s.selectServerInRoomStmt, selectServerInRoomSQL},
|
||||
{&s.selectRoomsForUserStmt, selectRoomsForUserSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
|
@ -263,8 +268,19 @@ func (s *membershipStatements) UpdateMembership(
|
|||
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||
ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||
) ([]types.RoomNID, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
|
||||
rows, err := stmt.QueryContext(ctx, membershipState, userID)
|
||||
var (
|
||||
rows *sql.Rows
|
||||
err error
|
||||
)
|
||||
|
||||
if membershipState == tables.MemberShipStateAll {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomsForUserStmt)
|
||||
rows, err = stmt.QueryContext(ctx, userID)
|
||||
} else {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMembershipsFromRoomStmt)
|
||||
rows, err = stmt.QueryContext(ctx, membershipState, userID)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -113,6 +113,7 @@ type Invites interface {
|
|||
type MembershipState int64
|
||||
|
||||
const (
|
||||
MemberShipStateAll MembershipState = 0
|
||||
MembershipStateLeaveOrBan MembershipState = 1
|
||||
MembershipStateInvite MembershipState = 2
|
||||
MembershipStateJoin MembershipState = 3
|
||||
|
|
Loading…
Reference in a new issue