From 677dc175d0f91e57d29e8d06aad67df8687a099f Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 18 Feb 2020 18:02:53 +0000 Subject: [PATCH] sqlite: Fix invite table by using the global stream pos rather than one specific to invites If we don't use the global, clients don't get notified about any invites because the position is too low. --- roomserver/storage/sqlite3/invite_table.go | 18 ++++++++++++----- syncapi/storage/sqlite3/invites_table.go | 23 +++++++--------------- syncapi/storage/sqlite3/syncserver.go | 23 ++++++++++------------ 3 files changed, 30 insertions(+), 34 deletions(-) diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index 5a0f0bf79..afa770e4a 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -52,16 +52,18 @@ const selectInviteActiveForUserInRoomSQL = "" + // However the matrix protocol doesn't give us a way to reliably identify the // invites that were retired, so we are forced to retire all of them. const updateInviteRetiredSQL = ` - UPDATE roomserver_invites SET retired = TRUE - WHERE room_nid = $1 AND target_nid = $2 AND NOT retired; - SELECT invite_event_id FROM roomserver_invites - WHERE rowid = last_insert_rowid(); + UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired +` + +const selectInvitesAboutToRetireSQL = ` +SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired ` type inviteStatements struct { insertInviteEventStmt *sql.Stmt selectInviteActiveForUserInRoomStmt *sql.Stmt updateInviteRetiredStmt *sql.Stmt + selectInvitesAboutToRetireStmt *sql.Stmt } func (s *inviteStatements) prepare(db *sql.DB) (err error) { @@ -74,6 +76,7 @@ func (s *inviteStatements) prepare(db *sql.DB) (err error) { {&s.insertInviteEventStmt, insertInviteEventSQL}, {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, + {&s.selectInvitesAboutToRetireStmt, selectInvitesAboutToRetireSQL}, }.prepare(db) } @@ -102,7 +105,8 @@ func (s *inviteStatements) updateInviteRetired( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventIDs []string, err error) { - stmt := common.TxStmt(txn, s.updateInviteRetiredStmt) + // gather all the event IDs we will retire + stmt := txn.Stmt(s.selectInvitesAboutToRetireStmt) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) if err != nil { return nil, err @@ -115,6 +119,10 @@ func (s *inviteStatements) updateInviteRetired( } eventIDs = append(eventIDs, inviteEventID) } + + // now retire the invites + stmt = txn.Stmt(s.updateInviteRetiredStmt) + _, err = stmt.ExecContext(ctx, roomNID, targetUserNID) return } diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index f420e4b27..baf8871bd 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -26,7 +26,7 @@ import ( const inviteEventsSchema = ` CREATE TABLE IF NOT EXISTS syncapi_invite_events ( - id INTEGER PRIMARY KEY AUTOINCREMENT, + id INTEGER PRIMARY KEY, event_id TEXT NOT NULL, room_id TEXT NOT NULL, target_user_id TEXT NOT NULL, @@ -39,8 +39,8 @@ CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx ON syncapi_invite_events const insertInviteEventSQL = "" + "INSERT INTO syncapi_invite_events" + - " (room_id, event_id, target_user_id, event_json)" + - " VALUES ($1, $2, $3, $4)" + " (id, room_id, event_id, target_user_id, event_json)" + + " VALUES ($1, $2, $3, $4, $5)" const deleteInviteEventSQL = "" + "DELETE FROM syncapi_invite_events WHERE event_id = $1" @@ -83,25 +83,16 @@ func (s *inviteEventsStatements) prepare(db *sql.DB, streamID *streamIDStatement } func (s *inviteEventsStatements) insertInviteEvent( - ctx context.Context, inviteEvent gomatrixserverlib.Event, -) (streamPos types.StreamPosition, err error) { - var res sql.Result - res, err = s.insertInviteEventStmt.ExecContext( + ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.Event, streamPos types.StreamPosition, +) (err error) { + _, err = txn.Stmt(s.insertInviteEventStmt).ExecContext( ctx, + streamPos, inviteEvent.RoomID(), inviteEvent.EventID(), *inviteEvent.StateKey(), inviteEvent.JSON(), ) - if err != nil { - return - } - var rowID int64 - rowID, err = res.LastInsertId() - if err != nil { - return - } - streamPos = types.StreamPosition(rowID) return } diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 8cfc1884f..4714de764 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -193,24 +193,20 @@ func (d *SyncServerDatasource) WriteEvent( ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, ) if err != nil { - fmt.Println("d.events.insertEvent:", err) return err } pduPosition = pos if err = d.topology.insertEventInTopology(ctx, txn, ev); err != nil { - fmt.Println("d.topology.insertEventInTopology:", err) return err } if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil { - fmt.Println("d.handleBackwardExtremities:", err) return err } if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { // Nothing to do, the event may have just been a message event. - fmt.Println("nothing to do") return nil } @@ -625,18 +621,15 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( if err != nil { return } - fmt.Println("Joined rooms:", joinedRoomIDs) stateFilterPart := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request // Build up a /sync response. Add joined rooms. for _, roomID := range joinedRoomIDs { - fmt.Println("WE'RE ON", roomID) var stateEvents []gomatrixserverlib.Event stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, &stateFilterPart) if err != nil { - fmt.Println("d.roomstate.selectCurrentState:", err) return } //fmt.Println("State events:", stateEvents) @@ -648,7 +641,6 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( numRecentEventsPerRoom, true, true, ) if err != nil { - fmt.Println("d.events.selectRecentEvents:", err) return } //fmt.Println("Recent stream events:", recentStreamEvents) @@ -658,10 +650,9 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( var backwardTopologyPos types.StreamPosition backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID()) if err != nil { - fmt.Println("d.topology.selectPositionInTopology:", err) return nil, types.PaginationToken{}, []string{}, err } - fmt.Println("Backward topology position:", backwardTopologyPos) + if backwardTopologyPos-1 <= 0 { backwardTopologyPos = types.StreamPosition(1) } else { @@ -683,7 +674,6 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( } if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil { - fmt.Println("d.addInvitesToResponse:", err) return } @@ -764,8 +754,15 @@ func (d *SyncServerDatasource) UpsertAccountData( // Returns an error if there was a problem communicating with the database. func (d *SyncServerDatasource) AddInviteEvent( ctx context.Context, inviteEvent gomatrixserverlib.Event, -) (types.StreamPosition, error) { - return d.invites.insertInviteEvent(ctx, inviteEvent) +) (streamPos types.StreamPosition, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + streamPos, err = d.streamID.nextStreamID(ctx, txn) + if err != nil { + return err + } + return d.invites.insertInviteEvent(ctx, txn, inviteEvent, streamPos) + }) + return } // RetireInviteEvent removes an old invite event from the database.