/sync bugfix: Check transitions to 'leave' do not leak events afterwards (#105)

This commit is contained in:
Kegsay 2017-05-17 16:21:27 +01:00 committed by GitHub
parent d5a44fd3e8
commit ccd0eb2851
3 changed files with 76 additions and 45 deletions

View file

@ -547,14 +547,15 @@ func main() {
// $ curl -XPUT -d '{"membership":"join"}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/state/m.room.member/@charlie:localhost?access_token=@charlie:localhost" // $ curl -XPUT -d '{"membership":"join"}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/state/m.room.member/@charlie:localhost?access_token=@charlie:localhost"
// $ curl -XPUT -d '{"msgtype":"m.text","body":"not charlie..."}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/send/m.room.message/3?access_token=@alice:localhost" // $ curl -XPUT -d '{"msgtype":"m.text","body":"not charlie..."}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/send/m.room.message/3?access_token=@alice:localhost"
// $ curl -XPUT -d '{"membership":"leave"}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/state/m.room.member/@charlie:localhost?access_token=@alice:localhost" // $ curl -XPUT -d '{"membership":"leave"}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/state/m.room.member/@charlie:localhost?access_token=@alice:localhost"
writeToRoomServerLog(i14StateCharlieJoin, i15AliceMsg, i16StateAliceKickCharlie) // $ curl -XPUT -d '{"msgtype":"m.text","body":"why did you kick charlie"}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/send/m.room.message/3?access_token=@bob:localhost"
writeToRoomServerLog(i14StateCharlieJoin, i15AliceMsg, i16StateAliceKickCharlie, i17BobMsg)
// Check transitions to leave work // Check transitions to leave work
testSyncServer(syncServerCmdChan, "@charlie:localhost", "15", `{ testSyncServer(syncServerCmdChan, "@charlie:localhost", "15", `{
"account_data": { "account_data": {
"events": [] "events": []
}, },
"next_batch": "17", "next_batch": "18",
"presence": { "presence": {
"events": [] "events": []
}, },
@ -586,7 +587,7 @@ func main() {
"account_data": { "account_data": {
"events": [] "events": []
}, },
"next_batch": "17", "next_batch": "18",
"presence": { "presence": {
"events": [] "events": []
}, },
@ -611,9 +612,8 @@ func main() {
} }
}`) }`)
// $ curl -XPUT -d '{"msgtype":"m.text","body":"why did you kick charlie"}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/send/m.room.message/3?access_token=@bob:localhost"
// $ curl -XPUT -d '{"name":"No Charlies"}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/state/m.room.name?access_token=@alice:localhost" // $ curl -XPUT -d '{"name":"No Charlies"}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/state/m.room.name?access_token=@alice:localhost"
writeToRoomServerLog(i17BobMsg, i18StateAliceRoomName) writeToRoomServerLog(i18StateAliceRoomName)
// Check that users don't see state changes in rooms after they have left // Check that users don't see state changes in rooms after they have left
testSyncServer(syncServerCmdChan, "@charlie:localhost", "17", `{ testSyncServer(syncServerCmdChan, "@charlie:localhost", "17", `{

View file

@ -50,20 +50,17 @@ const insertEventSQL = "" +
"INSERT INTO output_room_events (room_id, event_id, event_json, add_state_ids, remove_state_ids) VALUES ($1, $2, $3, $4, $5) RETURNING id" "INSERT INTO output_room_events (room_id, event_id, event_json, add_state_ids, remove_state_ids) VALUES ($1, $2, $3, $4, $5) RETURNING id"
const selectEventsSQL = "" + const selectEventsSQL = "" +
"SELECT event_json FROM output_room_events WHERE event_id = ANY($1)" "SELECT id, event_json FROM output_room_events WHERE event_id = ANY($1)"
const selectEventsInRangeSQL = "" +
"SELECT event_json FROM output_room_events WHERE id > $1 AND id <= $2"
const selectRecentEventsSQL = "" + const selectRecentEventsSQL = "" +
"SELECT event_json FROM output_room_events WHERE room_id = $1 AND id > $2 AND id <= $3 ORDER BY id DESC LIMIT $4" "SELECT id, event_json FROM output_room_events WHERE room_id = $1 AND id > $2 AND id <= $3 ORDER BY id DESC LIMIT $4"
const selectMaxIDSQL = "" + const selectMaxIDSQL = "" +
"SELECT MAX(id) FROM output_room_events" "SELECT MAX(id) FROM output_room_events"
// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id). // In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
const selectStateInRangeSQL = "" + const selectStateInRangeSQL = "" +
"SELECT event_json, add_state_ids, remove_state_ids FROM output_room_events" + "SELECT id, event_json, add_state_ids, remove_state_ids FROM output_room_events" +
" WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" +
" ORDER BY id ASC" " ORDER BY id ASC"
@ -71,7 +68,6 @@ type outputRoomEventsStatements struct {
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt selectEventsStmt *sql.Stmt
selectMaxIDStmt *sql.Stmt selectMaxIDStmt *sql.Stmt
selectEventsInRangeStmt *sql.Stmt
selectRecentEventsStmt *sql.Stmt selectRecentEventsStmt *sql.Stmt
selectStateInRangeStmt *sql.Stmt selectStateInRangeStmt *sql.Stmt
} }
@ -90,9 +86,6 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
if s.selectMaxIDStmt, err = db.Prepare(selectMaxIDSQL); err != nil { if s.selectMaxIDStmt, err = db.Prepare(selectMaxIDSQL); err != nil {
return return
} }
if s.selectEventsInRangeStmt, err = db.Prepare(selectEventsInRangeSQL); err != nil {
return
}
if s.selectRecentEventsStmt, err = db.Prepare(selectRecentEventsSQL); err != nil { if s.selectRecentEventsStmt, err = db.Prepare(selectRecentEventsSQL); err != nil {
return return
} }
@ -105,7 +98,7 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
// StateBetween returns the state events between the two given stream positions, exclusive of oldPos, inclusive of newPos. // StateBetween returns the state events between the two given stream positions, exclusive of oldPos, inclusive of newPos.
// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
// two positions, only the most recent state is returned. // two positions, only the most recent state is returned.
func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos types.StreamPosition) (map[string][]gomatrixserverlib.Event, error) { func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos types.StreamPosition) (map[string][]streamEvent, error) {
rows, err := txn.Stmt(s.selectStateInRangeStmt).Query(oldPos, newPos) rows, err := txn.Stmt(s.selectStateInRangeStmt).Query(oldPos, newPos)
if err != nil { if err != nil {
return nil, err return nil, err
@ -115,18 +108,19 @@ func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos ty
// - For each room ID, build up an array of event IDs which represents cumulative adds/removes // - For each room ID, build up an array of event IDs which represents cumulative adds/removes
// For each room, map cumulative event IDs to events and return. This may need to a batch SELECT based on event ID // For each room, map cumulative event IDs to events and return. This may need to a batch SELECT based on event ID
// if they aren't in the event ID cache. We don't handle state deletion yet. // if they aren't in the event ID cache. We don't handle state deletion yet.
eventIDToEvent := make(map[string]gomatrixserverlib.Event) eventIDToEvent := make(map[string]streamEvent)
// RoomID => A set (map[string]bool) of state event IDs which are between the two positions // RoomID => A set (map[string]bool) of state event IDs which are between the two positions
stateNeeded := make(map[string]map[string]bool) stateNeeded := make(map[string]map[string]bool)
for rows.Next() { for rows.Next() {
var ( var (
streamPos int64
eventBytes []byte eventBytes []byte
addIDs pq.StringArray addIDs pq.StringArray
delIDs pq.StringArray delIDs pq.StringArray
) )
if err := rows.Scan(&eventBytes, &addIDs, &delIDs); err != nil { if err := rows.Scan(&streamPos, &eventBytes, &addIDs, &delIDs); err != nil {
return nil, err return nil, err
} }
// Sanity check for deleted state and whine if we see it. We don't need to do anything // Sanity check for deleted state and whine if we see it. We don't need to do anything
@ -157,7 +151,7 @@ func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos ty
} }
stateNeeded[ev.RoomID()] = needSet stateNeeded[ev.RoomID()] = needSet
eventIDToEvent[ev.EventID()] = ev eventIDToEvent[ev.EventID()] = streamEvent{ev, types.StreamPosition(streamPos)}
} }
return s.fetchStateEvents(txn, stateNeeded, eventIDToEvent) return s.fetchStateEvents(txn, stateNeeded, eventIDToEvent)
@ -165,8 +159,8 @@ func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos ty
// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. // fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database.
// Returns a map of room ID to list of events. // Returns a map of room ID to list of events.
func (s *outputRoomEventsStatements) fetchStateEvents(txn *sql.Tx, roomIDToEventIDSet map[string]map[string]bool, eventIDToEvent map[string]gomatrixserverlib.Event) (map[string][]gomatrixserverlib.Event, error) { func (s *outputRoomEventsStatements) fetchStateEvents(txn *sql.Tx, roomIDToEventIDSet map[string]map[string]bool, eventIDToEvent map[string]streamEvent) (map[string][]streamEvent, error) {
stateBetween := make(map[string][]gomatrixserverlib.Event) stateBetween := make(map[string][]streamEvent)
missingEvents := make(map[string][]string) missingEvents := make(map[string][]string)
for roomID, ids := range roomIDToEventIDSet { for roomID, ids := range roomIDToEventIDSet {
events := stateBetween[roomID] events := stateBetween[roomID]
@ -232,7 +226,7 @@ func (s *outputRoomEventsStatements) InsertEvent(txn *sql.Tx, event *gomatrixser
} }
// RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'. // RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'.
func (s *outputRoomEventsStatements) RecentEventsInRoom(txn *sql.Tx, roomID string, fromPos, toPos types.StreamPosition, limit int) ([]gomatrixserverlib.Event, error) { func (s *outputRoomEventsStatements) RecentEventsInRoom(txn *sql.Tx, roomID string, fromPos, toPos types.StreamPosition, limit int) ([]streamEvent, error) {
rows, err := s.selectRecentEventsStmt.Query(roomID, fromPos, toPos, limit) rows, err := s.selectRecentEventsStmt.Query(roomID, fromPos, toPos, limit)
if err != nil { if err != nil {
return nil, err return nil, err
@ -249,7 +243,7 @@ func (s *outputRoomEventsStatements) RecentEventsInRoom(txn *sql.Tx, roomID stri
// Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing // Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing
// from the database. // from the database.
func (s *outputRoomEventsStatements) Events(txn *sql.Tx, eventIDs []string) ([]gomatrixserverlib.Event, error) { func (s *outputRoomEventsStatements) Events(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) {
rows, err := txn.Stmt(s.selectEventsStmt).Query(pq.StringArray(eventIDs)) rows, err := txn.Stmt(s.selectEventsStmt).Query(pq.StringArray(eventIDs))
if err != nil { if err != nil {
return nil, err return nil, err
@ -266,11 +260,14 @@ func (s *outputRoomEventsStatements) Events(txn *sql.Tx, eventIDs []string) ([]g
return result, nil return result, nil
} }
func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.Event, error) { func rowsToEvents(rows *sql.Rows) ([]streamEvent, error) {
var result []gomatrixserverlib.Event var result []streamEvent
for rows.Next() { for rows.Next() {
var eventBytes []byte var (
if err := rows.Scan(&eventBytes); err != nil { streamPos int64
eventBytes []byte
)
if err := rows.Scan(&streamPos, &eventBytes); err != nil {
return nil, err return nil, err
} }
// TODO: Handle redacted events // TODO: Handle redacted events
@ -278,12 +275,12 @@ func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.Event, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
result = append(result, ev) result = append(result, streamEvent{ev, types.StreamPosition(streamPos)})
} }
return result, nil return result, nil
} }
func reverseEvents(input []gomatrixserverlib.Event) (output []gomatrixserverlib.Event) { func reverseEvents(input []streamEvent) (output []streamEvent) {
for i := len(input) - 1; i >= 0; i-- { for i := len(input) - 1; i >= 0; i-- {
output = append(output, input[i]) output = append(output, input[i])
} }

View file

@ -29,6 +29,15 @@ type stateDelta struct {
roomID string roomID string
stateEvents []gomatrixserverlib.Event stateEvents []gomatrixserverlib.Event
membership string membership string
// The stream position of the latest membership event for this user, if applicable.
// Can be 0 if there is no membership event in this delta.
membershipPos types.StreamPosition
}
// Same as gomatrixserverlib.Event but also has the stream position for this event.
type streamEvent struct {
gomatrixserverlib.Event
streamPosition types.StreamPosition
} }
// SyncServerDatabase represents a sync server database // SyncServerDatabase represents a sync server database
@ -99,7 +108,7 @@ func (d *SyncServerDatabase) WriteEvent(ev *gomatrixserverlib.Event, addStateEve
if err != nil { if err != nil {
return err return err
} }
return d.roomstate.UpdateRoomState(txn, added, removeStateEventIDs) return d.roomstate.UpdateRoomState(txn, streamEventsToEvents(added), removeStateEventIDs)
}) })
return return
} }
@ -137,10 +146,21 @@ func (d *SyncServerDatabase) IncrementalSync(userID string, fromPos, toPos types
res = types.NewResponse(toPos) res = types.NewResponse(toPos)
for _, delta := range deltas { for _, delta := range deltas {
recentEvents, err := d.events.RecentEventsInRoom(txn, delta.roomID, fromPos, toPos, numRecentEventsPerRoom) endPos := toPos
if delta.membershipPos > 0 && delta.membership == "leave" {
// make sure we don't leak recent events after the leave event.
// TODO: History visibility makes this somewhat complex to handle correctly. For example:
// TODO: This doesn't work for join -> leave in a single /sync request (see events prior to join).
// TODO: This will fail on join -> leave -> sensitive msg -> join -> leave
// in a single /sync request
// This is all "okay" assuming history_visibility == "shared" which it is by default.
endPos = delta.membershipPos
}
recentStreamEvents, err := d.events.RecentEventsInRoom(txn, delta.roomID, fromPos, endPos, numRecentEventsPerRoom)
if err != nil { if err != nil {
return err return err
} }
recentEvents := streamEventsToEvents(recentStreamEvents)
delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back
switch delta.membership { switch delta.membership {
@ -198,10 +218,11 @@ func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom
} }
// TODO: When filters are added, we may need to call this multiple times to get enough events. // TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
recentEvents, err := d.events.RecentEventsInRoom(txn, roomID, types.StreamPosition(0), pos, numRecentEventsPerRoom) recentStreamEvents, err := d.events.RecentEventsInRoom(txn, roomID, types.StreamPosition(0), pos, numRecentEventsPerRoom)
if err != nil { if err != nil {
return err return err
} }
recentEvents := streamEventsToEvents(recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
@ -246,14 +267,14 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St
if err != nil { if err != nil {
return nil, err return nil, err
} }
for roomID, stateEvents := range state { for roomID, stateStreamEvents := range state {
for _, ev := range stateEvents { for _, ev := range stateStreamEvents {
// TODO: Currently this will incorrectly add rooms which were ALREADY joined but they sent another no-op join event. // TODO: Currently this will incorrectly add rooms which were ALREADY joined but they sent another no-op join event.
// We should be checking if the user was already joined at fromPos and not proceed if so. As a result of this, // We should be checking if the user was already joined at fromPos and not proceed if so. As a result of this,
// dupe join events will result in the entire room state coming down to the client again. This is added in // dupe join events will result in the entire room state coming down to the client again. This is added in
// the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to // the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to
// the timeline. // the timeline.
if membership := getMembershipFromEvent(&ev, userID); membership != "" { if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" {
if membership == "join" { if membership == "join" {
// send full room state down instead of a delta // send full room state down instead of a delta
var allState []gomatrixserverlib.Event var allState []gomatrixserverlib.Event
@ -261,13 +282,18 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St
if err != nil { if err != nil {
return nil, err return nil, err
} }
state[roomID] = allState s := make([]streamEvent, len(allState))
for i := 0; i < len(s); i++ {
s[i] = streamEvent{allState[i], types.StreamPosition(0)}
}
state[roomID] = s
continue // we'll add this room in when we do joined rooms continue // we'll add this room in when we do joined rooms
} }
deltas = append(deltas, stateDelta{ deltas = append(deltas, stateDelta{
membership: membership, membership: membership,
stateEvents: stateEvents, membershipPos: ev.streamPosition,
stateEvents: streamEventsToEvents(stateStreamEvents),
roomID: roomID, roomID: roomID,
}) })
break break
@ -283,7 +309,7 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St
for _, joinedRoomID := range joinedRoomIDs { for _, joinedRoomID := range joinedRoomIDs {
deltas = append(deltas, stateDelta{ deltas = append(deltas, stateDelta{
membership: "join", membership: "join",
stateEvents: state[joinedRoomID], stateEvents: streamEventsToEvents(state[joinedRoomID]),
roomID: joinedRoomID, roomID: joinedRoomID,
}) })
} }
@ -291,6 +317,14 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St
return deltas, nil return deltas, nil
} }
func streamEventsToEvents(in []streamEvent) []gomatrixserverlib.Event {
out := make([]gomatrixserverlib.Event, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[i].Event
}
return out
}
// There may be some overlap where events in stateEvents are already in recentEvents, so filter // There may be some overlap where events in stateEvents are already in recentEvents, so filter
// them out so we don't include them twice in the /sync response. They should be in recentEvents // them out so we don't include them twice in the /sync response. They should be in recentEvents
// only, so clients get to the correct state once they have rolled forward. // only, so clients get to the correct state once they have rolled forward.