mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-22 14:21:55 -06:00
Use TransactionWriter in other component SQLite (#1209)
* Use TransactionWriter on other component SQLites * Fix sync API tests * Fix panic in media API * Fix a couple of transactions * Fix wrong query, add some logging output * Add debug logging into StoreEvent * Adjust InsertRoomNID * Update logging
This commit is contained in:
parent
1d72ce8b7a
commit
b6bc132485
|
@ -21,6 +21,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
@ -65,6 +66,8 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
type eventsStatements struct {
|
type eventsStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
selectEventsByApplicationServiceIDStmt *sql.Stmt
|
selectEventsByApplicationServiceIDStmt *sql.Stmt
|
||||||
countEventsByApplicationServiceIDStmt *sql.Stmt
|
countEventsByApplicationServiceIDStmt *sql.Stmt
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
|
@ -73,6 +76,8 @@ type eventsStatements struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventsStatements) prepare(db *sql.DB) (err error) {
|
func (s *eventsStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
s.db = db
|
||||||
|
s.writer = sqlutil.NewTransactionWriter()
|
||||||
_, err = db.Exec(appserviceEventsSchema)
|
_, err = db.Exec(appserviceEventsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -217,13 +222,15 @@ func (s *eventsStatements) insertEvent(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.insertEventStmt.ExecContext(
|
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
ctx,
|
_, err := s.insertEventStmt.ExecContext(
|
||||||
appServiceID,
|
ctx,
|
||||||
eventJSON,
|
appServiceID,
|
||||||
-1, // No transaction ID yet
|
eventJSON,
|
||||||
)
|
-1, // No transaction ID yet
|
||||||
return
|
)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateTxnIDForEvents sets the transactionID for a collection of events. Done
|
// updateTxnIDForEvents sets the transactionID for a collection of events. Done
|
||||||
|
@ -234,8 +241,10 @@ func (s *eventsStatements) updateTxnIDForEvents(
|
||||||
appserviceID string,
|
appserviceID string,
|
||||||
maxID, txnID int,
|
maxID, txnID int,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID)
|
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
return
|
_, err := s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// deleteEventsBeforeAndIncludingID removes events matching given IDs from the database.
|
// deleteEventsBeforeAndIncludingID removes events matching given IDs from the database.
|
||||||
|
@ -244,6 +253,8 @@ func (s *eventsStatements) deleteEventsBeforeAndIncludingID(
|
||||||
appserviceID string,
|
appserviceID string,
|
||||||
eventTableID int,
|
eventTableID int,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID)
|
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
return
|
_, err := s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,8 @@ package sqlite3
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const txnIDSchema = `
|
const txnIDSchema = `
|
||||||
|
@ -35,10 +37,14 @@ const selectTxnIDSQL = `
|
||||||
`
|
`
|
||||||
|
|
||||||
type txnStatements struct {
|
type txnStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
selectTxnIDStmt *sql.Stmt
|
selectTxnIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *txnStatements) prepare(db *sql.DB) (err error) {
|
func (s *txnStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
s.db = db
|
||||||
|
s.writer = sqlutil.NewTransactionWriter()
|
||||||
_, err = db.Exec(txnIDSchema)
|
_, err = db.Exec(txnIDSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -55,6 +61,9 @@ func (s *txnStatements) prepare(db *sql.DB) (err error) {
|
||||||
func (s *txnStatements) selectTxnID(
|
func (s *txnStatements) selectTxnID(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
) (txnID int, err error) {
|
) (txnID int, err error) {
|
||||||
err = s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID)
|
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
|
err := s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,6 +68,7 @@ const selectBulkStateContentWildSQL = "" +
|
||||||
|
|
||||||
type currentRoomStateStatements struct {
|
type currentRoomStateStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
upsertRoomStateStmt *sql.Stmt
|
upsertRoomStateStmt *sql.Stmt
|
||||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||||
selectRoomIDsWithMembershipStmt *sql.Stmt
|
selectRoomIDsWithMembershipStmt *sql.Stmt
|
||||||
|
@ -76,7 +77,8 @@ type currentRoomStateStatements struct {
|
||||||
|
|
||||||
func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
|
func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
|
||||||
s := ¤tRoomStateStatements{
|
s := ¤tRoomStateStatements{
|
||||||
db: db,
|
db: db,
|
||||||
|
writer: sqlutil.NewTransactionWriter(),
|
||||||
}
|
}
|
||||||
_, err := db.Exec(currentRoomStateSchema)
|
_, err := db.Exec(currentRoomStateSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -125,9 +127,11 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
|
||||||
func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
|
func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
|
||||||
ctx context.Context, txn *sql.Tx, eventID string,
|
ctx context.Context, txn *sql.Tx, eventID string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
_, err := stmt.ExecContext(ctx, eventID)
|
stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
|
||||||
return err
|
_, err := stmt.ExecContext(ctx, eventID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *currentRoomStateStatements) UpsertRoomState(
|
func (s *currentRoomStateStatements) UpsertRoomState(
|
||||||
|
@ -140,18 +144,20 @@ func (s *currentRoomStateStatements) UpsertRoomState(
|
||||||
}
|
}
|
||||||
|
|
||||||
// upsert state event
|
// upsert state event
|
||||||
stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
_, err = stmt.ExecContext(
|
stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
|
||||||
ctx,
|
_, err = stmt.ExecContext(
|
||||||
event.RoomID(),
|
ctx,
|
||||||
event.EventID(),
|
event.RoomID(),
|
||||||
event.Type(),
|
event.EventID(),
|
||||||
event.Sender(),
|
event.Type(),
|
||||||
*event.StateKey(),
|
event.Sender(),
|
||||||
headeredJSON,
|
*event.StateKey(),
|
||||||
contentVal,
|
headeredJSON,
|
||||||
)
|
contentVal,
|
||||||
return err
|
)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
|
func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
@ -60,11 +61,16 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user
|
||||||
`
|
`
|
||||||
|
|
||||||
type mediaStatements struct {
|
type mediaStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
insertMediaStmt *sql.Stmt
|
insertMediaStmt *sql.Stmt
|
||||||
selectMediaStmt *sql.Stmt
|
selectMediaStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mediaStatements) prepare(db *sql.DB) (err error) {
|
func (s *mediaStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
s.db = db
|
||||||
|
s.writer = sqlutil.NewTransactionWriter()
|
||||||
|
|
||||||
_, err = db.Exec(mediaSchema)
|
_, err = db.Exec(mediaSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -80,18 +86,21 @@ func (s *mediaStatements) insertMedia(
|
||||||
ctx context.Context, mediaMetadata *types.MediaMetadata,
|
ctx context.Context, mediaMetadata *types.MediaMetadata,
|
||||||
) error {
|
) error {
|
||||||
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
|
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
|
||||||
_, err := s.insertMediaStmt.ExecContext(
|
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
ctx,
|
stmt := sqlutil.TxStmt(txn, s.insertMediaStmt)
|
||||||
mediaMetadata.MediaID,
|
_, err := stmt.ExecContext(
|
||||||
mediaMetadata.Origin,
|
ctx,
|
||||||
mediaMetadata.ContentType,
|
mediaMetadata.MediaID,
|
||||||
mediaMetadata.FileSizeBytes,
|
mediaMetadata.Origin,
|
||||||
mediaMetadata.CreationTimestamp,
|
mediaMetadata.ContentType,
|
||||||
mediaMetadata.UploadName,
|
mediaMetadata.FileSizeBytes,
|
||||||
mediaMetadata.Base64Hash,
|
mediaMetadata.CreationTimestamp,
|
||||||
mediaMetadata.UserID,
|
mediaMetadata.UploadName,
|
||||||
)
|
mediaMetadata.Base64Hash,
|
||||||
return err
|
mediaMetadata.UserID,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mediaStatements) selectMedia(
|
func (s *mediaStatements) selectMedia(
|
||||||
|
|
|
@ -18,6 +18,7 @@ package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
@ -65,13 +66,13 @@ func (r *RoomserverInternalAPI) processRoomEvent(
|
||||||
// Store the event.
|
// Store the event.
|
||||||
roomNID, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs)
|
roomNID, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return "", fmt.Errorf("r.DB.StoreEvent: %w", err)
|
||||||
}
|
}
|
||||||
// if storing this event results in it being redacted then do so.
|
// if storing this event results in it being redacted then do so.
|
||||||
if redactedEventID == event.EventID() {
|
if redactedEventID == event.EventID() {
|
||||||
r, rerr := eventutil.RedactEvent(redactionEvent, &event)
|
r, rerr := eventutil.RedactEvent(redactionEvent, &event)
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
return "", rerr
|
return "", fmt.Errorf("eventutil.RedactEvent: %w", rerr)
|
||||||
}
|
}
|
||||||
event = *r
|
event = *r
|
||||||
}
|
}
|
||||||
|
@ -93,7 +94,7 @@ func (r *RoomserverInternalAPI) processRoomEvent(
|
||||||
// Lets calculate one.
|
// Lets calculate one.
|
||||||
err = r.calculateAndSetState(ctx, input, roomNID, &stateAtEvent, event)
|
err = r.calculateAndSetState(ctx, input, roomNID, &stateAtEvent, event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return "", fmt.Errorf("r.calculateAndSetState: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,7 +106,7 @@ func (r *RoomserverInternalAPI) processRoomEvent(
|
||||||
input.SendAsServer, // send as server
|
input.SendAsServer, // send as server
|
||||||
input.TransactionID, // transaction ID
|
input.TransactionID, // transaction ID
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return
|
return "", fmt.Errorf("r.updateLatestEvents: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processing this event resulted in an event (which may not be the one we're processing)
|
// processing this event resulted in an event (which may not be the one we're processing)
|
||||||
|
@ -123,7 +124,7 @@ func (r *RoomserverInternalAPI) processRoomEvent(
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return "", fmt.Errorf("r.WriteOutputEvents: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
@ -362,7 +363,7 @@ func (d *Database) StoreEvent(
|
||||||
ctx, txn, txnAndSessionID.TransactionID,
|
ctx, txn, txnAndSessionID.TransactionID,
|
||||||
txnAndSessionID.SessionID, event.Sender(), event.EventID(),
|
txnAndSessionID.SessionID, event.Sender(), event.EventID(),
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return err
|
return fmt.Errorf("d.TransactionsTable.InsertTransaction: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -377,15 +378,15 @@ func (d *Database) StoreEvent(
|
||||||
// room.
|
// room.
|
||||||
var roomVersion gomatrixserverlib.RoomVersion
|
var roomVersion gomatrixserverlib.RoomVersion
|
||||||
if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil {
|
if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil {
|
||||||
return err
|
return fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil {
|
if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil {
|
||||||
return err
|
return fmt.Errorf("d.assignRoomNID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil {
|
if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil {
|
||||||
return err
|
return fmt.Errorf("d.assignEventTypeNID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
eventStateKey := event.StateKey()
|
eventStateKey := event.StateKey()
|
||||||
|
@ -393,7 +394,7 @@ func (d *Database) StoreEvent(
|
||||||
// Otherwise set the numeric ID for the state_key to 0.
|
// Otherwise set the numeric ID for the state_key to 0.
|
||||||
if eventStateKey != nil {
|
if eventStateKey != nil {
|
||||||
if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil {
|
if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil {
|
||||||
return err
|
return fmt.Errorf("d.assignStateKeyNID: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -411,17 +412,20 @@ func (d *Database) StoreEvent(
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
// We've already inserted the event so select the numeric event ID
|
// We've already inserted the event so select the numeric event ID
|
||||||
eventNID, stateNID, err = d.EventsTable.SelectEvent(ctx, txn, event.EventID())
|
eventNID, stateNID, err = d.EventsTable.SelectEvent(ctx, txn, event.EventID())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("d.EventsTable.SelectEvent: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("d.EventsTable.InsertEvent: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
|
if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
|
||||||
return err
|
return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
|
||||||
}
|
}
|
||||||
redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event)
|
redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event)
|
||||||
return err
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, types.StateAtEvent{}, nil, "", err
|
return 0, types.StateAtEvent{}, nil, "", err
|
||||||
|
|
|
@ -287,7 +287,8 @@ func (s *eventStatements) UpdateEventState(
|
||||||
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
_, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
|
stmt := sqlutil.TxStmt(txn, s.updateEventStateStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,7 +71,8 @@ func (s *publishedStatements) UpsertRoomPublished(
|
||||||
ctx context.Context, roomID string, published bool,
|
ctx context.Context, roomID string, published bool,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
_, err := s.upsertPublishedStmt.ExecContext(ctx, roomID, published)
|
stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, roomID, published)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -87,7 +87,8 @@ func (s *roomAliasesStatements) InsertRoomAlias(
|
||||||
ctx context.Context, alias string, roomID string, creatorUserID string,
|
ctx context.Context, alias string, roomID string, creatorUserID string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
_, err := s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID)
|
stmt := sqlutil.TxStmt(txn, s.insertRoomAliasStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, alias, roomID, creatorUserID)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -139,7 +140,8 @@ func (s *roomAliasesStatements) DeleteRoomAlias(
|
||||||
ctx context.Context, alias string,
|
ctx context.Context, alias string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
_, err := s.deleteRoomAliasStmt.ExecContext(ctx, alias)
|
stmt := sqlutil.TxStmt(txn, s.deleteRoomAliasStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, alias)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
|
@ -98,17 +99,23 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||||
func (s *roomStatements) InsertRoomNID(
|
func (s *roomStatements) InsertRoomNID(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
roomID string, roomVersion gomatrixserverlib.RoomVersion,
|
roomID string, roomVersion gomatrixserverlib.RoomVersion,
|
||||||
) (types.RoomNID, error) {
|
) (roomNID types.RoomNID, err error) {
|
||||||
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
|
insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
|
||||||
_, err := insertStmt.ExecContext(ctx, roomID, roomVersion)
|
_, err = insertStmt.ExecContext(ctx, roomID, roomVersion)
|
||||||
return err
|
if err != nil {
|
||||||
|
return fmt.Errorf("insertStmt.ExecContext: %w", err)
|
||||||
|
}
|
||||||
|
roomNID, err = s.SelectRoomNID(ctx, txn, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("s.SelectRoomNID: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
if err == nil {
|
if err != nil {
|
||||||
return s.SelectRoomNID(ctx, txn, roomID)
|
|
||||||
} else {
|
|
||||||
return types.RoomNID(0), err
|
return types.RoomNID(0), err
|
||||||
}
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomNID(
|
func (s *roomStatements) SelectRoomNID(
|
||||||
|
|
|
@ -63,12 +63,14 @@ const upsertServerKeysSQL = "" +
|
||||||
|
|
||||||
type serverKeyStatements struct {
|
type serverKeyStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
bulkSelectServerKeysStmt *sql.Stmt
|
bulkSelectServerKeysStmt *sql.Stmt
|
||||||
upsertServerKeysStmt *sql.Stmt
|
upsertServerKeysStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
|
func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
|
s.writer = sqlutil.NewTransactionWriter()
|
||||||
_, err = db.Exec(serverKeysSchema)
|
_, err = db.Exec(serverKeysSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -136,16 +138,19 @@ func (s *serverKeyStatements) upsertServerKeys(
|
||||||
request gomatrixserverlib.PublicKeyLookupRequest,
|
request gomatrixserverlib.PublicKeyLookupRequest,
|
||||||
key gomatrixserverlib.PublicKeyLookupResult,
|
key gomatrixserverlib.PublicKeyLookupResult,
|
||||||
) error {
|
) error {
|
||||||
_, err := s.upsertServerKeysStmt.ExecContext(
|
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
ctx,
|
stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt)
|
||||||
string(request.ServerName),
|
_, err := stmt.ExecContext(
|
||||||
string(request.KeyID),
|
ctx,
|
||||||
nameAndKeyID(request),
|
string(request.ServerName),
|
||||||
key.ValidUntilTS,
|
string(request.KeyID),
|
||||||
key.ExpiredTS,
|
nameAndKeyID(request),
|
||||||
key.Key.Encode(),
|
key.ValidUntilTS,
|
||||||
)
|
key.ExpiredTS,
|
||||||
return err
|
key.Key.Encode(),
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {
|
func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {
|
||||||
|
|
|
@ -281,16 +281,16 @@ func (d *Database) WriteEvent(
|
||||||
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
|
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err)
|
||||||
}
|
}
|
||||||
pduPosition = pos
|
pduPosition = pos
|
||||||
|
|
||||||
if err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil {
|
if err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil {
|
||||||
return err
|
return fmt.Errorf("d.Topology.InsertEventInTopology: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil {
|
if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil {
|
||||||
return err
|
return fmt.Errorf("d.handleBackwardExtremities: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
|
if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
|
||||||
|
@ -313,7 +313,7 @@ func (d *Database) updateRoomState(
|
||||||
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
|
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
|
||||||
for _, eventID := range removedEventIDs {
|
for _, eventID := range removedEventIDs {
|
||||||
if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil {
|
if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil {
|
||||||
return err
|
return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateByEventID: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -326,13 +326,13 @@ func (d *Database) updateRoomState(
|
||||||
if event.Type() == "m.room.member" {
|
if event.Type() == "m.room.member" {
|
||||||
value, err := event.Membership()
|
value, err := event.Membership()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("event.Membership: %w", err)
|
||||||
}
|
}
|
||||||
membership = &value
|
membership = &value
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil {
|
if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil {
|
||||||
return err
|
return fmt.Errorf("d.CurrentRoomState.UpsertRoomState: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -49,6 +50,8 @@ const selectMaxAccountDataIDSQL = "" +
|
||||||
"SELECT MAX(id) FROM syncapi_account_data_type"
|
"SELECT MAX(id) FROM syncapi_account_data_type"
|
||||||
|
|
||||||
type accountDataStatements struct {
|
type accountDataStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *streamIDStatements
|
||||||
insertAccountDataStmt *sql.Stmt
|
insertAccountDataStmt *sql.Stmt
|
||||||
selectMaxAccountDataIDStmt *sql.Stmt
|
selectMaxAccountDataIDStmt *sql.Stmt
|
||||||
|
@ -57,6 +60,8 @@ type accountDataStatements struct {
|
||||||
|
|
||||||
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
|
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
|
||||||
s := &accountDataStatements{
|
s := &accountDataStatements{
|
||||||
|
db: db,
|
||||||
|
writer: sqlutil.NewTransactionWriter(),
|
||||||
streamIDStatements: streamID,
|
streamIDStatements: streamID,
|
||||||
}
|
}
|
||||||
_, err := db.Exec(accountDataSchema)
|
_, err := db.Exec(accountDataSchema)
|
||||||
|
@ -79,12 +84,15 @@ func (s *accountDataStatements) InsertAccountData(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
userID, roomID, dataType string,
|
userID, roomID, dataType string,
|
||||||
) (pos types.StreamPosition, err error) {
|
) (pos types.StreamPosition, err error) {
|
||||||
pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
|
return pos, s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
if err != nil {
|
var err error
|
||||||
return
|
pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
|
||||||
}
|
if err != nil {
|
||||||
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
|
return err
|
||||||
return
|
}
|
||||||
|
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) SelectAccountDataInRange(
|
func (s *accountDataStatements) SelectAccountDataInRange(
|
||||||
|
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -47,13 +48,18 @@ const deleteBackwardExtremitySQL = "" +
|
||||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
|
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
|
||||||
|
|
||||||
type backwardExtremitiesStatements struct {
|
type backwardExtremitiesStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
insertBackwardExtremityStmt *sql.Stmt
|
insertBackwardExtremityStmt *sql.Stmt
|
||||||
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||||
deleteBackwardExtremityStmt *sql.Stmt
|
deleteBackwardExtremityStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
|
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
|
||||||
s := &backwardExtremitiesStatements{}
|
s := &backwardExtremitiesStatements{
|
||||||
|
db: db,
|
||||||
|
writer: sqlutil.NewTransactionWriter(),
|
||||||
|
}
|
||||||
_, err := db.Exec(backwardExtremitiesSchema)
|
_, err := db.Exec(backwardExtremitiesSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -73,8 +79,10 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
|
||||||
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
|
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
|
||||||
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
|
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
return
|
_, err := txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
|
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
|
||||||
|
@ -102,6 +110,8 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
|
||||||
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
||||||
ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
|
ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
return
|
_, err := txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -84,6 +84,8 @@ const selectEventsWithEventIDsSQL = "" +
|
||||||
" FROM syncapi_current_room_state WHERE event_id IN ($1)"
|
" FROM syncapi_current_room_state WHERE event_id IN ($1)"
|
||||||
|
|
||||||
type currentRoomStateStatements struct {
|
type currentRoomStateStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *streamIDStatements
|
||||||
upsertRoomStateStmt *sql.Stmt
|
upsertRoomStateStmt *sql.Stmt
|
||||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||||
|
@ -95,6 +97,8 @@ type currentRoomStateStatements struct {
|
||||||
|
|
||||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
|
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
|
||||||
s := ¤tRoomStateStatements{
|
s := ¤tRoomStateStatements{
|
||||||
|
db: db,
|
||||||
|
writer: sqlutil.NewTransactionWriter(),
|
||||||
streamIDStatements: streamID,
|
streamIDStatements: streamID,
|
||||||
}
|
}
|
||||||
_, err := db.Exec(currentRoomStateSchema)
|
_, err := db.Exec(currentRoomStateSchema)
|
||||||
|
@ -196,9 +200,11 @@ func (s *currentRoomStateStatements) SelectCurrentState(
|
||||||
func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
|
func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
|
||||||
ctx context.Context, txn *sql.Tx, eventID string,
|
ctx context.Context, txn *sql.Tx, eventID string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
_, err := stmt.ExecContext(ctx, eventID)
|
stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
|
||||||
return err
|
_, err := stmt.ExecContext(ctx, eventID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *currentRoomStateStatements) UpsertRoomState(
|
func (s *currentRoomStateStatements) UpsertRoomState(
|
||||||
|
@ -219,20 +225,22 @@ func (s *currentRoomStateStatements) UpsertRoomState(
|
||||||
}
|
}
|
||||||
|
|
||||||
// upsert state event
|
// upsert state event
|
||||||
stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
_, err = stmt.ExecContext(
|
stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
|
||||||
ctx,
|
_, err := stmt.ExecContext(
|
||||||
event.RoomID(),
|
ctx,
|
||||||
event.EventID(),
|
event.RoomID(),
|
||||||
event.Type(),
|
event.EventID(),
|
||||||
event.Sender(),
|
event.Type(),
|
||||||
containsURL,
|
event.Sender(),
|
||||||
*event.StateKey(),
|
containsURL,
|
||||||
headeredJSON,
|
*event.StateKey(),
|
||||||
membership,
|
headeredJSON,
|
||||||
addedAt,
|
membership,
|
||||||
)
|
addedAt,
|
||||||
return err
|
)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
|
func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
@ -50,6 +51,8 @@ const insertFilterSQL = "" +
|
||||||
"INSERT INTO syncapi_filter (filter, localpart) VALUES ($1, $2)"
|
"INSERT INTO syncapi_filter (filter, localpart) VALUES ($1, $2)"
|
||||||
|
|
||||||
type filterStatements struct {
|
type filterStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
selectFilterStmt *sql.Stmt
|
selectFilterStmt *sql.Stmt
|
||||||
selectFilterIDByContentStmt *sql.Stmt
|
selectFilterIDByContentStmt *sql.Stmt
|
||||||
insertFilterStmt *sql.Stmt
|
insertFilterStmt *sql.Stmt
|
||||||
|
@ -60,7 +63,10 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
s := &filterStatements{}
|
s := &filterStatements{
|
||||||
|
db: db,
|
||||||
|
writer: sqlutil.NewTransactionWriter(),
|
||||||
|
}
|
||||||
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
|
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -108,30 +114,33 @@ func (s *filterStatements) InsertFilter(
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if filter already exists in the database using its localpart and content
|
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
//
|
// Check if filter already exists in the database using its localpart and content
|
||||||
// This can result in a race condition when two clients try to insert the
|
//
|
||||||
// same filter and localpart at the same time, however this is not a
|
// This can result in a race condition when two clients try to insert the
|
||||||
// problem as both calls will result in the same filterID
|
// same filter and localpart at the same time, however this is not a
|
||||||
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
|
// problem as both calls will result in the same filterID
|
||||||
localpart, filterJSON).Scan(&existingFilterID)
|
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
|
||||||
if err != nil && err != sql.ErrNoRows {
|
localpart, filterJSON).Scan(&existingFilterID)
|
||||||
return "", err
|
if err != nil && err != sql.ErrNoRows {
|
||||||
}
|
return err
|
||||||
// If it does, return the existing ID
|
}
|
||||||
if existingFilterID != "" {
|
// If it does, return the existing ID
|
||||||
return existingFilterID, err
|
if existingFilterID != "" {
|
||||||
}
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Otherwise insert the filter and return the new ID
|
// Otherwise insert the filter and return the new ID
|
||||||
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
|
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return err
|
||||||
}
|
}
|
||||||
rowid, err := res.LastInsertId()
|
rowid, err := res.LastInsertId()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return err
|
||||||
}
|
}
|
||||||
filterID = fmt.Sprintf("%d", rowid)
|
filterID = fmt.Sprintf("%d", rowid)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,6 +58,8 @@ const selectMaxInviteIDSQL = "" +
|
||||||
"SELECT MAX(id) FROM syncapi_invite_events"
|
"SELECT MAX(id) FROM syncapi_invite_events"
|
||||||
|
|
||||||
type inviteEventsStatements struct {
|
type inviteEventsStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *streamIDStatements
|
||||||
insertInviteEventStmt *sql.Stmt
|
insertInviteEventStmt *sql.Stmt
|
||||||
selectInviteEventsInRangeStmt *sql.Stmt
|
selectInviteEventsInRangeStmt *sql.Stmt
|
||||||
|
@ -67,6 +69,8 @@ type inviteEventsStatements struct {
|
||||||
|
|
||||||
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
|
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
|
||||||
s := &inviteEventsStatements{
|
s := &inviteEventsStatements{
|
||||||
|
db: db,
|
||||||
|
writer: sqlutil.NewTransactionWriter(),
|
||||||
streamIDStatements: streamID,
|
streamIDStatements: streamID,
|
||||||
}
|
}
|
||||||
_, err := db.Exec(inviteEventsSchema)
|
_, err := db.Exec(inviteEventsSchema)
|
||||||
|
@ -91,36 +95,45 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Inv
|
||||||
func (s *inviteEventsStatements) InsertInviteEvent(
|
func (s *inviteEventsStatements) InsertInviteEvent(
|
||||||
ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent,
|
ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent,
|
||||||
) (streamPos types.StreamPosition, err error) {
|
) (streamPos types.StreamPosition, err error) {
|
||||||
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
|
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
if err != nil {
|
var err error
|
||||||
return
|
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
|
||||||
}
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
var headeredJSON []byte
|
var headeredJSON []byte
|
||||||
headeredJSON, err = json.Marshal(inviteEvent)
|
headeredJSON, err = json.Marshal(inviteEvent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = txn.Stmt(s.insertInviteEventStmt).ExecContext(
|
_, err = txn.Stmt(s.insertInviteEventStmt).ExecContext(
|
||||||
ctx,
|
ctx,
|
||||||
streamPos,
|
streamPos,
|
||||||
inviteEvent.RoomID(),
|
inviteEvent.RoomID(),
|
||||||
inviteEvent.EventID(),
|
inviteEvent.EventID(),
|
||||||
*inviteEvent.StateKey(),
|
*inviteEvent.StateKey(),
|
||||||
headeredJSON,
|
headeredJSON,
|
||||||
)
|
)
|
||||||
|
return err
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *inviteEventsStatements) DeleteInviteEvent(
|
func (s *inviteEventsStatements) DeleteInviteEvent(
|
||||||
ctx context.Context, inviteEventID string,
|
ctx context.Context, inviteEventID string,
|
||||||
) (types.StreamPosition, error) {
|
) (types.StreamPosition, error) {
|
||||||
streamPos, err := s.streamIDStatements.nextStreamID(ctx, nil)
|
var streamPos types.StreamPosition
|
||||||
if err != nil {
|
err := s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
return streamPos, err
|
var err error
|
||||||
}
|
streamPos, err = s.streamIDStatements.nextStreamID(ctx, nil)
|
||||||
_, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
return streamPos, err
|
return streamPos, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -104,6 +104,8 @@ const selectStateInRangeSQL = "" +
|
||||||
" LIMIT $8" // limit
|
" LIMIT $8" // limit
|
||||||
|
|
||||||
type outputRoomEventsStatements struct {
|
type outputRoomEventsStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *streamIDStatements
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
selectEventsStmt *sql.Stmt
|
selectEventsStmt *sql.Stmt
|
||||||
|
@ -117,6 +119,8 @@ type outputRoomEventsStatements struct {
|
||||||
|
|
||||||
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
|
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
|
||||||
s := &outputRoomEventsStatements{
|
s := &outputRoomEventsStatements{
|
||||||
|
db: db,
|
||||||
|
writer: sqlutil.NewTransactionWriter(),
|
||||||
streamIDStatements: streamID,
|
streamIDStatements: streamID,
|
||||||
}
|
}
|
||||||
_, err := db.Exec(outputRoomEventsSchema)
|
_, err := db.Exec(outputRoomEventsSchema)
|
||||||
|
@ -155,8 +159,10 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
|
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
return err
|
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
|
// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
|
||||||
|
@ -267,7 +273,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
event *gomatrixserverlib.HeaderedEvent, addState, removeState []string,
|
event *gomatrixserverlib.HeaderedEvent, addState, removeState []string,
|
||||||
transactionID *api.TransactionID, excludeFromSync bool,
|
transactionID *api.TransactionID, excludeFromSync bool,
|
||||||
) (streamPos types.StreamPosition, err error) {
|
) (types.StreamPosition, error) {
|
||||||
var txnID *string
|
var txnID *string
|
||||||
var sessionID *int64
|
var sessionID *int64
|
||||||
if transactionID != nil {
|
if transactionID != nil {
|
||||||
|
@ -284,43 +290,47 @@ func (s *outputRoomEventsStatements) InsertEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
var headeredJSON []byte
|
var headeredJSON []byte
|
||||||
headeredJSON, err = json.Marshal(event)
|
headeredJSON, err := json.Marshal(event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return 0, err
|
||||||
}
|
|
||||||
|
|
||||||
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
addStateJSON, err := json.Marshal(addState)
|
addStateJSON, err := json.Marshal(addState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return 0, err
|
||||||
}
|
}
|
||||||
removeStateJSON, err := json.Marshal(removeState)
|
removeStateJSON, err := json.Marshal(removeState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
|
var streamPos types.StreamPosition
|
||||||
_, err = insertStmt.ExecContext(
|
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
ctx,
|
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
|
||||||
streamPos,
|
if err != nil {
|
||||||
event.RoomID(),
|
return err
|
||||||
event.EventID(),
|
}
|
||||||
headeredJSON,
|
|
||||||
event.Type(),
|
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
|
||||||
event.Sender(),
|
_, ierr := insertStmt.ExecContext(
|
||||||
containsURL,
|
ctx,
|
||||||
string(addStateJSON),
|
streamPos,
|
||||||
string(removeStateJSON),
|
event.RoomID(),
|
||||||
sessionID,
|
event.EventID(),
|
||||||
txnID,
|
headeredJSON,
|
||||||
excludeFromSync,
|
event.Type(),
|
||||||
excludeFromSync,
|
event.Sender(),
|
||||||
)
|
containsURL,
|
||||||
return
|
string(addStateJSON),
|
||||||
|
string(removeStateJSON),
|
||||||
|
sessionID,
|
||||||
|
txnID,
|
||||||
|
excludeFromSync,
|
||||||
|
excludeFromSync,
|
||||||
|
)
|
||||||
|
return ierr
|
||||||
|
})
|
||||||
|
return streamPos, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *outputRoomEventsStatements) SelectRecentEvents(
|
func (s *outputRoomEventsStatements) SelectRecentEvents(
|
||||||
|
|
|
@ -66,6 +66,8 @@ const selectMaxPositionInTopologySQL = "" +
|
||||||
" WHERE room_id = $1 ORDER BY stream_position DESC"
|
" WHERE room_id = $1 ORDER BY stream_position DESC"
|
||||||
|
|
||||||
type outputRoomEventsTopologyStatements struct {
|
type outputRoomEventsTopologyStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
insertEventInTopologyStmt *sql.Stmt
|
insertEventInTopologyStmt *sql.Stmt
|
||||||
selectEventIDsInRangeASCStmt *sql.Stmt
|
selectEventIDsInRangeASCStmt *sql.Stmt
|
||||||
selectEventIDsInRangeDESCStmt *sql.Stmt
|
selectEventIDsInRangeDESCStmt *sql.Stmt
|
||||||
|
@ -74,7 +76,10 @@ type outputRoomEventsTopologyStatements struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
|
func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
|
||||||
s := &outputRoomEventsTopologyStatements{}
|
s := &outputRoomEventsTopologyStatements{
|
||||||
|
db: db,
|
||||||
|
writer: sqlutil.NewTransactionWriter(),
|
||||||
|
}
|
||||||
_, err := db.Exec(outputRoomEventsTopologySchema)
|
_, err := db.Exec(outputRoomEventsTopologySchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -102,11 +107,13 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
|
||||||
func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
|
func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
|
||||||
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition,
|
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
_, err = stmt.ExecContext(
|
stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt)
|
||||||
ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
|
_, err := stmt.ExecContext(
|
||||||
)
|
ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
|
||||||
return
|
)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
|
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
|
||||||
|
|
|
@ -72,13 +72,18 @@ const deleteSendToDeviceMessagesSQL = `
|
||||||
`
|
`
|
||||||
|
|
||||||
type sendToDeviceStatements struct {
|
type sendToDeviceStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
insertSendToDeviceMessageStmt *sql.Stmt
|
insertSendToDeviceMessageStmt *sql.Stmt
|
||||||
selectSendToDeviceMessagesStmt *sql.Stmt
|
selectSendToDeviceMessagesStmt *sql.Stmt
|
||||||
countSendToDeviceMessagesStmt *sql.Stmt
|
countSendToDeviceMessagesStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
||||||
s := &sendToDeviceStatements{}
|
s := &sendToDeviceStatements{
|
||||||
|
db: db,
|
||||||
|
writer: sqlutil.NewTransactionWriter(),
|
||||||
|
}
|
||||||
_, err := db.Exec(sendToDeviceSchema)
|
_, err := db.Exec(sendToDeviceSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -98,8 +103,10 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
||||||
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
|
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
|
||||||
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
|
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
return
|
_, err := sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
|
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
|
||||||
|
@ -156,8 +163,10 @@ func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
|
||||||
for k, v := range nids {
|
for k, v := range nids {
|
||||||
params[k+1] = v
|
params[k+1] = v
|
||||||
}
|
}
|
||||||
_, err = txn.ExecContext(ctx, query, params...)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
return
|
_, err := txn.ExecContext(ctx, query, params...)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
|
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
|
||||||
|
@ -168,6 +177,8 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
|
||||||
for k, v := range nids {
|
for k, v := range nids {
|
||||||
params[k] = v
|
params[k] = v
|
||||||
}
|
}
|
||||||
_, err = txn.ExecContext(ctx, query, params...)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
return
|
_, err := txn.ExecContext(ctx, query, params...)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,11 +27,15 @@ const selectStreamIDStmt = "" +
|
||||||
"SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1"
|
"SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1"
|
||||||
|
|
||||||
type streamIDStatements struct {
|
type streamIDStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
increaseStreamIDStmt *sql.Stmt
|
increaseStreamIDStmt *sql.Stmt
|
||||||
selectStreamIDStmt *sql.Stmt
|
selectStreamIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
s.db = db
|
||||||
|
s.writer = sqlutil.NewTransactionWriter()
|
||||||
_, err = db.Exec(streamIDTableSchema)
|
_, err = db.Exec(streamIDTableSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -48,11 +52,14 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
||||||
func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||||
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
|
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
|
||||||
if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
|
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
return
|
if _, ierr := increaseStmt.ExecContext(ctx, "global"); err != nil {
|
||||||
}
|
return ierr
|
||||||
if err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil {
|
}
|
||||||
return
|
if serr := selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil {
|
||||||
}
|
return serr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -52,7 +53,13 @@ func MustCreateEvent(t *testing.T, roomID string, prevs []gomatrixserverlib.Head
|
||||||
}
|
}
|
||||||
|
|
||||||
func MustCreateDatabase(t *testing.T) storage.Database {
|
func MustCreateDatabase(t *testing.T) storage.Database {
|
||||||
db, err := sqlite3.NewDatabase("file::memory:")
|
dbname := fmt.Sprintf("test_%s.db", t.Name())
|
||||||
|
if _, err := os.Stat(dbname); err == nil {
|
||||||
|
if err = os.Remove(dbname); err != nil {
|
||||||
|
t.Fatalf("tried to delete stale test database but failed: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
db, err := sqlite3.NewDatabase(fmt.Sprintf("file:%s", dbname))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewSyncServerDatasource returned %s", err)
|
t.Fatalf("NewSyncServerDatasource returned %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const accountDataSchema = `
|
const accountDataSchema = `
|
||||||
|
@ -48,12 +50,16 @@ const selectAccountDataByTypeSQL = "" +
|
||||||
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||||
|
|
||||||
type accountDataStatements struct {
|
type accountDataStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
insertAccountDataStmt *sql.Stmt
|
insertAccountDataStmt *sql.Stmt
|
||||||
selectAccountDataStmt *sql.Stmt
|
selectAccountDataStmt *sql.Stmt
|
||||||
selectAccountDataByTypeStmt *sql.Stmt
|
selectAccountDataByTypeStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
s.db = db
|
||||||
|
s.writer = sqlutil.NewTransactionWriter()
|
||||||
_, err = db.Exec(accountDataSchema)
|
_, err = db.Exec(accountDataSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -73,8 +79,10 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
||||||
func (s *accountDataStatements) insertAccountData(
|
func (s *accountDataStatements) insertAccountData(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
return
|
_, err := txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) selectAccountData(
|
func (s *accountDataStatements) selectAccountData(
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
@ -57,6 +58,8 @@ const selectNewNumericLocalpartSQL = "" +
|
||||||
// TODO: Update password
|
// TODO: Update password
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
selectAccountByLocalpartStmt *sql.Stmt
|
selectAccountByLocalpartStmt *sql.Stmt
|
||||||
selectPasswordHashStmt *sql.Stmt
|
selectPasswordHashStmt *sql.Stmt
|
||||||
|
@ -65,6 +68,8 @@ type accountsStatements struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||||
|
s.db = db
|
||||||
|
s.writer = sqlutil.NewTransactionWriter()
|
||||||
_, err = db.Exec(accountsSchema)
|
_, err = db.Exec(accountsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -94,12 +99,15 @@ func (s *accountsStatements) insertAccount(
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
stmt := s.insertAccountStmt
|
stmt := s.insertAccountStmt
|
||||||
|
|
||||||
var err error
|
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
if appserviceID == "" {
|
var err error
|
||||||
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
if appserviceID == "" {
|
||||||
} else {
|
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
||||||
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
} else {
|
||||||
}
|
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const profilesSchema = `
|
const profilesSchema = `
|
||||||
|
@ -46,6 +47,8 @@ const setDisplayNameSQL = "" +
|
||||||
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
|
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
type profilesStatements struct {
|
type profilesStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
insertProfileStmt *sql.Stmt
|
insertProfileStmt *sql.Stmt
|
||||||
selectProfileByLocalpartStmt *sql.Stmt
|
selectProfileByLocalpartStmt *sql.Stmt
|
||||||
setAvatarURLStmt *sql.Stmt
|
setAvatarURLStmt *sql.Stmt
|
||||||
|
@ -53,6 +56,8 @@ type profilesStatements struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
s.db = db
|
||||||
|
s.writer = sqlutil.NewTransactionWriter()
|
||||||
_, err = db.Exec(profilesSchema)
|
_, err = db.Exec(profilesSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -75,8 +80,10 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
||||||
func (s *profilesStatements) insertProfile(
|
func (s *profilesStatements) insertProfile(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx, localpart string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
return
|
_, err := txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) selectProfileByLocalpart(
|
func (s *profilesStatements) selectProfileByLocalpart(
|
||||||
|
|
|
@ -53,6 +53,8 @@ const deleteThreePIDSQL = "" +
|
||||||
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||||
|
|
||||||
type threepidStatements struct {
|
type threepidStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
selectLocalpartForThreePIDStmt *sql.Stmt
|
selectLocalpartForThreePIDStmt *sql.Stmt
|
||||||
selectThreePIDsForLocalpartStmt *sql.Stmt
|
selectThreePIDsForLocalpartStmt *sql.Stmt
|
||||||
insertThreePIDStmt *sql.Stmt
|
insertThreePIDStmt *sql.Stmt
|
||||||
|
@ -60,6 +62,8 @@ type threepidStatements struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
|
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
s.db = db
|
||||||
|
s.writer = sqlutil.NewTransactionWriter()
|
||||||
_, err = db.Exec(threepidSchema)
|
_, err = db.Exec(threepidSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -118,13 +122,18 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
|
||||||
func (s *threepidStatements) insertThreePID(
|
func (s *threepidStatements) insertThreePID(
|
||||||
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
|
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
||||||
return
|
_, err := stmt.ExecContext(ctx, threepid, medium, localpart)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) deleteThreePID(
|
func (s *threepidStatements) deleteThreePID(
|
||||||
ctx context.Context, threepid string, medium string) (err error) {
|
ctx context.Context, threepid string, medium string) (err error) {
|
||||||
_, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium)
|
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
return
|
stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, threepid, medium)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,6 +74,7 @@ const deleteDevicesSQL = "" +
|
||||||
|
|
||||||
type devicesStatements struct {
|
type devicesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
insertDeviceStmt *sql.Stmt
|
insertDeviceStmt *sql.Stmt
|
||||||
selectDevicesCountStmt *sql.Stmt
|
selectDevicesCountStmt *sql.Stmt
|
||||||
selectDeviceByTokenStmt *sql.Stmt
|
selectDeviceByTokenStmt *sql.Stmt
|
||||||
|
@ -87,6 +88,7 @@ type devicesStatements struct {
|
||||||
|
|
||||||
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
|
s.writer = sqlutil.NewTransactionWriter()
|
||||||
_, err = db.Exec(devicesSchema)
|
_, err = db.Exec(devicesSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -128,13 +130,19 @@ func (s *devicesStatements) insertDevice(
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
var sessionID int64
|
var sessionID int64
|
||||||
countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
|
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
|
countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
|
||||||
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
|
insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
|
||||||
return nil, err
|
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
|
||||||
}
|
return err
|
||||||
sessionID++
|
}
|
||||||
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
|
sessionID++
|
||||||
|
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &api.Device{
|
return &api.Device{
|
||||||
|
@ -148,9 +156,11 @@ func (s *devicesStatements) insertDevice(
|
||||||
func (s *devicesStatements) deleteDevice(
|
func (s *devicesStatements) deleteDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
_, err := stmt.ExecContext(ctx, id, localpart)
|
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||||
return err
|
_, err := stmt.ExecContext(ctx, id, localpart)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) deleteDevices(
|
func (s *devicesStatements) deleteDevices(
|
||||||
|
@ -161,31 +171,37 @@ func (s *devicesStatements) deleteDevices(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stmt := sqlutil.TxStmt(txn, prep)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
params := make([]interface{}, len(devices)+1)
|
stmt := sqlutil.TxStmt(txn, prep)
|
||||||
params[0] = localpart
|
params := make([]interface{}, len(devices)+1)
|
||||||
for i, v := range devices {
|
params[0] = localpart
|
||||||
params[i+1] = v
|
for i, v := range devices {
|
||||||
}
|
params[i+1] = v
|
||||||
params = append(params, params...)
|
}
|
||||||
_, err = stmt.ExecContext(ctx, params...)
|
params = append(params, params...)
|
||||||
return err
|
_, err = stmt.ExecContext(ctx, params...)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) deleteDevicesByLocalpart(
|
func (s *devicesStatements) deleteDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx, localpart string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
_, err := stmt.ExecContext(ctx, localpart)
|
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||||
return err
|
_, err := stmt.ExecContext(ctx, localpart)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) updateDeviceName(
|
func (s *devicesStatements) updateDeviceName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||||
return err
|
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) selectDeviceByToken(
|
func (s *devicesStatements) selectDeviceByToken(
|
||||||
|
|
Loading…
Reference in a new issue