/sync bugfix: Check transitions to 'leave' do not leak events afterwards (#105)
This commit is contained in:
parent
d5a44fd3e8
commit
ccd0eb2851
|
@ -547,14 +547,15 @@ func main() {
|
|||
// $ curl -XPUT -d '{"membership":"join"}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/state/m.room.member/@charlie:localhost?access_token=@charlie:localhost"
|
||||
// $ curl -XPUT -d '{"msgtype":"m.text","body":"not charlie..."}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/send/m.room.message/3?access_token=@alice:localhost"
|
||||
// $ curl -XPUT -d '{"membership":"leave"}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/state/m.room.member/@charlie:localhost?access_token=@alice:localhost"
|
||||
writeToRoomServerLog(i14StateCharlieJoin, i15AliceMsg, i16StateAliceKickCharlie)
|
||||
// $ curl -XPUT -d '{"msgtype":"m.text","body":"why did you kick charlie"}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/send/m.room.message/3?access_token=@bob:localhost"
|
||||
writeToRoomServerLog(i14StateCharlieJoin, i15AliceMsg, i16StateAliceKickCharlie, i17BobMsg)
|
||||
|
||||
// Check transitions to leave work
|
||||
testSyncServer(syncServerCmdChan, "@charlie:localhost", "15", `{
|
||||
"account_data": {
|
||||
"events": []
|
||||
},
|
||||
"next_batch": "17",
|
||||
"next_batch": "18",
|
||||
"presence": {
|
||||
"events": []
|
||||
},
|
||||
|
@ -586,7 +587,7 @@ func main() {
|
|||
"account_data": {
|
||||
"events": []
|
||||
},
|
||||
"next_batch": "17",
|
||||
"next_batch": "18",
|
||||
"presence": {
|
||||
"events": []
|
||||
},
|
||||
|
@ -611,9 +612,8 @@ func main() {
|
|||
}
|
||||
}`)
|
||||
|
||||
// $ curl -XPUT -d '{"msgtype":"m.text","body":"why did you kick charlie"}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/send/m.room.message/3?access_token=@bob:localhost"
|
||||
// $ curl -XPUT -d '{"name":"No Charlies"}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/state/m.room.name?access_token=@alice:localhost"
|
||||
writeToRoomServerLog(i17BobMsg, i18StateAliceRoomName)
|
||||
writeToRoomServerLog(i18StateAliceRoomName)
|
||||
|
||||
// Check that users don't see state changes in rooms after they have left
|
||||
testSyncServer(syncServerCmdChan, "@charlie:localhost", "17", `{
|
||||
|
|
|
@ -50,20 +50,17 @@ const insertEventSQL = "" +
|
|||
"INSERT INTO output_room_events (room_id, event_id, event_json, add_state_ids, remove_state_ids) VALUES ($1, $2, $3, $4, $5) RETURNING id"
|
||||
|
||||
const selectEventsSQL = "" +
|
||||
"SELECT event_json FROM output_room_events WHERE event_id = ANY($1)"
|
||||
|
||||
const selectEventsInRangeSQL = "" +
|
||||
"SELECT event_json FROM output_room_events WHERE id > $1 AND id <= $2"
|
||||
"SELECT id, event_json FROM output_room_events WHERE event_id = ANY($1)"
|
||||
|
||||
const selectRecentEventsSQL = "" +
|
||||
"SELECT event_json FROM output_room_events WHERE room_id = $1 AND id > $2 AND id <= $3 ORDER BY id DESC LIMIT $4"
|
||||
"SELECT id, event_json FROM output_room_events WHERE room_id = $1 AND id > $2 AND id <= $3 ORDER BY id DESC LIMIT $4"
|
||||
|
||||
const selectMaxIDSQL = "" +
|
||||
"SELECT MAX(id) FROM output_room_events"
|
||||
|
||||
// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
|
||||
const selectStateInRangeSQL = "" +
|
||||
"SELECT event_json, add_state_ids, remove_state_ids FROM output_room_events" +
|
||||
"SELECT id, event_json, add_state_ids, remove_state_ids FROM output_room_events" +
|
||||
" WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" +
|
||||
" ORDER BY id ASC"
|
||||
|
||||
|
@ -71,7 +68,6 @@ type outputRoomEventsStatements struct {
|
|||
insertEventStmt *sql.Stmt
|
||||
selectEventsStmt *sql.Stmt
|
||||
selectMaxIDStmt *sql.Stmt
|
||||
selectEventsInRangeStmt *sql.Stmt
|
||||
selectRecentEventsStmt *sql.Stmt
|
||||
selectStateInRangeStmt *sql.Stmt
|
||||
}
|
||||
|
@ -90,9 +86,6 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
|
|||
if s.selectMaxIDStmt, err = db.Prepare(selectMaxIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectEventsInRangeStmt, err = db.Prepare(selectEventsInRangeSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectRecentEventsStmt, err = db.Prepare(selectRecentEventsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -105,7 +98,7 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
|
|||
// StateBetween returns the state events between the two given stream positions, exclusive of oldPos, inclusive of newPos.
|
||||
// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
|
||||
// two positions, only the most recent state is returned.
|
||||
func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos types.StreamPosition) (map[string][]gomatrixserverlib.Event, error) {
|
||||
func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos types.StreamPosition) (map[string][]streamEvent, error) {
|
||||
rows, err := txn.Stmt(s.selectStateInRangeStmt).Query(oldPos, newPos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -115,18 +108,19 @@ func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos ty
|
|||
// - For each room ID, build up an array of event IDs which represents cumulative adds/removes
|
||||
// For each room, map cumulative event IDs to events and return. This may need to a batch SELECT based on event ID
|
||||
// if they aren't in the event ID cache. We don't handle state deletion yet.
|
||||
eventIDToEvent := make(map[string]gomatrixserverlib.Event)
|
||||
eventIDToEvent := make(map[string]streamEvent)
|
||||
|
||||
// RoomID => A set (map[string]bool) of state event IDs which are between the two positions
|
||||
stateNeeded := make(map[string]map[string]bool)
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
streamPos int64
|
||||
eventBytes []byte
|
||||
addIDs pq.StringArray
|
||||
delIDs pq.StringArray
|
||||
)
|
||||
if err := rows.Scan(&eventBytes, &addIDs, &delIDs); err != nil {
|
||||
if err := rows.Scan(&streamPos, &eventBytes, &addIDs, &delIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Sanity check for deleted state and whine if we see it. We don't need to do anything
|
||||
|
@ -157,7 +151,7 @@ func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos ty
|
|||
}
|
||||
stateNeeded[ev.RoomID()] = needSet
|
||||
|
||||
eventIDToEvent[ev.EventID()] = ev
|
||||
eventIDToEvent[ev.EventID()] = streamEvent{ev, types.StreamPosition(streamPos)}
|
||||
}
|
||||
|
||||
return s.fetchStateEvents(txn, stateNeeded, eventIDToEvent)
|
||||
|
@ -165,8 +159,8 @@ func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos ty
|
|||
|
||||
// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database.
|
||||
// Returns a map of room ID to list of events.
|
||||
func (s *outputRoomEventsStatements) fetchStateEvents(txn *sql.Tx, roomIDToEventIDSet map[string]map[string]bool, eventIDToEvent map[string]gomatrixserverlib.Event) (map[string][]gomatrixserverlib.Event, error) {
|
||||
stateBetween := make(map[string][]gomatrixserverlib.Event)
|
||||
func (s *outputRoomEventsStatements) fetchStateEvents(txn *sql.Tx, roomIDToEventIDSet map[string]map[string]bool, eventIDToEvent map[string]streamEvent) (map[string][]streamEvent, error) {
|
||||
stateBetween := make(map[string][]streamEvent)
|
||||
missingEvents := make(map[string][]string)
|
||||
for roomID, ids := range roomIDToEventIDSet {
|
||||
events := stateBetween[roomID]
|
||||
|
@ -232,7 +226,7 @@ func (s *outputRoomEventsStatements) InsertEvent(txn *sql.Tx, event *gomatrixser
|
|||
}
|
||||
|
||||
// RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'.
|
||||
func (s *outputRoomEventsStatements) RecentEventsInRoom(txn *sql.Tx, roomID string, fromPos, toPos types.StreamPosition, limit int) ([]gomatrixserverlib.Event, error) {
|
||||
func (s *outputRoomEventsStatements) RecentEventsInRoom(txn *sql.Tx, roomID string, fromPos, toPos types.StreamPosition, limit int) ([]streamEvent, error) {
|
||||
rows, err := s.selectRecentEventsStmt.Query(roomID, fromPos, toPos, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -249,7 +243,7 @@ func (s *outputRoomEventsStatements) RecentEventsInRoom(txn *sql.Tx, roomID stri
|
|||
|
||||
// Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing
|
||||
// from the database.
|
||||
func (s *outputRoomEventsStatements) Events(txn *sql.Tx, eventIDs []string) ([]gomatrixserverlib.Event, error) {
|
||||
func (s *outputRoomEventsStatements) Events(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) {
|
||||
rows, err := txn.Stmt(s.selectEventsStmt).Query(pq.StringArray(eventIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -266,11 +260,14 @@ func (s *outputRoomEventsStatements) Events(txn *sql.Tx, eventIDs []string) ([]g
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.Event, error) {
|
||||
var result []gomatrixserverlib.Event
|
||||
func rowsToEvents(rows *sql.Rows) ([]streamEvent, error) {
|
||||
var result []streamEvent
|
||||
for rows.Next() {
|
||||
var eventBytes []byte
|
||||
if err := rows.Scan(&eventBytes); err != nil {
|
||||
var (
|
||||
streamPos int64
|
||||
eventBytes []byte
|
||||
)
|
||||
if err := rows.Scan(&streamPos, &eventBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO: Handle redacted events
|
||||
|
@ -278,12 +275,12 @@ func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.Event, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, ev)
|
||||
result = append(result, streamEvent{ev, types.StreamPosition(streamPos)})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func reverseEvents(input []gomatrixserverlib.Event) (output []gomatrixserverlib.Event) {
|
||||
func reverseEvents(input []streamEvent) (output []streamEvent) {
|
||||
for i := len(input) - 1; i >= 0; i-- {
|
||||
output = append(output, input[i])
|
||||
}
|
||||
|
|
|
@ -29,6 +29,15 @@ type stateDelta struct {
|
|||
roomID string
|
||||
stateEvents []gomatrixserverlib.Event
|
||||
membership string
|
||||
// The stream position of the latest membership event for this user, if applicable.
|
||||
// Can be 0 if there is no membership event in this delta.
|
||||
membershipPos types.StreamPosition
|
||||
}
|
||||
|
||||
// Same as gomatrixserverlib.Event but also has the stream position for this event.
|
||||
type streamEvent struct {
|
||||
gomatrixserverlib.Event
|
||||
streamPosition types.StreamPosition
|
||||
}
|
||||
|
||||
// SyncServerDatabase represents a sync server database
|
||||
|
@ -99,7 +108,7 @@ func (d *SyncServerDatabase) WriteEvent(ev *gomatrixserverlib.Event, addStateEve
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return d.roomstate.UpdateRoomState(txn, added, removeStateEventIDs)
|
||||
return d.roomstate.UpdateRoomState(txn, streamEventsToEvents(added), removeStateEventIDs)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
@ -137,10 +146,21 @@ func (d *SyncServerDatabase) IncrementalSync(userID string, fromPos, toPos types
|
|||
|
||||
res = types.NewResponse(toPos)
|
||||
for _, delta := range deltas {
|
||||
recentEvents, err := d.events.RecentEventsInRoom(txn, delta.roomID, fromPos, toPos, numRecentEventsPerRoom)
|
||||
endPos := toPos
|
||||
if delta.membershipPos > 0 && delta.membership == "leave" {
|
||||
// make sure we don't leak recent events after the leave event.
|
||||
// TODO: History visibility makes this somewhat complex to handle correctly. For example:
|
||||
// TODO: This doesn't work for join -> leave in a single /sync request (see events prior to join).
|
||||
// TODO: This will fail on join -> leave -> sensitive msg -> join -> leave
|
||||
// in a single /sync request
|
||||
// This is all "okay" assuming history_visibility == "shared" which it is by default.
|
||||
endPos = delta.membershipPos
|
||||
}
|
||||
recentStreamEvents, err := d.events.RecentEventsInRoom(txn, delta.roomID, fromPos, endPos, numRecentEventsPerRoom)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
recentEvents := streamEventsToEvents(recentStreamEvents)
|
||||
delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back
|
||||
|
||||
switch delta.membership {
|
||||
|
@ -198,10 +218,11 @@ func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom
|
|||
}
|
||||
// TODO: When filters are added, we may need to call this multiple times to get enough events.
|
||||
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
|
||||
recentEvents, err := d.events.RecentEventsInRoom(txn, roomID, types.StreamPosition(0), pos, numRecentEventsPerRoom)
|
||||
recentStreamEvents, err := d.events.RecentEventsInRoom(txn, roomID, types.StreamPosition(0), pos, numRecentEventsPerRoom)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
recentEvents := streamEventsToEvents(recentStreamEvents)
|
||||
|
||||
stateEvents = removeDuplicates(stateEvents, recentEvents)
|
||||
jr := types.NewJoinResponse()
|
||||
|
@ -246,14 +267,14 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for roomID, stateEvents := range state {
|
||||
for _, ev := range stateEvents {
|
||||
for roomID, stateStreamEvents := range state {
|
||||
for _, ev := range stateStreamEvents {
|
||||
// TODO: Currently this will incorrectly add rooms which were ALREADY joined but they sent another no-op join event.
|
||||
// We should be checking if the user was already joined at fromPos and not proceed if so. As a result of this,
|
||||
// dupe join events will result in the entire room state coming down to the client again. This is added in
|
||||
// the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to
|
||||
// the timeline.
|
||||
if membership := getMembershipFromEvent(&ev, userID); membership != "" {
|
||||
if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" {
|
||||
if membership == "join" {
|
||||
// send full room state down instead of a delta
|
||||
var allState []gomatrixserverlib.Event
|
||||
|
@ -261,13 +282,18 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
state[roomID] = allState
|
||||
s := make([]streamEvent, len(allState))
|
||||
for i := 0; i < len(s); i++ {
|
||||
s[i] = streamEvent{allState[i], types.StreamPosition(0)}
|
||||
}
|
||||
state[roomID] = s
|
||||
continue // we'll add this room in when we do joined rooms
|
||||
}
|
||||
|
||||
deltas = append(deltas, stateDelta{
|
||||
membership: membership,
|
||||
stateEvents: stateEvents,
|
||||
membershipPos: ev.streamPosition,
|
||||
stateEvents: streamEventsToEvents(stateStreamEvents),
|
||||
roomID: roomID,
|
||||
})
|
||||
break
|
||||
|
@ -283,7 +309,7 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St
|
|||
for _, joinedRoomID := range joinedRoomIDs {
|
||||
deltas = append(deltas, stateDelta{
|
||||
membership: "join",
|
||||
stateEvents: state[joinedRoomID],
|
||||
stateEvents: streamEventsToEvents(state[joinedRoomID]),
|
||||
roomID: joinedRoomID,
|
||||
})
|
||||
}
|
||||
|
@ -291,6 +317,14 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St
|
|||
return deltas, nil
|
||||
}
|
||||
|
||||
func streamEventsToEvents(in []streamEvent) []gomatrixserverlib.Event {
|
||||
out := make([]gomatrixserverlib.Event, len(in))
|
||||
for i := 0; i < len(in); i++ {
|
||||
out[i] = in[i].Event
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// There may be some overlap where events in stateEvents are already in recentEvents, so filter
|
||||
// them out so we don't include them twice in the /sync response. They should be in recentEvents
|
||||
// only, so clients get to the correct state once they have rolled forward.
|
||||
|
|
Loading…
Reference in a new issue