From 2f0415e0df60582b91c651331db5b42007ff1789 Mon Sep 17 00:00:00 2001 From: Tak Wai Wong Date: Tue, 1 Nov 2022 16:04:06 -0700 Subject: [PATCH] implement storage_sync query for membership --- syncapi/storage/interface.go | 1 + .../postgres/current_room_state_table.go | 31 +++++++++++++++++++ syncapi/storage/shared/storage_sync.go | 13 ++++++++ .../sqlite3/current_room_state_table.go | 30 ++++++++++++++++++ syncapi/storage/tables/interface.go | 2 ++ syncapi/streams/stream_invite.go | 4 +-- 6 files changed, 79 insertions(+), 2 deletions(-) diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index af4fce44e..98e85a349 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -109,6 +109,7 @@ type DatabaseTransaction interface { GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error) + IsMemberOfRoom(ctx context.Context, roomID string, userID string) (bool, error) } type Database interface { diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 2ccf0be1a..2918a932a 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -100,6 +100,9 @@ const selectJoinedUsersSQL = "" + const selectJoinedUsersInRoomSQL = "" + "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id = ANY($1)" +const selectRoomMembershipOfUserSQL = "" + + "SELECT membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND state_key = $2" + const selectStateEventSQL = "" + "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" @@ -123,6 +126,7 @@ type currentRoomStateStatements struct { selectEventsWithEventIDsStmt *sql.Stmt selectStateEventStmt *sql.Stmt selectSharedUsersStmt *sql.Stmt + selectRoomMembershipOfUserStmt *sql.Stmt } func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { @@ -166,6 +170,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil { return nil, err } + if s.selectRoomMembershipOfUserStmt, err = db.Prepare(selectRoomMembershipOfUserSQL); err != nil { + return nil, err + } if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil { return nil, err } @@ -449,3 +456,27 @@ func (s *currentRoomStateStatements) SelectSharedUsers( } return result, rows.Err() } + +func (s *currentRoomStateStatements) SelectRoomMembershipOfUser( + ctx context.Context, + txn *sql.Tx, + roomID string, + userID string, +) (string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomMembershipOfUserStmt) + rows, err := stmt.QueryContext(ctx, roomID, userID) + if err != nil { + return "", err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomMembershipOfUserStmt: rows.close() failed") + + for rows.Next() { + var membership string + if err := rows.Scan(&membership); err != nil { + return "", err + } + return membership, nil + } + + return "", nil +} diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index 1f66ccc0e..222644007 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -688,3 +688,16 @@ func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, return events, prevBatch, nextBatch, nil } + +func (d *DatabaseTransaction) IsMemberOfRoom( + ctx context.Context, + roomID string, + userID string, +) (bool, error) { + membership, err := d.CurrentRoomState.SelectRoomMembershipOfUser(ctx, d.txn, roomID, userID) + if err != nil { + return false, err + } + + return membership == gomatrixserverlib.Join, nil +} diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index ff45e786e..2c961f5aa 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -84,6 +84,9 @@ const selectJoinedUsersSQL = "" + const selectJoinedUsersInRoomSQL = "" + "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id IN ($1)" +const selectRoomMembershipOfUserSQL = "" + + "SELECT membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND state_key = $2" + const selectStateEventSQL = "" + "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" @@ -107,6 +110,7 @@ type currentRoomStateStatements struct { //selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic selectStateEventStmt *sql.Stmt //selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic + selectRoomMembershipOfUserStmt *sql.Stmt } func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) { @@ -147,6 +151,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { return nil, err } + if s.selectRoomMembershipOfUserStmt, err = db.Prepare(selectRoomMembershipOfUserSQL); err != nil { + return nil, err + } //if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil { // return nil, err //} @@ -484,3 +491,26 @@ func (s *currentRoomStateStatements) SelectSharedUsers( return result, err } + +func (s *currentRoomStateStatements) SelectRoomMembershipOfUser( + ctx context.Context, txn *sql.Tx, roomID string, userID string, +) (string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomMembershipOfUserStmt) + rows, err := stmt.QueryContext(ctx, roomID, userID) + if err != nil { + return "", err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomMembershipOfUserStmt: rows.close() failed") + + membership := "" + for rows.Next() { + if err := rows.Scan(&membership); err != nil { + return "", err + } + // Found the membership info + if membership != "" { + return membership, rows.Err() + } + } + return membership, rows.Err() +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 2c4f04ec2..1a0aa9ee7 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -115,6 +115,8 @@ type CurrentRoomState interface { SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error) // SelectSharedUsers returns a subset of otherUserIDs that share a room with userID. SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error) + // SelectRoomMembershipOfUser returns the membership of the user in the room. + SelectRoomMembershipOfUser(ctx context.Context, txn *sql.Tx, roomID string, userID string) (string, error) } // BackwardsExtremities keeps track of backwards extremities for a room. diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index 569e6df48..c7272f253 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -81,12 +81,12 @@ func (p *InviteStreamProvider) IncrementalSync( continue } - joinedUsers, err := snapshot.AllJoinedUsersInRoom(ctx, []string{roomID}) + isMember, err := snapshot.IsMemberOfRoom(ctx, roomID, req.Device.UserID) if err != nil { continue } - if !contains(joinedUsers[roomID], req.Device.UserID) { + if !isMember { lr := types.NewLeaveResponse() h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...)) lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{