mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-23 14:53:10 -06:00
Update SQLite InsertPreviousEvent properly
This commit is contained in:
parent
c73c239142
commit
2604b0b8ec
|
|
@ -18,6 +18,8 @@ package sqlite3
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
|
|
@ -43,8 +45,11 @@ const insertPreviousEventSQL = `
|
|||
INSERT OR REPLACE INTO roomserver_previous_events
|
||||
(previous_event_id, previous_reference_sha256, event_nids)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT DO UPDATE
|
||||
SET event_nids = event_nids || ',' || $3
|
||||
`
|
||||
|
||||
const selectPreviousEventNIDsSQL = `
|
||||
SELECT event_nids FROM roomserver_previous_events
|
||||
WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
|
||||
`
|
||||
|
||||
// Check if the event is referenced by another event in the table.
|
||||
|
|
@ -57,6 +62,7 @@ const selectPreviousEventExistsSQL = `
|
|||
type previousEventStatements struct {
|
||||
db *sql.DB
|
||||
insertPreviousEventStmt *sql.Stmt
|
||||
selectPreviousEventNIDsStmt *sql.Stmt
|
||||
selectPreviousEventExistsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
|
|
@ -71,6 +77,7 @@ func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
|
|||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertPreviousEventStmt, insertPreviousEventSQL},
|
||||
{&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL},
|
||||
{&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
|
@ -82,9 +89,21 @@ func (s *previousEventStatements) InsertPreviousEvent(
|
|||
previousEventReferenceSHA256 []byte,
|
||||
eventNID types.EventNID,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
|
||||
var eventNIDs string
|
||||
selectStmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt)
|
||||
err := selectStmt.QueryRowContext(ctx, previousEventID, previousEventReferenceSHA256).Scan(&eventNIDs)
|
||||
if err != sql.ErrNoRows {
|
||||
return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err)
|
||||
}
|
||||
for _, nid := range strings.Split(eventNIDs, ",") {
|
||||
if nid == fmt.Sprintf("%d", eventNID) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
eventNIDs = fmt.Sprintf("%s,%d", eventNIDs, eventNID)
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
|
||||
_, err = insertStmt.ExecContext(
|
||||
ctx, previousEventID, previousEventReferenceSHA256, eventNIDs,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue