/sync bugfix: Check transitions to 'leave' do not leak events afterwards (#105)

This commit is contained in:
Kegsay 2017-05-17 16:21:27 +01:00 committed by GitHub
parent d5a44fd3e8
commit ccd0eb2851
3 changed files with 76 additions and 45 deletions

View file

@ -547,14 +547,15 @@ func main() {
// $ curl -XPUT -d '{"membership":"join"}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/state/m.room.member/@charlie:localhost?access_token=@charlie:localhost"
// $ curl -XPUT -d '{"msgtype":"m.text","body":"not charlie..."}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/send/m.room.message/3?access_token=@alice:localhost"
// $ curl -XPUT -d '{"membership":"leave"}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/state/m.room.member/@charlie:localhost?access_token=@alice:localhost"
writeToRoomServerLog(i14StateCharlieJoin, i15AliceMsg, i16StateAliceKickCharlie)
// $ curl -XPUT -d '{"msgtype":"m.text","body":"why did you kick charlie"}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/send/m.room.message/3?access_token=@bob:localhost"
writeToRoomServerLog(i14StateCharlieJoin, i15AliceMsg, i16StateAliceKickCharlie, i17BobMsg)
// Check transitions to leave work
testSyncServer(syncServerCmdChan, "@charlie:localhost", "15", `{
"account_data": {
"events": []
},
"next_batch": "17",
"next_batch": "18",
"presence": {
"events": []
},
@ -586,7 +587,7 @@ func main() {
"account_data": {
"events": []
},
"next_batch": "17",
"next_batch": "18",
"presence": {
"events": []
},
@ -611,9 +612,8 @@ func main() {
}
}`)
// $ curl -XPUT -d '{"msgtype":"m.text","body":"why did you kick charlie"}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/send/m.room.message/3?access_token=@bob:localhost"
// $ curl -XPUT -d '{"name":"No Charlies"}' "http://localhost:8009/_matrix/client/r0/rooms/%21PjrbIMW2cIiaYF4t:localhost/state/m.room.name?access_token=@alice:localhost"
writeToRoomServerLog(i17BobMsg, i18StateAliceRoomName)
writeToRoomServerLog(i18StateAliceRoomName)
// Check that users don't see state changes in rooms after they have left
testSyncServer(syncServerCmdChan, "@charlie:localhost", "17", `{

View file

@ -50,20 +50,17 @@ const insertEventSQL = "" +
"INSERT INTO output_room_events (room_id, event_id, event_json, add_state_ids, remove_state_ids) VALUES ($1, $2, $3, $4, $5) RETURNING id"
const selectEventsSQL = "" +
"SELECT event_json FROM output_room_events WHERE event_id = ANY($1)"
const selectEventsInRangeSQL = "" +
"SELECT event_json FROM output_room_events WHERE id > $1 AND id <= $2"
"SELECT id, event_json FROM output_room_events WHERE event_id = ANY($1)"
const selectRecentEventsSQL = "" +
"SELECT event_json FROM output_room_events WHERE room_id = $1 AND id > $2 AND id <= $3 ORDER BY id DESC LIMIT $4"
"SELECT id, event_json FROM output_room_events WHERE room_id = $1 AND id > $2 AND id <= $3 ORDER BY id DESC LIMIT $4"
const selectMaxIDSQL = "" +
"SELECT MAX(id) FROM output_room_events"
// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
const selectStateInRangeSQL = "" +
"SELECT event_json, add_state_ids, remove_state_ids FROM output_room_events" +
"SELECT id, event_json, add_state_ids, remove_state_ids FROM output_room_events" +
" WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" +
" ORDER BY id ASC"
@ -71,7 +68,6 @@ type outputRoomEventsStatements struct {
insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt
selectMaxIDStmt *sql.Stmt
selectEventsInRangeStmt *sql.Stmt
selectRecentEventsStmt *sql.Stmt
selectStateInRangeStmt *sql.Stmt
}
@ -90,9 +86,6 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
if s.selectMaxIDStmt, err = db.Prepare(selectMaxIDSQL); err != nil {
return
}
if s.selectEventsInRangeStmt, err = db.Prepare(selectEventsInRangeSQL); err != nil {
return
}
if s.selectRecentEventsStmt, err = db.Prepare(selectRecentEventsSQL); err != nil {
return
}
@ -105,7 +98,7 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
// StateBetween returns the state events between the two given stream positions, exclusive of oldPos, inclusive of newPos.
// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
// two positions, only the most recent state is returned.
func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos types.StreamPosition) (map[string][]gomatrixserverlib.Event, error) {
func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos types.StreamPosition) (map[string][]streamEvent, error) {
rows, err := txn.Stmt(s.selectStateInRangeStmt).Query(oldPos, newPos)
if err != nil {
return nil, err
@ -115,18 +108,19 @@ func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos ty
// - For each room ID, build up an array of event IDs which represents cumulative adds/removes
// For each room, map cumulative event IDs to events and return. This may need to a batch SELECT based on event ID
// if they aren't in the event ID cache. We don't handle state deletion yet.
eventIDToEvent := make(map[string]gomatrixserverlib.Event)
eventIDToEvent := make(map[string]streamEvent)
// RoomID => A set (map[string]bool) of state event IDs which are between the two positions
stateNeeded := make(map[string]map[string]bool)
for rows.Next() {
var (
streamPos int64
eventBytes []byte
addIDs pq.StringArray
delIDs pq.StringArray
)
if err := rows.Scan(&eventBytes, &addIDs, &delIDs); err != nil {
if err := rows.Scan(&streamPos, &eventBytes, &addIDs, &delIDs); err != nil {
return nil, err
}
// Sanity check for deleted state and whine if we see it. We don't need to do anything
@ -157,7 +151,7 @@ func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos ty
}
stateNeeded[ev.RoomID()] = needSet
eventIDToEvent[ev.EventID()] = ev
eventIDToEvent[ev.EventID()] = streamEvent{ev, types.StreamPosition(streamPos)}
}
return s.fetchStateEvents(txn, stateNeeded, eventIDToEvent)
@ -165,8 +159,8 @@ func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos ty
// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database.
// Returns a map of room ID to list of events.
func (s *outputRoomEventsStatements) fetchStateEvents(txn *sql.Tx, roomIDToEventIDSet map[string]map[string]bool, eventIDToEvent map[string]gomatrixserverlib.Event) (map[string][]gomatrixserverlib.Event, error) {
stateBetween := make(map[string][]gomatrixserverlib.Event)
func (s *outputRoomEventsStatements) fetchStateEvents(txn *sql.Tx, roomIDToEventIDSet map[string]map[string]bool, eventIDToEvent map[string]streamEvent) (map[string][]streamEvent, error) {
stateBetween := make(map[string][]streamEvent)
missingEvents := make(map[string][]string)
for roomID, ids := range roomIDToEventIDSet {
events := stateBetween[roomID]
@ -232,7 +226,7 @@ func (s *outputRoomEventsStatements) InsertEvent(txn *sql.Tx, event *gomatrixser
}
// RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'.
func (s *outputRoomEventsStatements) RecentEventsInRoom(txn *sql.Tx, roomID string, fromPos, toPos types.StreamPosition, limit int) ([]gomatrixserverlib.Event, error) {
func (s *outputRoomEventsStatements) RecentEventsInRoom(txn *sql.Tx, roomID string, fromPos, toPos types.StreamPosition, limit int) ([]streamEvent, error) {
rows, err := s.selectRecentEventsStmt.Query(roomID, fromPos, toPos, limit)
if err != nil {
return nil, err
@ -249,7 +243,7 @@ func (s *outputRoomEventsStatements) RecentEventsInRoom(txn *sql.Tx, roomID stri
// Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing
// from the database.
func (s *outputRoomEventsStatements) Events(txn *sql.Tx, eventIDs []string) ([]gomatrixserverlib.Event, error) {
func (s *outputRoomEventsStatements) Events(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) {
rows, err := txn.Stmt(s.selectEventsStmt).Query(pq.StringArray(eventIDs))
if err != nil {
return nil, err
@ -266,11 +260,14 @@ func (s *outputRoomEventsStatements) Events(txn *sql.Tx, eventIDs []string) ([]g
return result, nil
}
func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.Event, error) {
var result []gomatrixserverlib.Event
func rowsToEvents(rows *sql.Rows) ([]streamEvent, error) {
var result []streamEvent
for rows.Next() {
var eventBytes []byte
if err := rows.Scan(&eventBytes); err != nil {
var (
streamPos int64
eventBytes []byte
)
if err := rows.Scan(&streamPos, &eventBytes); err != nil {
return nil, err
}
// TODO: Handle redacted events
@ -278,12 +275,12 @@ func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.Event, error) {
if err != nil {
return nil, err
}
result = append(result, ev)
result = append(result, streamEvent{ev, types.StreamPosition(streamPos)})
}
return result, nil
}
func reverseEvents(input []gomatrixserverlib.Event) (output []gomatrixserverlib.Event) {
func reverseEvents(input []streamEvent) (output []streamEvent) {
for i := len(input) - 1; i >= 0; i-- {
output = append(output, input[i])
}

View file

@ -29,6 +29,15 @@ type stateDelta struct {
roomID string
stateEvents []gomatrixserverlib.Event
membership string
// The stream position of the latest membership event for this user, if applicable.
// Can be 0 if there is no membership event in this delta.
membershipPos types.StreamPosition
}
// Same as gomatrixserverlib.Event but also has the stream position for this event.
type streamEvent struct {
gomatrixserverlib.Event
streamPosition types.StreamPosition
}
// SyncServerDatabase represents a sync server database
@ -99,7 +108,7 @@ func (d *SyncServerDatabase) WriteEvent(ev *gomatrixserverlib.Event, addStateEve
if err != nil {
return err
}
return d.roomstate.UpdateRoomState(txn, added, removeStateEventIDs)
return d.roomstate.UpdateRoomState(txn, streamEventsToEvents(added), removeStateEventIDs)
})
return
}
@ -137,10 +146,21 @@ func (d *SyncServerDatabase) IncrementalSync(userID string, fromPos, toPos types
res = types.NewResponse(toPos)
for _, delta := range deltas {
recentEvents, err := d.events.RecentEventsInRoom(txn, delta.roomID, fromPos, toPos, numRecentEventsPerRoom)
endPos := toPos
if delta.membershipPos > 0 && delta.membership == "leave" {
// make sure we don't leak recent events after the leave event.
// TODO: History visibility makes this somewhat complex to handle correctly. For example:
// TODO: This doesn't work for join -> leave in a single /sync request (see events prior to join).
// TODO: This will fail on join -> leave -> sensitive msg -> join -> leave
// in a single /sync request
// This is all "okay" assuming history_visibility == "shared" which it is by default.
endPos = delta.membershipPos
}
recentStreamEvents, err := d.events.RecentEventsInRoom(txn, delta.roomID, fromPos, endPos, numRecentEventsPerRoom)
if err != nil {
return err
}
recentEvents := streamEventsToEvents(recentStreamEvents)
delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back
switch delta.membership {
@ -198,10 +218,11 @@ func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom
}
// TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
recentEvents, err := d.events.RecentEventsInRoom(txn, roomID, types.StreamPosition(0), pos, numRecentEventsPerRoom)
recentStreamEvents, err := d.events.RecentEventsInRoom(txn, roomID, types.StreamPosition(0), pos, numRecentEventsPerRoom)
if err != nil {
return err
}
recentEvents := streamEventsToEvents(recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents)
jr := types.NewJoinResponse()
@ -246,14 +267,14 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St
if err != nil {
return nil, err
}
for roomID, stateEvents := range state {
for _, ev := range stateEvents {
for roomID, stateStreamEvents := range state {
for _, ev := range stateStreamEvents {
// TODO: Currently this will incorrectly add rooms which were ALREADY joined but they sent another no-op join event.
// We should be checking if the user was already joined at fromPos and not proceed if so. As a result of this,
// dupe join events will result in the entire room state coming down to the client again. This is added in
// the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to
// the timeline.
if membership := getMembershipFromEvent(&ev, userID); membership != "" {
if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" {
if membership == "join" {
// send full room state down instead of a delta
var allState []gomatrixserverlib.Event
@ -261,13 +282,18 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St
if err != nil {
return nil, err
}
state[roomID] = allState
s := make([]streamEvent, len(allState))
for i := 0; i < len(s); i++ {
s[i] = streamEvent{allState[i], types.StreamPosition(0)}
}
state[roomID] = s
continue // we'll add this room in when we do joined rooms
}
deltas = append(deltas, stateDelta{
membership: membership,
stateEvents: stateEvents,
membershipPos: ev.streamPosition,
stateEvents: streamEventsToEvents(stateStreamEvents),
roomID: roomID,
})
break
@ -283,7 +309,7 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St
for _, joinedRoomID := range joinedRoomIDs {
deltas = append(deltas, stateDelta{
membership: "join",
stateEvents: state[joinedRoomID],
stateEvents: streamEventsToEvents(state[joinedRoomID]),
roomID: joinedRoomID,
})
}
@ -291,6 +317,14 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St
return deltas, nil
}
func streamEventsToEvents(in []streamEvent) []gomatrixserverlib.Event {
out := make([]gomatrixserverlib.Event, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[i].Event
}
return out
}
// There may be some overlap where events in stateEvents are already in recentEvents, so filter
// them out so we don't include them twice in the /sync response. They should be in recentEvents
// only, so clients get to the correct state once they have rolled forward.