From 2604b0b8ecacf776b348a85a7221cb95ae60fef8 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 21 Sep 2020 13:24:29 +0100 Subject: [PATCH] Update SQLite InsertPreviousEvent properly --- .../storage/sqlite3/previous_events_table.go | 29 +++++++++++++++---- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index 887c625c0..1b672912c 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -18,6 +18,8 @@ package sqlite3 import ( "context" "database/sql" + "fmt" + "strings" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" @@ -43,8 +45,11 @@ const insertPreviousEventSQL = ` INSERT OR REPLACE INTO roomserver_previous_events (previous_event_id, previous_reference_sha256, event_nids) VALUES ($1, $2, $3) - ON CONFLICT DO UPDATE - SET event_nids = event_nids || ',' || $3 +` + +const selectPreviousEventNIDsSQL = ` + SELECT event_nids FROM roomserver_previous_events + WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 ` // Check if the event is referenced by another event in the table. @@ -57,6 +62,7 @@ const selectPreviousEventExistsSQL = ` type previousEventStatements struct { db *sql.DB insertPreviousEventStmt *sql.Stmt + selectPreviousEventNIDsStmt *sql.Stmt selectPreviousEventExistsStmt *sql.Stmt } @@ -71,6 +77,7 @@ func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { return s, shared.StatementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, + {&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, }.Prepare(db) } @@ -82,9 +89,21 @@ func (s *previousEventStatements) InsertPreviousEvent( previousEventReferenceSHA256 []byte, eventNID types.EventNID, ) error { - stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) - _, err := stmt.ExecContext( - ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), + var eventNIDs string + selectStmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) + err := selectStmt.QueryRowContext(ctx, previousEventID, previousEventReferenceSHA256).Scan(&eventNIDs) + if err != sql.ErrNoRows { + return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err) + } + for _, nid := range strings.Split(eventNIDs, ",") { + if nid == fmt.Sprintf("%d", eventNID) { + return nil + } + } + eventNIDs = fmt.Sprintf("%s,%d", eventNIDs, eventNID) + insertStmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) + _, err = insertStmt.ExecContext( + ctx, previousEventID, previousEventReferenceSHA256, eventNIDs, ) return err }