Update SQLite InsertPreviousEvent properly

This commit is contained in:
Neil Alexander 2020-09-21 13:24:29 +01:00
parent c73c239142
commit 2604b0b8ec
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -18,6 +18,8 @@ package sqlite3
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
@ -43,8 +45,11 @@ const insertPreviousEventSQL = `
INSERT OR REPLACE INTO roomserver_previous_events
(previous_event_id, previous_reference_sha256, event_nids)
VALUES ($1, $2, $3)
ON CONFLICT DO UPDATE
SET event_nids = event_nids || ',' || $3
`
const selectPreviousEventNIDsSQL = `
SELECT event_nids FROM roomserver_previous_events
WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
`
// Check if the event is referenced by another event in the table.
@ -57,6 +62,7 @@ const selectPreviousEventExistsSQL = `
type previousEventStatements struct {
db *sql.DB
insertPreviousEventStmt *sql.Stmt
selectPreviousEventNIDsStmt *sql.Stmt
selectPreviousEventExistsStmt *sql.Stmt
}
@ -71,6 +77,7 @@ func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
return s, shared.StatementList{
{&s.insertPreviousEventStmt, insertPreviousEventSQL},
{&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL},
{&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
}.Prepare(db)
}
@ -82,9 +89,21 @@ func (s *previousEventStatements) InsertPreviousEvent(
previousEventReferenceSHA256 []byte,
eventNID types.EventNID,
) error {
stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
_, err := stmt.ExecContext(
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
var eventNIDs string
selectStmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt)
err := selectStmt.QueryRowContext(ctx, previousEventID, previousEventReferenceSHA256).Scan(&eventNIDs)
if err != sql.ErrNoRows {
return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err)
}
for _, nid := range strings.Split(eventNIDs, ",") {
if nid == fmt.Sprintf("%d", eventNID) {
return nil
}
}
eventNIDs = fmt.Sprintf("%s,%d", eventNIDs, eventNID)
insertStmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
_, err = insertStmt.ExecContext(
ctx, previousEventID, previousEventReferenceSHA256, eventNIDs,
)
return err
}