Update SQLite InsertPreviousEvent properly

This commit is contained in:
Neil Alexander 2020-09-21 13:24:29 +01:00
parent c73c239142
commit 2604b0b8ec
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -18,6 +18,8 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strings"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
@ -43,8 +45,11 @@ const insertPreviousEventSQL = `
INSERT OR REPLACE INTO roomserver_previous_events INSERT OR REPLACE INTO roomserver_previous_events
(previous_event_id, previous_reference_sha256, event_nids) (previous_event_id, previous_reference_sha256, event_nids)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
ON CONFLICT DO UPDATE `
SET event_nids = event_nids || ',' || $3
const selectPreviousEventNIDsSQL = `
SELECT event_nids FROM roomserver_previous_events
WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
` `
// Check if the event is referenced by another event in the table. // Check if the event is referenced by another event in the table.
@ -57,6 +62,7 @@ const selectPreviousEventExistsSQL = `
type previousEventStatements struct { type previousEventStatements struct {
db *sql.DB db *sql.DB
insertPreviousEventStmt *sql.Stmt insertPreviousEventStmt *sql.Stmt
selectPreviousEventNIDsStmt *sql.Stmt
selectPreviousEventExistsStmt *sql.Stmt selectPreviousEventExistsStmt *sql.Stmt
} }
@ -71,6 +77,7 @@ func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
return s, shared.StatementList{ return s, shared.StatementList{
{&s.insertPreviousEventStmt, insertPreviousEventSQL}, {&s.insertPreviousEventStmt, insertPreviousEventSQL},
{&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL},
{&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -82,9 +89,21 @@ func (s *previousEventStatements) InsertPreviousEvent(
previousEventReferenceSHA256 []byte, previousEventReferenceSHA256 []byte,
eventNID types.EventNID, eventNID types.EventNID,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) var eventNIDs string
_, err := stmt.ExecContext( selectStmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt)
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), err := selectStmt.QueryRowContext(ctx, previousEventID, previousEventReferenceSHA256).Scan(&eventNIDs)
if err != sql.ErrNoRows {
return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err)
}
for _, nid := range strings.Split(eventNIDs, ",") {
if nid == fmt.Sprintf("%d", eventNID) {
return nil
}
}
eventNIDs = fmt.Sprintf("%s,%d", eventNIDs, eventNID)
insertStmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
_, err = insertStmt.ExecContext(
ctx, previousEventID, previousEventReferenceSHA256, eventNIDs,
) )
return err return err
} }