From 8d9bdda95c04c8a794e3a77788d4c6c3081c3ff7 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 28 Jan 2020 16:53:01 +0000 Subject: [PATCH] Mostly not-bad support for SQLite in syncapi (although there are problems where lots of events get classed incorrectly as backward extremities, probably because of IN/ANY clauses that are badly supported) --- syncapi/storage/sqlite3/account_data_table.go | 24 +++-- .../sqlite3/backward_extremities_table.go | 29 ++--- .../sqlite3/current_room_state_table.go | 30 ++---- syncapi/storage/sqlite3/invites_table.go | 44 ++++---- .../sqlite3/output_room_events_table.go | 102 ++++++++++-------- .../output_room_events_topology_table.go | 38 ++++--- syncapi/storage/sqlite3/stream_id_table.go | 58 ++++++++++ syncapi/storage/sqlite3/syncserver.go | 56 ++++++---- 8 files changed, 240 insertions(+), 141 deletions(-) create mode 100644 syncapi/storage/sqlite3/stream_id_table.go diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 7d9f882bb..8ebf79bdd 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -33,33 +33,32 @@ CREATE TABLE IF NOT EXISTS syncapi_account_data_type ( type TEXT NOT NULL, UNIQUE (user_id, room_id, type) ); - --- CREATE UNIQUE INDEX IF NOT EXISTS syncapi_account_data_id_idx ON syncapi_account_data_type(id, type); ` const insertAccountDataSQL = "" + - "INSERT INTO syncapi_account_data_type (user_id, room_id, type) VALUES ($1, $2, $3)" + + "INSERT INTO syncapi_account_data_type (id, user_id, room_id, type) VALUES ($1, $2, $3, $4)" + " ON CONFLICT (user_id, room_id, type) DO UPDATE" + - " SET id = EXCLUDED.id;" + - "SELECT id FROM syncapi_account_data_type WHERE rowid = last_insert_rowid()" + " SET id = EXCLUDED.id" const selectAccountDataInRangeSQL = "" + "SELECT room_id, type FROM syncapi_account_data_type" + " WHERE user_id = $1 AND id > $2 AND id <= $3" + - // " AND ( $4 IS NULL OR type LIKE ANY($4) )" + - // " AND ( $5 IS NULL OR NOT(type LIKE ANY($5)) )" + + " AND ( $4 IS NULL OR type IN ($4) )" + + " AND ( $5 IS NULL OR NOT(type IN ($5)) )" + " ORDER BY id ASC LIMIT $6" const selectMaxAccountDataIDSQL = "" + "SELECT MAX(id) FROM syncapi_account_data_type" type accountDataStatements struct { + streamIDStatements *streamIDStatements insertAccountDataStmt *sql.Stmt selectAccountDataInRangeStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB) (err error) { +func (s *accountDataStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { + s.streamIDStatements = streamID _, err = db.Exec(accountDataSchema) if err != nil { return @@ -77,10 +76,15 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { } func (s *accountDataStatements) insertAccountData( - ctx context.Context, + ctx context.Context, txn *sql.Tx, userID, roomID, dataType string, ) (pos types.StreamPosition, err error) { - err = s.insertAccountDataStmt.QueryRowContext(ctx, userID, roomID, dataType).Scan(&pos) + pos, err = s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return + } + insertStmt := common.TxStmt(txn, s.insertAccountDataStmt) + _, err = insertStmt.ExecContext(ctx, pos, userID, roomID, dataType) return } diff --git a/syncapi/storage/sqlite3/backward_extremities_table.go b/syncapi/storage/sqlite3/backward_extremities_table.go index 726abff14..fcf15da25 100644 --- a/syncapi/storage/sqlite3/backward_extremities_table.go +++ b/syncapi/storage/sqlite3/backward_extremities_table.go @@ -17,14 +17,14 @@ package sqlite3 import ( "context" "database/sql" + + "github.com/matrix-org/dendrite/common" ) const backwardExtremitiesSchema = ` -- Stores output room events received from the roomserver. CREATE TABLE IF NOT EXISTS syncapi_backward_extremities ( - -- The 'room_id' key for the event. room_id TEXT NOT NULL, - -- The event ID for the event. event_id TEXT NOT NULL, PRIMARY KEY(room_id, event_id) @@ -34,7 +34,7 @@ CREATE TABLE IF NOT EXISTS syncapi_backward_extremities ( const insertBackwardExtremitySQL = "" + "INSERT INTO syncapi_backward_extremities (room_id, event_id)" + " VALUES ($1, $2)" + - " ON CONFLICT DO NOTHING" + " ON CONFLICT (room_id, event_id) DO NOTHING" const selectBackwardExtremitiesForRoomSQL = "" + "SELECT event_id FROM syncapi_backward_extremities WHERE room_id = $1" @@ -46,7 +46,8 @@ const isBackwardExtremitySQL = "" + ")" const deleteBackwardExtremitySQL = "" + - "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND event_id = $2" + "DELETE FROM syncapi_backward_extremities" + + " WHERE room_id = $1 AND event_id = $2" type backwardExtremitiesStatements struct { insertBackwardExtremityStmt *sql.Stmt @@ -76,18 +77,20 @@ func (s *backwardExtremitiesStatements) prepare(db *sql.DB) (err error) { } func (s *backwardExtremitiesStatements) insertsBackwardExtremity( - ctx context.Context, roomID, eventID string, + ctx context.Context, txn *sql.Tx, roomID, eventID string, ) (err error) { - _, err = s.insertBackwardExtremityStmt.ExecContext(ctx, roomID, eventID) + stmt := common.TxStmt(txn, s.insertBackwardExtremityStmt) + _, err = stmt.ExecContext(ctx, roomID, eventID) return } func (s *backwardExtremitiesStatements) selectBackwardExtremitiesForRoom( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (eventIDs []string, err error) { eventIDs = make([]string, 0) - rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) + stmt := common.TxStmt(txn, s.selectBackwardExtremitiesForRoomStmt) + rows, err := stmt.QueryContext(ctx, roomID) if err != nil { return } @@ -105,15 +108,17 @@ func (s *backwardExtremitiesStatements) selectBackwardExtremitiesForRoom( } func (s *backwardExtremitiesStatements) isBackwardExtremity( - ctx context.Context, roomID, eventID string, + ctx context.Context, txn *sql.Tx, roomID, eventID string, ) (isBE bool, err error) { - err = s.isBackwardExtremityStmt.QueryRowContext(ctx, roomID, eventID).Scan(&isBE) + stmt := common.TxStmt(txn, s.isBackwardExtremityStmt) + err = stmt.QueryRowContext(ctx, roomID, eventID).Scan(&isBE) return } func (s *backwardExtremitiesStatements) deleteBackwardExtremity( - ctx context.Context, roomID, eventID string, + ctx context.Context, txn *sql.Tx, roomID, eventID string, ) (err error) { - _, err = s.insertBackwardExtremityStmt.ExecContext(ctx, roomID, eventID) + stmt := common.TxStmt(txn, s.deleteBackwardExtremityStmt) + _, err = stmt.ExecContext(ctx, roomID, eventID) return } diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 26f227738..2145dea29 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -30,31 +30,19 @@ import ( const currentRoomStateSchema = ` -- Stores the current room state for every room. CREATE TABLE IF NOT EXISTS syncapi_current_room_state ( - -- The 'room_id' key for the state event. room_id TEXT NOT NULL, - -- The state event ID event_id TEXT NOT NULL, - -- The state event type e.g 'm.room.member' type TEXT NOT NULL, - -- The 'sender' property of the event. sender TEXT NOT NULL, - -- true if the event content contains a url key - contains_url BOOL NOT NULL, - -- The state_key value for this state event e.g '' + contains_url BOOL NOT NULL DEFAULT false, state_key TEXT NOT NULL, - -- The JSON for the event. Stored as TEXT because this should be valid UTF-8. event_json TEXT NOT NULL, - -- The 'content.membership' value if this event is an m.room.member event. For other - -- events, this will be NULL. membership TEXT, - -- The serial ID of the output_room_events table when this event became - -- part of the current state of the room. added_at BIGINT, - -- Clobber based on 3-uple of room_id, type and state_key UNIQUE (room_id, type, state_key) ); -- for event deletion --- CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url); +CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url); -- for querying membership states of users -- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave'; ` @@ -73,10 +61,10 @@ const selectRoomIDsWithMembershipSQL = "" + const selectCurrentStateSQL = "" + "SELECT event_json FROM syncapi_current_room_state WHERE room_id = $1" + - " AND ( $2 IS NULL OR sender = ANY($2) )" + - " AND ( $3 IS NULL OR NOT(sender = ANY($3)) )" + - " AND ( $4 IS NULL OR type LIKE ANY($4) )" + - " AND ( $5 IS NULL OR NOT(type LIKE ANY($5)) )" + + " AND ( $2 IS NULL OR sender IN ($2) )" + + " AND ( $3 IS NULL OR NOT(sender IN ($3)) )" + + " AND ( $4 IS NULL OR type IN ($4) )" + + " AND ( $5 IS NULL OR NOT(type IN ($5)) )" + " AND ( $6 IS NULL OR contains_url = $6 )" + " LIMIT $7" @@ -92,9 +80,10 @@ const selectEventsWithEventIDsSQL = "" + // figure out if these really need to be in the DB, and if so, we need a // better permanent fix for this. - neilalexander, 2 Jan 2020 "SELECT added_at, event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" + - " FROM syncapi_current_room_state WHERE event_id = ANY($1)" + " FROM syncapi_current_room_state WHERE event_id IN ($1)" type currentRoomStateStatements struct { + streamIDStatements *streamIDStatements upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt @@ -104,7 +93,8 @@ type currentRoomStateStatements struct { selectStateEventStmt *sql.Stmt } -func (s *currentRoomStateStatements) prepare(db *sql.DB) (err error) { +func (s *currentRoomStateStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { + s.streamIDStatements = streamID _, err = db.Exec(currentRoomStateSchema) if err != nil { return diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index ee95aef01..74dba245b 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -26,26 +26,24 @@ import ( const inviteEventsSchema = ` CREATE TABLE IF NOT EXISTS syncapi_invite_events ( - id INTEGER PRIMARY KEY DEFAULT AUTOINCREMENT, -- nextval('syncapi_stream_id'), + id INTEGER PRIMARY KEY AUTOINCREMENT, event_id TEXT NOT NULL, room_id TEXT NOT NULL, target_user_id TEXT NOT NULL, event_json TEXT NOT NULL ); --- For looking up the invites for a given user. --- CREATE INDEX IF NOT EXISTS syncapi_invites_target_user_id_idx --- ON syncapi_invite_events (target_user_id, id); - --- For deleting old invites --- CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx --- ON syncapi_invite_events (event_id); +CREATE INDEX IF NOT EXISTS syncapi_invites_target_user_id_idx ON syncapi_invite_events (target_user_id, id); +CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx ON syncapi_invite_events (event_id); ` const insertInviteEventSQL = "" + - "INSERT INTO syncapi_invite_events (" + - " room_id, event_id, target_user_id, event_json" + - ") VALUES ($1, $2, $3, $4) RETURNING id" + "INSERT INTO syncapi_invite_events" + + " (room_id, event_id, target_user_id, event_json)" + + " VALUES ($1, $2, $3, $4)" + +const selectLastInsertedInviteEventSQL = "" + + "SELECT id FROM syncapi_invite_events WHERE rowid = last_insert_rowid()" const deleteInviteEventSQL = "" + "DELETE FROM syncapi_invite_events WHERE event_id = $1" @@ -59,13 +57,16 @@ const selectMaxInviteIDSQL = "" + "SELECT MAX(id) FROM syncapi_invite_events" type inviteEventsStatements struct { - insertInviteEventStmt *sql.Stmt - selectInviteEventsInRangeStmt *sql.Stmt - deleteInviteEventStmt *sql.Stmt - selectMaxInviteIDStmt *sql.Stmt + streamIDStatements *streamIDStatements + insertInviteEventStmt *sql.Stmt + selectLastInsertedInviteEventStmt *sql.Stmt + selectInviteEventsInRangeStmt *sql.Stmt + deleteInviteEventStmt *sql.Stmt + selectMaxInviteIDStmt *sql.Stmt } -func (s *inviteEventsStatements) prepare(db *sql.DB) (err error) { +func (s *inviteEventsStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { + s.streamIDStatements = streamID _, err = db.Exec(inviteEventsSchema) if err != nil { return @@ -73,6 +74,9 @@ func (s *inviteEventsStatements) prepare(db *sql.DB) (err error) { if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { return } + if s.selectLastInsertedInviteEventStmt, err = db.Prepare(selectLastInsertedInviteEventSQL); err != nil { + return + } if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { return } @@ -88,13 +92,17 @@ func (s *inviteEventsStatements) prepare(db *sql.DB) (err error) { func (s *inviteEventsStatements) insertInviteEvent( ctx context.Context, inviteEvent gomatrixserverlib.Event, ) (streamPos types.StreamPosition, err error) { - err = s.insertInviteEventStmt.QueryRowContext( + _, err = s.insertInviteEventStmt.ExecContext( ctx, inviteEvent.RoomID(), inviteEvent.EventID(), *inviteEvent.StateKey(), inviteEvent.JSON(), - ).Scan(&streamPos) + ) + if err != nil { + return + } + err = s.selectLastInsertedInviteEventStmt.QueryRowContext(ctx).Scan(&streamPos) return } diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index f30e176e0..d78ffb198 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -32,52 +32,34 @@ import ( ) const outputRoomEventsSchema = ` --- This sequence is shared between all the tables generated from kafka logs. -CREATE SEQUENCE IF NOT EXISTS syncapi_stream_id; - -- Stores output room events received from the roomserver. CREATE TABLE IF NOT EXISTS syncapi_output_room_events ( - -- An incrementing ID which denotes the position in the log that this event resides at. - -- NB: 'serial' makes no guarantees to increment by 1 every time, only that it increments. - -- This isn't a problem for us since we just want to order by this field. - id INTEGER PRIMARY KEY AUTOINCREMENT, -- DEFAULT nextval('syncapi_stream_id'), - -- The event ID for the event - event_id TEXT NOT NULL UNIQUE, -- CONSTRAINT syncapi_event_id_idx UNIQUE, - -- The 'room_id' key for the event. + id INTEGER PRIMARY KEY AUTOINCREMENT, + event_id TEXT NOT NULL UNIQUE, room_id TEXT NOT NULL, - -- The JSON for the event. Stored as TEXT because this should be valid UTF-8. event_json TEXT NOT NULL, - -- The event type e.g 'm.room.member'. type TEXT NOT NULL, - -- The 'sender' property of the event. sender TEXT NOT NULL, - -- true if the event content contains a url key. contains_url BOOL NOT NULL, - -- A list of event IDs which represent a delta of added/removed room state. This can be NULL - -- if there is no delta. add_state_ids TEXT[], remove_state_ids TEXT[], - -- The client session that sent the event, if any session_id BIGINT, - -- The transaction id used to send the event, if any transaction_id TEXT, - -- Should the event be excluded from responses to /sync requests. Useful for - -- events retrieved through backfilling that have a position in the stream - -- that relates to the moment these were retrieved rather than the moment these - -- were emitted. exclude_from_sync BOOL DEFAULT FALSE ); ` const insertEventSQL = "" + "INSERT INTO syncapi_output_room_events (" + - "room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" + - ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) " + - "ON CONFLICT DO UPDATE SET exclude_from_sync = $11 " + - "RETURNING id" + "id, room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" + + ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " + + "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = $11" + +const selectLastInsertedEventSQL = "" + + "SELECT id FROM syncapi_output_room_events WHERE rowid = last_insert_rowid()" const selectEventsSQL = "" + - "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" + "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1" const selectRecentEventsSQL = "" + "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + @@ -98,20 +80,33 @@ const selectMaxEventIDSQL = "" + "SELECT MAX(id) FROM syncapi_output_room_events" // In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id). +/* + $1 = oldPos, + $2 = newPos, + $3 = pq.StringArray(stateFilterPart.Senders), + $4 = pq.StringArray(stateFilterPart.NotSenders), + $5 = pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)), + $6 = pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)), + $7 = stateFilterPart.ContainsURL, + $8 = stateFilterPart.Limit, +*/ const selectStateInRangeSQL = "" + "SELECT id, event_json, exclude_from_sync, add_state_ids, remove_state_ids" + " FROM syncapi_output_room_events" + - " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + - " AND ( $3::text[] IS NULL OR sender = ANY($3) )" + - " AND ( $4::text[] IS NULL OR NOT(sender = ANY($4)) )" + - " AND ( $5::text[] IS NULL OR type LIKE ANY($5) )" + - " AND ( $6::text[] IS NULL OR NOT(type LIKE ANY($6)) )" + - " AND ( $7::bool IS NULL OR contains_url = $7 )" + + " WHERE (id > $1 AND id <= $2)" + // old/new pos + " AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + + " AND ( $3 IS NULL OR sender IN ($3) )" + // sender + " AND ( $4 IS NULL OR NOT(sender IN ($4)) )" + // not sender + " AND ( $5 IS NULL OR type IN ($5) )" + // type + " AND ( $6 IS NULL OR NOT(type IN ($6)) )" + // not type + " AND ( $7 IS NULL OR contains_url = $7)" + // contains URL? " ORDER BY id ASC" + - " LIMIT $8" + " LIMIT $8" // limit type outputRoomEventsStatements struct { + streamIDStatements *streamIDStatements insertEventStmt *sql.Stmt + selectLastInsertedEventStmt *sql.Stmt selectEventsStmt *sql.Stmt selectMaxEventIDStmt *sql.Stmt selectRecentEventsStmt *sql.Stmt @@ -120,7 +115,8 @@ type outputRoomEventsStatements struct { selectStateInRangeStmt *sql.Stmt } -func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { +func (s *outputRoomEventsStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { + s.streamIDStatements = streamID _, err = db.Exec(outputRoomEventsSchema) if err != nil { return @@ -128,6 +124,9 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { return } + if s.selectLastInsertedEventStmt, err = db.Prepare(selectLastInsertedEventSQL); err != nil { + return + } if s.selectEventsStmt, err = db.Prepare(selectEventsSQL); err != nil { return } @@ -266,9 +265,16 @@ func (s *outputRoomEventsStatements) insertEvent( _, containsURL = content["url"] } - stmt := common.TxStmt(txn, s.insertEventStmt) - err = stmt.QueryRowContext( + streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return + } + + insertStmt := common.TxStmt(txn, s.insertEventStmt) + selectStmt := common.TxStmt(txn, s.selectLastInsertedEventStmt) + _, err = insertStmt.ExecContext( ctx, + streamPos, event.RoomID(), event.EventID(), event.JSON(), @@ -280,7 +286,11 @@ func (s *outputRoomEventsStatements) insertEvent( sessionID, txnID, excludeFromSync, - ).Scan(&streamPos) + ) + if err != nil { + return + } + err = selectStmt.QueryRowContext(ctx).Scan(&streamPos) return } @@ -349,13 +359,19 @@ func (s *outputRoomEventsStatements) selectEarlyEvents( func (s *outputRoomEventsStatements) selectEvents( ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StreamEvent, error) { + var returnEvents []types.StreamEvent stmt := common.TxStmt(txn, s.selectEventsStmt) - rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) - if err != nil { - return nil, err + for _, eventID := range eventIDs { + rows, err := stmt.QueryContext(ctx, eventID) + if err != nil { + return nil, err + } + if streamEvents, err := rowsToStreamEvents(rows); err == nil { + returnEvents = append(returnEvents, streamEvents...) + } + rows.Close() } - defer rows.Close() // nolint: errcheck - return rowsToStreamEvents(rows) + return returnEvents, nil } func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 1040466f0..f7075bd6f 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -25,13 +26,11 @@ import ( const outputRoomEventsTopologySchema = ` -- Stores output room events received from the roomserver. CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology ( - -- The event ID for the event. - event_id TEXT PRIMARY KEY, - -- The place of the event in the room's topology. This can usually be determined - -- from the event's depth. + event_id TEXT PRIMARY KEY, topological_position BIGINT NOT NULL, - -- The 'room_id' key for the event. - room_id TEXT NOT NULL + room_id TEXT NOT NULL, + + UNIQUE(topological_position, room_id) ); -- The topological order will be used in events selection and ordering -- CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, room_id); @@ -102,9 +101,10 @@ func (s *outputRoomEventsTopologyStatements) prepare(db *sql.DB) (err error) { // insertEventInTopology inserts the given event in the room's topology, based // on the event's depth. func (s *outputRoomEventsTopologyStatements) insertEventInTopology( - ctx context.Context, event *gomatrixserverlib.Event, + ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.Event, ) (err error) { - _, err = s.insertEventInTopologyStmt.ExecContext( + stmt := common.TxStmt(txn, s.insertEventInTopologyStmt) + _, err = stmt.ExecContext( ctx, event.EventID(), event.Depth(), event.RoomID(), ) return @@ -114,16 +114,17 @@ func (s *outputRoomEventsTopologyStatements) insertEventInTopology( // given range in a given room's topological order. // Returns an empty slice if no events match the given range. func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange( - ctx context.Context, roomID string, fromPos, toPos types.StreamPosition, + ctx context.Context, txn *sql.Tx, roomID string, + fromPos, toPos types.StreamPosition, limit int, chronologicalOrder bool, ) (eventIDs []string, err error) { // Decide on the selection's order according to whether chronological order // is requested or not. var stmt *sql.Stmt if chronologicalOrder { - stmt = s.selectEventIDsInRangeASCStmt + stmt = common.TxStmt(txn, s.selectEventIDsInRangeASCStmt) } else { - stmt = s.selectEventIDsInRangeDESCStmt + stmt = common.TxStmt(txn, s.selectEventIDsInRangeDESCStmt) } // Query the event IDs. @@ -150,26 +151,29 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange( // selectPositionInTopology returns the position of a given event in the // topology of the room it belongs to. func (s *outputRoomEventsTopologyStatements) selectPositionInTopology( - ctx context.Context, eventID string, + ctx context.Context, txn *sql.Tx, eventID string, ) (pos types.StreamPosition, err error) { - err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos) + stmt := common.TxStmt(txn, s.selectPositionInTopologyStmt) + err = stmt.QueryRowContext(ctx, eventID).Scan(&pos) return } func (s *outputRoomEventsTopologyStatements) selectMaxPositionInTopology( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (pos types.StreamPosition, err error) { - err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos) + stmt := common.TxStmt(txn, s.selectMaxPositionInTopologyStmt) + err = stmt.QueryRowContext(ctx, roomID).Scan(&pos) return } // selectEventIDsFromPosition returns the IDs of all events that have a given // position in the topology of a given room. func (s *outputRoomEventsTopologyStatements) selectEventIDsFromPosition( - ctx context.Context, roomID string, pos types.StreamPosition, + ctx context.Context, txn *sql.Tx, roomID string, pos types.StreamPosition, ) (eventIDs []string, err error) { // Query the event IDs. - rows, err := s.selectEventIDsFromPositionStmt.QueryContext(ctx, roomID, pos) + stmt := common.TxStmt(txn, s.selectEventIDsFromPositionStmt) + rows, err := stmt.QueryContext(ctx, roomID, pos) if err == sql.ErrNoRows { // If no event matched the request, return an empty slice. return []string{}, nil diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go new file mode 100644 index 000000000..260f7a95d --- /dev/null +++ b/syncapi/storage/sqlite3/stream_id_table.go @@ -0,0 +1,58 @@ +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/syncapi/types" +) + +const streamIDTableSchema = ` +-- Global stream ID counter, used by other tables. +CREATE TABLE IF NOT EXISTS syncapi_stream_id ( + stream_name TEXT NOT NULL PRIMARY KEY, + stream_id INT DEFAULT 0, + + UNIQUE(stream_name) +); +INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0) + ON CONFLICT DO NOTHING; +` + +const increaseStreamIDStmt = "" + + "UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" + +const selectStreamIDStmt = "" + + "SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1" + +type streamIDStatements struct { + increaseStreamIDStmt *sql.Stmt + selectStreamIDStmt *sql.Stmt +} + +func (s *streamIDStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(streamIDTableSchema) + if err != nil { + return + } + if s.increaseStreamIDStmt, err = db.Prepare(increaseStreamIDStmt); err != nil { + return + } + if s.selectStreamIDStmt, err = db.Prepare(selectStreamIDStmt); err != nil { + return + } + return +} + +func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { + increaseStmt := common.TxStmt(txn, s.increaseStreamIDStmt) + selectStmt := common.TxStmt(txn, s.selectStreamIDStmt) + if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil { + return + } + if err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil { + return + } + return +} diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index e0a38ac15..432b18a60 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -54,6 +54,7 @@ type stateDelta struct { type SyncServerDatasource struct { db *sql.DB common.PartitionOffsetStatements + streamID streamIDStatements accountData accountDataStatements events outputRoomEventsStatements roomstate currentRoomStateStatements @@ -86,16 +87,19 @@ func NewSyncServerDatasource(dataSourceName string) (*SyncServerDatasource, erro if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil { return nil, err } - if err = d.accountData.prepare(d.db); err != nil { + if err = d.streamID.prepare(d.db); err != nil { return nil, err } - if err = d.events.prepare(d.db); err != nil { + if err = d.accountData.prepare(d.db, &d.streamID); err != nil { return nil, err } - if err := d.roomstate.prepare(d.db); err != nil { + if err = d.events.prepare(d.db, &d.streamID); err != nil { return nil, err } - if err := d.invites.prepare(d.db); err != nil { + if err := d.roomstate.prepare(d.db, &d.streamID); err != nil { + return nil, err + } + if err := d.invites.prepare(d.db, &d.streamID); err != nil { return nil, err } if err := d.topology.prepare(d.db); err != nil { @@ -129,22 +133,22 @@ func (d *SyncServerDatasource) Events(ctx context.Context, eventIDs []string) ([ return d.StreamEventsToEvents(nil, streamEvents), nil } -func (d *SyncServerDatasource) handleBackwardExtremities(ctx context.Context, ev *gomatrixserverlib.Event) error { +func (d *SyncServerDatasource) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.Event) error { // If the event is already known as a backward extremity, don't consider // it as such anymore now that we have it. - isBackwardExtremity, err := d.backwardExtremities.isBackwardExtremity(ctx, ev.RoomID(), ev.EventID()) + isBackwardExtremity, err := d.backwardExtremities.isBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()) if err != nil { return err } if isBackwardExtremity { - if err = d.backwardExtremities.deleteBackwardExtremity(ctx, ev.RoomID(), ev.EventID()); err != nil { + if err = d.backwardExtremities.deleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil { return err } } // Check if we have all of the event's previous events. If an event is // missing, add it to the room's backward extremities. - prevEvents, err := d.events.selectEvents(ctx, nil, ev.PrevEventIDs()) + prevEvents, err := d.events.selectEvents(ctx, txn, ev.PrevEventIDs()) if err != nil { return err } @@ -159,7 +163,7 @@ func (d *SyncServerDatasource) handleBackwardExtremities(ctx context.Context, ev // If the event is missing, consider it a backward extremity. if !found { - if err = d.backwardExtremities.insertsBackwardExtremity(ctx, ev.RoomID(), ev.EventID()); err != nil { + if err = d.backwardExtremities.insertsBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil { return err } } @@ -184,20 +188,24 @@ func (d *SyncServerDatasource) WriteEvent( ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, ) if err != nil { + fmt.Println("d.events.insertEvent:", err) return err } pduPosition = pos - if err = d.topology.insertEventInTopology(ctx, ev); err != nil { + if err = d.topology.insertEventInTopology(ctx, txn, ev); err != nil { + fmt.Println("d.topology.insertEventInTopology:", err) return err } - if err = d.handleBackwardExtremities(ctx, ev); err != nil { + if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil { + fmt.Println("d.handleBackwardExtremities:", err) return err } if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { // Nothing to do, the event may have just been a message event. + fmt.Println("nothing to do") return nil } @@ -292,7 +300,7 @@ func (d *SyncServerDatasource) GetEventsInRange( // Select the event IDs from the defined range. var eIDs []string eIDs, err = d.topology.selectEventIDsInRange( - ctx, roomID, backwardLimit, forwardLimit, limit, !backwardOrdering, + ctx, nil, roomID, backwardLimit, forwardLimit, limit, !backwardOrdering, ) if err != nil { return @@ -336,7 +344,7 @@ func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.Paginati func (d *SyncServerDatasource) BackwardExtremitiesForRoom( ctx context.Context, roomID string, ) (backwardExtremities []string, err error) { - return d.backwardExtremities.selectBackwardExtremitiesForRoom(ctx, roomID) + return d.backwardExtremities.selectBackwardExtremitiesForRoom(ctx, nil, roomID) } // MaxTopologicalPosition returns the highest topological position for a given @@ -344,7 +352,7 @@ func (d *SyncServerDatasource) BackwardExtremitiesForRoom( func (d *SyncServerDatasource) MaxTopologicalPosition( ctx context.Context, roomID string, ) (types.StreamPosition, error) { - return d.topology.selectMaxPositionInTopology(ctx, roomID) + return d.topology.selectMaxPositionInTopology(ctx, nil, roomID) } // EventsAtTopologicalPosition returns all of the events matching a given @@ -352,7 +360,7 @@ func (d *SyncServerDatasource) MaxTopologicalPosition( func (d *SyncServerDatasource) EventsAtTopologicalPosition( ctx context.Context, roomID string, pos types.StreamPosition, ) ([]types.StreamEvent, error) { - eIDs, err := d.topology.selectEventIDsFromPosition(ctx, roomID, pos) + eIDs, err := d.topology.selectEventIDsFromPosition(ctx, nil, roomID, pos) if err != nil { return nil, err } @@ -363,7 +371,7 @@ func (d *SyncServerDatasource) EventsAtTopologicalPosition( func (d *SyncServerDatasource) EventPositionInTopology( ctx context.Context, eventID string, ) (types.StreamPosition, error) { - return d.topology.selectPositionInTopology(ctx, eventID) + return d.topology.selectPositionInTopology(ctx, nil, eventID) } // SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. @@ -627,7 +635,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( // Retrieve the backward topology position, i.e. the position of the // oldest event in the room's topology. var backwardTopologyPos types.StreamPosition - backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID()) + backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID()) if err != nil { return nil, types.PaginationToken{}, []string{}, err } @@ -712,7 +720,13 @@ func (d *SyncServerDatasource) GetAccountDataInRange( func (d *SyncServerDatasource) UpsertAccountData( ctx context.Context, userID, roomID, dataType string, ) (types.StreamPosition, error) { - return d.accountData.insertAccountData(ctx, userID, roomID, dataType) + txn, err := d.db.BeginTx(ctx, nil) + if err != nil { + return types.StreamPosition(0), err + } + var succeeded bool + defer common.EndTransaction(txn, &succeeded) + return d.accountData.insertAccountData(ctx, txn, userID, roomID, dataType) } // AddInviteEvent stores a new invite event for a user. @@ -781,11 +795,11 @@ func (d *SyncServerDatasource) addInvitesToResponse( // Retrieve the backward topology position, i.e. the position of the // oldest event in the room's topology. func (d *SyncServerDatasource) getBackwardTopologyPos( - ctx context.Context, + ctx context.Context, txn *sql.Tx, events []types.StreamEvent, ) (pos types.StreamPosition) { if len(events) > 0 { - pos, _ = d.topology.selectPositionInTopology(ctx, events[0].EventID()) + pos, _ = d.topology.selectPositionInTopology(ctx, txn, events[0].EventID()) } if pos-1 <= 0 { pos = types.StreamPosition(1) @@ -824,7 +838,7 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( } recentEvents := d.StreamEventsToEvents(device, recentStreamEvents) delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back - backwardTopologyPos := d.getBackwardTopologyPos(ctx, recentStreamEvents) + backwardTopologyPos := d.getBackwardTopologyPos(ctx, txn, recentStreamEvents) switch delta.membership { case gomatrixserverlib.Join: