Revert changes. Use the SelectMembershipForUser directly to check for membership

This commit is contained in:
Tak Wai Wong 2022-11-01 17:05:57 -07:00
parent 35df23edd6
commit 888d68fc38
No known key found for this signature in database
GPG key ID: 222E4AF2AA1F467D
6 changed files with 19 additions and 104 deletions

View file

@ -109,7 +109,6 @@ type DatabaseTransaction interface {
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error)
IsMemberOfRoom(ctx context.Context, roomID string, userID string) (bool, error)
} }
type Database interface { type Database interface {

View file

@ -100,9 +100,6 @@ const selectJoinedUsersSQL = "" +
const selectJoinedUsersInRoomSQL = "" + const selectJoinedUsersInRoomSQL = "" +
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id = ANY($1)" "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id = ANY($1)"
const selectRoomMembershipOfUserSQL = "" +
"SELECT membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND state_key = $2"
const selectStateEventSQL = "" + const selectStateEventSQL = "" +
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
@ -126,7 +123,6 @@ type currentRoomStateStatements struct {
selectEventsWithEventIDsStmt *sql.Stmt selectEventsWithEventIDsStmt *sql.Stmt
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
selectSharedUsersStmt *sql.Stmt selectSharedUsersStmt *sql.Stmt
selectRoomMembershipOfUserStmt *sql.Stmt
} }
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
@ -170,9 +166,6 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil { if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
return nil, err return nil, err
} }
if s.selectRoomMembershipOfUserStmt, err = db.Prepare(selectRoomMembershipOfUserSQL); err != nil {
return nil, err
}
if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil { if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil {
return nil, err return nil, err
} }
@ -456,30 +449,3 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
} }
return result, rows.Err() return result, rows.Err()
} }
func (s *currentRoomStateStatements) SelectRoomMembershipOfUser(
ctx context.Context,
txn *sql.Tx,
roomID string,
userID string,
) (string, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomMembershipOfUserStmt)
rows, err := stmt.QueryContext(ctx, roomID, userID)
if err != nil {
return "", err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomMembershipOfUserStmt: rows.close() failed")
for rows.Next() {
var membership string
if err := rows.Scan(&membership); err != nil {
return "", err
}
// Found the membership info
if membership != "" {
return membership, rows.Err()
}
}
return "", nil
}

View file

@ -688,16 +688,3 @@ func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID,
return events, prevBatch, nextBatch, nil return events, prevBatch, nextBatch, nil
} }
func (d *DatabaseTransaction) IsMemberOfRoom(
ctx context.Context,
roomID string,
userID string,
) (bool, error) {
membership, err := d.CurrentRoomState.SelectRoomMembershipOfUser(ctx, d.txn, roomID, userID)
if err != nil {
return false, err
}
return membership == gomatrixserverlib.Join, nil
}

View file

@ -84,9 +84,6 @@ const selectJoinedUsersSQL = "" +
const selectJoinedUsersInRoomSQL = "" + const selectJoinedUsersInRoomSQL = "" +
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id IN ($1)" "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id IN ($1)"
const selectRoomMembershipOfUserSQL = "" +
"SELECT membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND state_key = $2"
const selectStateEventSQL = "" + const selectStateEventSQL = "" +
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
@ -110,7 +107,6 @@ type currentRoomStateStatements struct {
//selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic //selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
//selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic //selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic
selectRoomMembershipOfUserStmt *sql.Stmt
} }
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) { func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
@ -151,9 +147,6 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
return nil, err return nil, err
} }
if s.selectRoomMembershipOfUserStmt, err = db.Prepare(selectRoomMembershipOfUserSQL); err != nil {
return nil, err
}
//if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil { //if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
// return nil, err // return nil, err
//} //}
@ -491,26 +484,3 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
return result, err return result, err
} }
func (s *currentRoomStateStatements) SelectRoomMembershipOfUser(
ctx context.Context, txn *sql.Tx, roomID string, userID string,
) (string, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomMembershipOfUserStmt)
rows, err := stmt.QueryContext(ctx, roomID, userID)
if err != nil {
return "", err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomMembershipOfUserStmt: rows.close() failed")
membership := ""
for rows.Next() {
if err := rows.Scan(&membership); err != nil {
return "", err
}
// Found the membership info
if membership != "" {
return membership, rows.Err()
}
}
return membership, rows.Err()
}

View file

@ -115,8 +115,6 @@ type CurrentRoomState interface {
SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error) SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error)
// SelectSharedUsers returns a subset of otherUserIDs that share a room with userID. // SelectSharedUsers returns a subset of otherUserIDs that share a room with userID.
SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error) SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error)
// SelectRoomMembershipOfUser returns the membership of the user in the room.
SelectRoomMembershipOfUser(ctx context.Context, txn *sql.Tx, roomID string, userID string) (string, error)
} }
// BackwardsExtremities keeps track of backwards extremities for a room. // BackwardsExtremities keeps track of backwards extremities for a room.

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"math"
"strconv" "strconv"
"time" "time"
@ -74,33 +75,27 @@ func (p *InviteStreamProvider) IncrementalSync(
return to return to
} }
for roomID := range retiredInvites { for roomID := range retiredInvites {
if req.Response.Rooms.Invite[roomID] != nil { membership, _, err := snapshot.SelectMembershipForUser(ctx, roomID, req.Device.UserID, math.MaxInt64)
continue // Skip if the user is an existing member of the room.
} // Otherwise, the NewLeaveResponse will eject the user from the room unintentionally
if req.Response.Rooms.Join[roomID] != nil { if membership == gomatrixserverlib.Join ||
err != nil {
continue continue
} }
isMember, err := snapshot.IsMemberOfRoom(ctx, roomID, req.Device.UserID) lr := types.NewLeaveResponse()
if err != nil { h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...))
continue lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{
} // fake event ID which muxes in the to position
EventID: "$" + base64.RawURLEncoding.EncodeToString(h[:]),
if !isMember { OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()),
lr := types.NewLeaveResponse() RoomID: roomID,
h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...)) Sender: req.Device.UserID,
lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{ StateKey: &req.Device.UserID,
// fake event ID which muxes in the to position Type: "m.room.member",
EventID: "$" + base64.RawURLEncoding.EncodeToString(h[:]), Content: gomatrixserverlib.RawJSON(`{"membership":"leave"}`),
OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), })
RoomID: roomID, req.Response.Rooms.Leave[roomID] = lr
Sender: req.Device.UserID,
StateKey: &req.Device.UserID,
Type: "m.room.member",
Content: gomatrixserverlib.RawJSON(`{"membership":"leave"}`),
})
req.Response.Rooms.Leave[roomID] = lr
}
} }
return maxID return maxID