mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-16 18:43:10 -06:00
Revert changes. Use the SelectMembershipForUser directly to check for membership
This commit is contained in:
parent
35df23edd6
commit
888d68fc38
|
|
@ -109,7 +109,6 @@ type DatabaseTransaction interface {
|
||||||
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
|
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
|
||||||
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
|
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
|
||||||
RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error)
|
RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error)
|
||||||
IsMemberOfRoom(ctx context.Context, roomID string, userID string) (bool, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Database interface {
|
type Database interface {
|
||||||
|
|
|
||||||
|
|
@ -100,9 +100,6 @@ const selectJoinedUsersSQL = "" +
|
||||||
const selectJoinedUsersInRoomSQL = "" +
|
const selectJoinedUsersInRoomSQL = "" +
|
||||||
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id = ANY($1)"
|
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id = ANY($1)"
|
||||||
|
|
||||||
const selectRoomMembershipOfUserSQL = "" +
|
|
||||||
"SELECT membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND state_key = $2"
|
|
||||||
|
|
||||||
const selectStateEventSQL = "" +
|
const selectStateEventSQL = "" +
|
||||||
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
|
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
|
||||||
|
|
||||||
|
|
@ -126,7 +123,6 @@ type currentRoomStateStatements struct {
|
||||||
selectEventsWithEventIDsStmt *sql.Stmt
|
selectEventsWithEventIDsStmt *sql.Stmt
|
||||||
selectStateEventStmt *sql.Stmt
|
selectStateEventStmt *sql.Stmt
|
||||||
selectSharedUsersStmt *sql.Stmt
|
selectSharedUsersStmt *sql.Stmt
|
||||||
selectRoomMembershipOfUserStmt *sql.Stmt
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
|
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
|
||||||
|
|
@ -170,9 +166,6 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
|
||||||
if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
|
if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.selectRoomMembershipOfUserStmt, err = db.Prepare(selectRoomMembershipOfUserSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil {
|
if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -456,30 +449,3 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
|
||||||
}
|
}
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *currentRoomStateStatements) SelectRoomMembershipOfUser(
|
|
||||||
ctx context.Context,
|
|
||||||
txn *sql.Tx,
|
|
||||||
roomID string,
|
|
||||||
userID string,
|
|
||||||
) (string, error) {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectRoomMembershipOfUserStmt)
|
|
||||||
rows, err := stmt.QueryContext(ctx, roomID, userID)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomMembershipOfUserStmt: rows.close() failed")
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
var membership string
|
|
||||||
if err := rows.Scan(&membership); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
// Found the membership info
|
|
||||||
if membership != "" {
|
|
||||||
return membership, rows.Err()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -688,16 +688,3 @@ func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID,
|
||||||
|
|
||||||
return events, prevBatch, nextBatch, nil
|
return events, prevBatch, nextBatch, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DatabaseTransaction) IsMemberOfRoom(
|
|
||||||
ctx context.Context,
|
|
||||||
roomID string,
|
|
||||||
userID string,
|
|
||||||
) (bool, error) {
|
|
||||||
membership, err := d.CurrentRoomState.SelectRoomMembershipOfUser(ctx, d.txn, roomID, userID)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return membership == gomatrixserverlib.Join, nil
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -84,9 +84,6 @@ const selectJoinedUsersSQL = "" +
|
||||||
const selectJoinedUsersInRoomSQL = "" +
|
const selectJoinedUsersInRoomSQL = "" +
|
||||||
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id IN ($1)"
|
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id IN ($1)"
|
||||||
|
|
||||||
const selectRoomMembershipOfUserSQL = "" +
|
|
||||||
"SELECT membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND state_key = $2"
|
|
||||||
|
|
||||||
const selectStateEventSQL = "" +
|
const selectStateEventSQL = "" +
|
||||||
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
|
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
|
||||||
|
|
||||||
|
|
@ -110,7 +107,6 @@ type currentRoomStateStatements struct {
|
||||||
//selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
|
//selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
|
||||||
selectStateEventStmt *sql.Stmt
|
selectStateEventStmt *sql.Stmt
|
||||||
//selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic
|
//selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic
|
||||||
selectRoomMembershipOfUserStmt *sql.Stmt
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
|
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
|
||||||
|
|
@ -151,9 +147,6 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t
|
||||||
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
|
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.selectRoomMembershipOfUserStmt, err = db.Prepare(selectRoomMembershipOfUserSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
//if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
|
//if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
|
||||||
// return nil, err
|
// return nil, err
|
||||||
//}
|
//}
|
||||||
|
|
@ -491,26 +484,3 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
|
||||||
|
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *currentRoomStateStatements) SelectRoomMembershipOfUser(
|
|
||||||
ctx context.Context, txn *sql.Tx, roomID string, userID string,
|
|
||||||
) (string, error) {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectRoomMembershipOfUserStmt)
|
|
||||||
rows, err := stmt.QueryContext(ctx, roomID, userID)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomMembershipOfUserStmt: rows.close() failed")
|
|
||||||
|
|
||||||
membership := ""
|
|
||||||
for rows.Next() {
|
|
||||||
if err := rows.Scan(&membership); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
// Found the membership info
|
|
||||||
if membership != "" {
|
|
||||||
return membership, rows.Err()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return membership, rows.Err()
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -115,8 +115,6 @@ type CurrentRoomState interface {
|
||||||
SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error)
|
SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error)
|
||||||
// SelectSharedUsers returns a subset of otherUserIDs that share a room with userID.
|
// SelectSharedUsers returns a subset of otherUserIDs that share a room with userID.
|
||||||
SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error)
|
SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error)
|
||||||
// SelectRoomMembershipOfUser returns the membership of the user in the room.
|
|
||||||
SelectRoomMembershipOfUser(ctx context.Context, txn *sql.Tx, roomID string, userID string) (string, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// BackwardsExtremities keeps track of backwards extremities for a room.
|
// BackwardsExtremities keeps track of backwards extremities for a room.
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"math"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -74,19 +75,14 @@ func (p *InviteStreamProvider) IncrementalSync(
|
||||||
return to
|
return to
|
||||||
}
|
}
|
||||||
for roomID := range retiredInvites {
|
for roomID := range retiredInvites {
|
||||||
if req.Response.Rooms.Invite[roomID] != nil {
|
membership, _, err := snapshot.SelectMembershipForUser(ctx, roomID, req.Device.UserID, math.MaxInt64)
|
||||||
continue
|
// Skip if the user is an existing member of the room.
|
||||||
}
|
// Otherwise, the NewLeaveResponse will eject the user from the room unintentionally
|
||||||
if req.Response.Rooms.Join[roomID] != nil {
|
if membership == gomatrixserverlib.Join ||
|
||||||
|
err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
isMember, err := snapshot.IsMemberOfRoom(ctx, roomID, req.Device.UserID)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isMember {
|
|
||||||
lr := types.NewLeaveResponse()
|
lr := types.NewLeaveResponse()
|
||||||
h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...))
|
h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...))
|
||||||
lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{
|
lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{
|
||||||
|
|
@ -101,7 +97,6 @@ func (p *InviteStreamProvider) IncrementalSync(
|
||||||
})
|
})
|
||||||
req.Response.Rooms.Leave[roomID] = lr
|
req.Response.Rooms.Leave[roomID] = lr
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return maxID
|
return maxID
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue