From 888d68fc3822fb53d68b011221c9ce6b10f7c155 Mon Sep 17 00:00:00 2001 From: Tak Wai Wong Date: Tue, 1 Nov 2022 17:05:57 -0700 Subject: [PATCH] Revert changes. Use the SelectMembershipForUser directly to check for membership --- syncapi/storage/interface.go | 1 - .../postgres/current_room_state_table.go | 34 --------------- syncapi/storage/shared/storage_sync.go | 13 ------ .../sqlite3/current_room_state_table.go | 30 ------------- syncapi/storage/tables/interface.go | 2 - syncapi/streams/stream_invite.go | 43 ++++++++----------- 6 files changed, 19 insertions(+), 104 deletions(-) diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 98e85a349..af4fce44e 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -109,7 +109,6 @@ type DatabaseTransaction interface { GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error) - IsMemberOfRoom(ctx context.Context, roomID string, userID string) (bool, error) } type Database interface { diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index c63f5e9d6..2ccf0be1a 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -100,9 +100,6 @@ const selectJoinedUsersSQL = "" + const selectJoinedUsersInRoomSQL = "" + "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id = ANY($1)" -const selectRoomMembershipOfUserSQL = "" + - "SELECT membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND state_key = $2" - const selectStateEventSQL = "" + "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" @@ -126,7 +123,6 @@ type currentRoomStateStatements struct { selectEventsWithEventIDsStmt *sql.Stmt selectStateEventStmt *sql.Stmt selectSharedUsersStmt *sql.Stmt - selectRoomMembershipOfUserStmt *sql.Stmt } func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { @@ -170,9 +166,6 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil { return nil, err } - if s.selectRoomMembershipOfUserStmt, err = db.Prepare(selectRoomMembershipOfUserSQL); err != nil { - return nil, err - } if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil { return nil, err } @@ -456,30 +449,3 @@ func (s *currentRoomStateStatements) SelectSharedUsers( } return result, rows.Err() } - -func (s *currentRoomStateStatements) SelectRoomMembershipOfUser( - ctx context.Context, - txn *sql.Tx, - roomID string, - userID string, -) (string, error) { - stmt := sqlutil.TxStmt(txn, s.selectRoomMembershipOfUserStmt) - rows, err := stmt.QueryContext(ctx, roomID, userID) - if err != nil { - return "", err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectRoomMembershipOfUserStmt: rows.close() failed") - - for rows.Next() { - var membership string - if err := rows.Scan(&membership); err != nil { - return "", err - } - // Found the membership info - if membership != "" { - return membership, rows.Err() - } - } - - return "", nil -} diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index 222644007..1f66ccc0e 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -688,16 +688,3 @@ func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, return events, prevBatch, nextBatch, nil } - -func (d *DatabaseTransaction) IsMemberOfRoom( - ctx context.Context, - roomID string, - userID string, -) (bool, error) { - membership, err := d.CurrentRoomState.SelectRoomMembershipOfUser(ctx, d.txn, roomID, userID) - if err != nil { - return false, err - } - - return membership == gomatrixserverlib.Join, nil -} diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 2c961f5aa..ff45e786e 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -84,9 +84,6 @@ const selectJoinedUsersSQL = "" + const selectJoinedUsersInRoomSQL = "" + "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id IN ($1)" -const selectRoomMembershipOfUserSQL = "" + - "SELECT membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND state_key = $2" - const selectStateEventSQL = "" + "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" @@ -110,7 +107,6 @@ type currentRoomStateStatements struct { //selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic selectStateEventStmt *sql.Stmt //selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic - selectRoomMembershipOfUserStmt *sql.Stmt } func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) { @@ -151,9 +147,6 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { return nil, err } - if s.selectRoomMembershipOfUserStmt, err = db.Prepare(selectRoomMembershipOfUserSQL); err != nil { - return nil, err - } //if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil { // return nil, err //} @@ -491,26 +484,3 @@ func (s *currentRoomStateStatements) SelectSharedUsers( return result, err } - -func (s *currentRoomStateStatements) SelectRoomMembershipOfUser( - ctx context.Context, txn *sql.Tx, roomID string, userID string, -) (string, error) { - stmt := sqlutil.TxStmt(txn, s.selectRoomMembershipOfUserStmt) - rows, err := stmt.QueryContext(ctx, roomID, userID) - if err != nil { - return "", err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectRoomMembershipOfUserStmt: rows.close() failed") - - membership := "" - for rows.Next() { - if err := rows.Scan(&membership); err != nil { - return "", err - } - // Found the membership info - if membership != "" { - return membership, rows.Err() - } - } - return membership, rows.Err() -} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 1a0aa9ee7..2c4f04ec2 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -115,8 +115,6 @@ type CurrentRoomState interface { SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error) // SelectSharedUsers returns a subset of otherUserIDs that share a room with userID. SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error) - // SelectRoomMembershipOfUser returns the membership of the user in the room. - SelectRoomMembershipOfUser(ctx context.Context, txn *sql.Tx, roomID string, userID string) (string, error) } // BackwardsExtremities keeps track of backwards extremities for a room. diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index 02da42a73..e4de30e1c 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "encoding/base64" + "math" "strconv" "time" @@ -74,33 +75,27 @@ func (p *InviteStreamProvider) IncrementalSync( return to } for roomID := range retiredInvites { - if req.Response.Rooms.Invite[roomID] != nil { - continue - } - if req.Response.Rooms.Join[roomID] != nil { + membership, _, err := snapshot.SelectMembershipForUser(ctx, roomID, req.Device.UserID, math.MaxInt64) + // Skip if the user is an existing member of the room. + // Otherwise, the NewLeaveResponse will eject the user from the room unintentionally + if membership == gomatrixserverlib.Join || + err != nil { continue } - isMember, err := snapshot.IsMemberOfRoom(ctx, roomID, req.Device.UserID) - if err != nil { - continue - } - - if !isMember { - lr := types.NewLeaveResponse() - h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...)) - lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{ - // fake event ID which muxes in the to position - EventID: "$" + base64.RawURLEncoding.EncodeToString(h[:]), - OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), - RoomID: roomID, - Sender: req.Device.UserID, - StateKey: &req.Device.UserID, - Type: "m.room.member", - Content: gomatrixserverlib.RawJSON(`{"membership":"leave"}`), - }) - req.Response.Rooms.Leave[roomID] = lr - } + lr := types.NewLeaveResponse() + h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...)) + lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{ + // fake event ID which muxes in the to position + EventID: "$" + base64.RawURLEncoding.EncodeToString(h[:]), + OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), + RoomID: roomID, + Sender: req.Device.UserID, + StateKey: &req.Device.UserID, + Type: "m.room.member", + Content: gomatrixserverlib.RawJSON(`{"membership":"leave"}`), + }) + req.Response.Rooms.Leave[roomID] = lr } return maxID