Add core SQL/structs for calculating a complete /sync response

This commit is contained in:
Kegan Dougal 2017-04-12 16:49:26 +01:00
parent 203e706b99
commit d935d6db25
4 changed files with 147 additions and 27 deletions

View file

@ -39,9 +39,17 @@ const upsertRoomStateSQL = "" +
const deleteRoomStateByEventIDSQL = "" + const deleteRoomStateByEventIDSQL = "" +
"DELETE FROM current_room_state WHERE event_id = $1" "DELETE FROM current_room_state WHERE event_id = $1"
const selectRoomIDsWithMembershipSQL = "" +
"SELECT room_id FROM current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
const selectCurrentStateSQL = "" +
"SELECT event_json FROM current_room_state WHERE room_id = $1"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
selectRoomIDsWithMembershipStmt *sql.Stmt
selectCurrentStateStmt *sql.Stmt
} }
func (s *currentRoomStateStatements) prepare(db *sql.DB) (err error) { func (s *currentRoomStateStatements) prepare(db *sql.DB) (err error) {
@ -55,8 +63,56 @@ func (s *currentRoomStateStatements) prepare(db *sql.DB) (err error) {
if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil {
return return
} }
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
return return
} }
if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil {
return
}
return
}
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(txn *sql.Tx, userID, membership string) ([]string, error) {
rows, err := txn.Stmt(s.selectRoomIDsWithMembershipStmt).Query(userID, membership)
if err != nil {
return nil, err
}
defer rows.Close()
var result []string
for rows.Next() {
var roomID string
if err := rows.Scan(&roomID); err != nil {
return nil, err
}
result = append(result, roomID)
}
return result, nil
}
// CurrentState returns all the current state events for the given room.
func (s *currentRoomStateStatements) CurrentState(txn *sql.Tx, roomID string) ([]gomatrixserverlib.Event, error) {
rows, err := txn.Stmt(s.selectCurrentStateStmt).Query(roomID)
if err != nil {
return nil, err
}
defer rows.Close()
var result []gomatrixserverlib.Event
for rows.Next() {
var eventBytes []byte
if err := rows.Scan(&eventBytes); err != nil {
return nil, err
}
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(eventBytes, false)
if err != nil {
return nil, err
}
result = append(result, ev)
}
return result, nil
}
func (s *currentRoomStateStatements) UpdateRoomState(txn *sql.Tx, added []gomatrixserverlib.Event, removedEventIDs []string) error { func (s *currentRoomStateStatements) UpdateRoomState(txn *sql.Tx, added []gomatrixserverlib.Event, removedEventIDs []string) error {
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.

View file

@ -39,6 +39,9 @@ const selectEventsSQL = "" +
const selectEventsInRangeSQL = "" + const selectEventsInRangeSQL = "" +
"SELECT event_json FROM output_room_events WHERE id > $1 AND id <= $2" "SELECT event_json FROM output_room_events WHERE id > $1 AND id <= $2"
const selectRecentEventsSQL = "" +
"SELECT event_json FROM output_room_events WHERE room_id = $1 ORDER BY id DESC LIMIT $2"
const selectMaxIDSQL = "" + const selectMaxIDSQL = "" +
"SELECT MAX(id) FROM output_room_events" "SELECT MAX(id) FROM output_room_events"
@ -47,6 +50,7 @@ type outputRoomEventsStatements struct {
selectEventsStmt *sql.Stmt selectEventsStmt *sql.Stmt
selectMaxIDStmt *sql.Stmt selectMaxIDStmt *sql.Stmt
selectEventsInRangeStmt *sql.Stmt selectEventsInRangeStmt *sql.Stmt
selectRecentEventsStmt *sql.Stmt
} }
func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
@ -66,14 +70,22 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
if s.selectEventsInRangeStmt, err = db.Prepare(selectEventsInRangeSQL); err != nil { if s.selectEventsInRangeStmt, err = db.Prepare(selectEventsInRangeSQL); err != nil {
return return
} }
if s.selectRecentEventsStmt, err = db.Prepare(selectRecentEventsSQL); err != nil {
return
}
return return
} }
// MaxID returns the ID of the last inserted event in this table. This should only ever be used at startup, as it will // MaxID returns the ID of the last inserted event in this table. 'txn' is optional. If it is not supplied,
// race with inserting events if it is done afterwards. If there are no inserted events, 0 is returned. // then this function should only ever be used at startup, as it will race with inserting events if it is
func (s *outputRoomEventsStatements) MaxID() (id int64, err error) { // done afterwards. If there are no inserted events, 0 is returned.
func (s *outputRoomEventsStatements) MaxID(txn *sql.Tx) (id int64, err error) {
stmt := s.selectMaxIDStmt
if txn != nil {
stmt = txn.Stmt(stmt)
}
var nullableID sql.NullInt64 var nullableID sql.NullInt64
err = s.selectMaxIDStmt.QueryRow().Scan(&nullableID) err = stmt.QueryRow().Scan(&nullableID)
if nullableID.Valid { if nullableID.Valid {
id = nullableID.Int64 id = nullableID.Int64
} }
@ -89,23 +101,14 @@ func (s *outputRoomEventsStatements) InRange(oldPos, newPos int64) ([]gomatrixse
} }
defer rows.Close() defer rows.Close()
var result []gomatrixserverlib.Event result, err := rowsToEvents(rows)
var i int64
for ; rows.Next(); i++ {
var eventBytes []byte
if err := rows.Scan(&eventBytes); err != nil {
return nil, err
}
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(eventBytes, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
result = append(result, ev)
}
// Expect one event per position, exclusive of old. eg old=3, new=5, expect 4,5 so 2 events. // Expect one event per position, exclusive of old. eg old=3, new=5, expect 4,5 so 2 events.
wantNum := (newPos - oldPos) wantNum := int(newPos - oldPos)
if i != wantNum { if len(result) != wantNum {
return nil, fmt.Errorf("failed to map all positions to events: (got %d, wanted, %d)", i, wantNum) return nil, fmt.Errorf("failed to map all positions to events: (got %d, wanted, %d)", len(result), wantNum)
} }
return result, nil return result, nil
} }
@ -119,6 +122,16 @@ func (s *outputRoomEventsStatements) InsertEvent(txn *sql.Tx, event *gomatrixser
return return
} }
// RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'.
func (s *outputRoomEventsStatements) RecentEventsInRoom(txn *sql.Tx, roomID string, limit int) ([]gomatrixserverlib.Event, error) {
rows, err := s.selectRecentEventsStmt.Query(roomID, limit)
if err != nil {
return nil, err
}
defer rows.Close()
return rowsToEvents(rows)
}
// Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing // Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing
// from the database. // from the database.
func (s *outputRoomEventsStatements) Events(txn *sql.Tx, eventIDs []string) ([]gomatrixserverlib.Event, error) { func (s *outputRoomEventsStatements) Events(txn *sql.Tx, eventIDs []string) ([]gomatrixserverlib.Event, error) {
@ -127,10 +140,20 @@ func (s *outputRoomEventsStatements) Events(txn *sql.Tx, eventIDs []string) ([]g
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
result, err := rowsToEvents(rows)
if err != nil {
return nil, err
}
if len(result) != len(eventIDs) {
return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(result), len(eventIDs))
}
return result, nil
}
func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.Event, error) {
var result []gomatrixserverlib.Event var result []gomatrixserverlib.Event
i := 0 for rows.Next() {
for ; rows.Next(); i++ {
var eventBytes []byte var eventBytes []byte
if err := rows.Scan(&eventBytes); err != nil { if err := rows.Scan(&eventBytes); err != nil {
return nil, err return nil, err
@ -141,8 +164,5 @@ func (s *outputRoomEventsStatements) Events(txn *sql.Tx, eventIDs []string) ([]g
} }
result = append(result, ev) result = append(result, ev)
} }
if i != len(eventIDs) {
return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", i, len(eventIDs))
}
return result, nil return result, nil
} }

View file

@ -89,13 +89,51 @@ func (d *SyncServerDatabase) SetPartitionOffset(topic string, partition int32, o
// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. // SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet.
func (d *SyncServerDatabase) SyncStreamPosition() (types.StreamPosition, error) { func (d *SyncServerDatabase) SyncStreamPosition() (types.StreamPosition, error) {
id, err := d.events.MaxID() id, err := d.events.MaxID(nil)
if err != nil { if err != nil {
return types.StreamPosition(0), err return types.StreamPosition(0), err
} }
return types.StreamPosition(id), nil return types.StreamPosition(id), nil
} }
// CompleteSync returns a map of room ID to RoomData.
func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom int) (pos types.StreamPosition, data map[string]types.RoomData, returnErr error) {
// This needs to be all done in a transaction as we need to do multiple SELECTs, and we need to have
// a consistent view of the database throughout. This includes extracting the sync stream position.
returnErr = runTransaction(d.db, func(txn *sql.Tx) error {
// Get the current stream position which we will base the sync response on.
id, err := d.events.MaxID(txn)
if err != nil {
return err
}
pos = types.StreamPosition(id)
// Extract room state and recent events for all rooms the user is joined to.
roomIDs, err := d.roomstate.SelectRoomIDsWithMembership(txn, userID, "join")
if err != nil {
return err
}
for _, roomID := range roomIDs {
stateEvents, err := d.roomstate.CurrentState(txn, roomID)
if err != nil {
return err
}
// TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
recentEvents, err := d.events.RecentEventsInRoom(txn, roomID, numRecentEventsPerRoom)
if err != nil {
return err
}
data[roomID] = types.RoomData{
State: stateEvents,
RecentEvents: recentEvents,
}
}
return nil
})
return
}
// EventsInRange returns all events in the given range, exclusive of oldPos, inclusive of newPos. // EventsInRange returns all events in the given range, exclusive of oldPos, inclusive of newPos.
func (d *SyncServerDatabase) EventsInRange(oldPos, newPos types.StreamPosition) ([]gomatrixserverlib.Event, error) { func (d *SyncServerDatabase) EventsInRange(oldPos, newPos types.StreamPosition) ([]gomatrixserverlib.Event, error) {
return d.events.InRange(int64(oldPos), int64(newPos)) return d.events.InRange(int64(oldPos), int64(newPos))

View file

@ -14,6 +14,12 @@ func (sp StreamPosition) String() string {
return strconv.FormatInt(int64(sp), 10) return strconv.FormatInt(int64(sp), 10)
} }
// RoomData represents the data for a room suitable for building a sync response from.
type RoomData struct {
State []gomatrixserverlib.Event
RecentEvents []gomatrixserverlib.Event
}
// Response represents a /sync API response. See https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-sync // Response represents a /sync API response. See https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-sync
type Response struct { type Response struct {
NextBatch string `json:"next_batch"` NextBatch string `json:"next_batch"`