mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-18 04:13:10 -06:00
Not sure if roomserver is better or worse now
This commit is contained in:
parent
fca23c356a
commit
0360f07110
|
|
@ -114,15 +114,6 @@ func Setup(
|
||||||
return SendMembership(req, accountDB, device, vars["roomID"], vars["membership"], cfg, queryAPI, asAPI, producer)
|
return SendMembership(req, accountDB, device, vars["roomID"], vars["membership"], cfg, queryAPI, asAPI, producer)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
r0mux.Handle("/rooms/{roomID}/send/{eventType}",
|
|
||||||
common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
|
|
||||||
vars, err := common.URLDecodeMapValues(mux.Vars(req))
|
|
||||||
if err != nil {
|
|
||||||
return util.ErrorResponse(err)
|
|
||||||
}
|
|
||||||
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, queryAPI, producer, nil)
|
|
||||||
}),
|
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
|
||||||
r0mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
|
r0mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
|
||||||
common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
|
common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
|
||||||
vars, err := common.URLDecodeMapValues(mux.Vars(req))
|
vars, err := common.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
|
|
||||||
|
|
@ -124,8 +124,10 @@ func processRoomEvent(
|
||||||
if stateAtEvent.BeforeStateSnapshotNID == 0 {
|
if stateAtEvent.BeforeStateSnapshotNID == 0 {
|
||||||
// We haven't calculated a state for this event yet.
|
// We haven't calculated a state for this event yet.
|
||||||
// Lets calculate one.
|
// Lets calculate one.
|
||||||
|
fmt.Println("We don't have a state snapshot NID yet")
|
||||||
err = calculateAndSetState(ctx, db, input, roomNID, &stateAtEvent, event)
|
err = calculateAndSetState(ctx, db, input, roomNID, &stateAtEvent, event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
fmt.Println("Failed to calculateAndSetState:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -151,6 +153,7 @@ func calculateAndSetState(
|
||||||
) error {
|
) error {
|
||||||
var err error
|
var err error
|
||||||
if input.HasState {
|
if input.HasState {
|
||||||
|
fmt.Println("We have state")
|
||||||
// We've been told what the state at the event is so we don't need to calculate it.
|
// We've been told what the state at the event is so we don't need to calculate it.
|
||||||
// Check that those state events are in the database and store the state.
|
// Check that those state events are in the database and store the state.
|
||||||
var entries []types.StateEntry
|
var entries []types.StateEntry
|
||||||
|
|
@ -162,11 +165,13 @@ func calculateAndSetState(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
fmt.Println("We don't have state")
|
||||||
// We haven't been told what the state at the event is so we need to calculate it from the prev_events
|
// We haven't been told what the state at the event is so we need to calculate it from the prev_events
|
||||||
if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(ctx, db, event, roomNID); err != nil {
|
if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(ctx, db, event, roomNID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
fmt.Println("Then set state", stateAtEvent)
|
||||||
return db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
|
return db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -556,11 +556,15 @@ func CalculateAndStoreStateBeforeEvent(
|
||||||
prevEventIDs[i] = prevEventRefs[i].EventID
|
prevEventIDs[i] = prevEventRefs[i].EventID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("Previous event IDs:", prevEventIDs)
|
||||||
|
|
||||||
prevStates, err := db.StateAtEventIDs(ctx, prevEventIDs)
|
prevStates, err := db.StateAtEventIDs(ctx, prevEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("Previous states:", prevStates)
|
||||||
|
|
||||||
// The state before this event will be the state after the events that came before it.
|
// The state before this event will be the state after the events that came before it.
|
||||||
return CalculateAndStoreStateAfterEvents(ctx, db, roomNID, prevStates)
|
return CalculateAndStoreStateAfterEvents(ctx, db, roomNID, prevStates)
|
||||||
}
|
}
|
||||||
|
|
@ -574,7 +578,6 @@ func CalculateAndStoreStateAfterEvents(
|
||||||
prevStates []types.StateAtEvent,
|
prevStates []types.StateAtEvent,
|
||||||
) (types.StateSnapshotNID, error) {
|
) (types.StateSnapshotNID, error) {
|
||||||
metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)}
|
metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)}
|
||||||
|
|
||||||
if len(prevStates) == 0 {
|
if len(prevStates) == 0 {
|
||||||
// 2) There weren't any prev_events for this event so the state is
|
// 2) There weren't any prev_events for this event so the state is
|
||||||
// empty.
|
// empty.
|
||||||
|
|
@ -592,6 +595,7 @@ func CalculateAndStoreStateAfterEvents(
|
||||||
metrics.algorithm = "no_change"
|
metrics.algorithm = "no_change"
|
||||||
return metrics.stop(prevState.BeforeStateSnapshotNID, nil)
|
return metrics.stop(prevState.BeforeStateSnapshotNID, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// The previous event was a state event so we need to store a copy
|
// The previous event was a state event so we need to store a copy
|
||||||
// of the previous state updated with that event.
|
// of the previous state updated with that event.
|
||||||
stateBlockNIDLists, err := db.StateBlockNIDs(
|
stateBlockNIDLists, err := db.StateBlockNIDs(
|
||||||
|
|
@ -614,6 +618,7 @@ func CalculateAndStoreStateAfterEvents(
|
||||||
// So fall through to calculateAndStoreStateAfterManyEvents
|
// So fall through to calculateAndStoreStateAfterManyEvents
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("Falling through to calculateAndStoreStateAfterManyEvents")
|
||||||
return calculateAndStoreStateAfterManyEvents(ctx, db, roomNID, prevStates, metrics)
|
return calculateAndStoreStateAfterManyEvents(ctx, db, roomNID, prevStates, metrics)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -262,7 +262,9 @@ func (s *eventStatements) bulkSelectStateAtEventByID(
|
||||||
func (s *eventStatements) updateEventState(
|
func (s *eventStatements) updateEventState(
|
||||||
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
||||||
) error {
|
) error {
|
||||||
|
fmt.Println("updateEventState eventNID", eventNID, "stateNID", stateNID)
|
||||||
_, err := s.updateEventStateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID))
|
_, err := s.updateEventStateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID))
|
||||||
|
fmt.Println("Errors?", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
|
@ -46,11 +47,13 @@ const bulkSelectEventJSONSQL = `
|
||||||
`
|
`
|
||||||
|
|
||||||
type eventJSONStatements struct {
|
type eventJSONStatements struct {
|
||||||
|
db *sql.DB
|
||||||
insertEventJSONStmt *sql.Stmt
|
insertEventJSONStmt *sql.Stmt
|
||||||
bulkSelectEventJSONStmt *sql.Stmt
|
bulkSelectEventJSONStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventJSONStatements) prepare(db *sql.DB) (err error) {
|
func (s *eventJSONStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
s.db = db
|
||||||
_, err = db.Exec(eventJSONSchema)
|
_, err = db.Exec(eventJSONSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
@ -76,7 +79,19 @@ type eventJSONPair struct {
|
||||||
func (s *eventJSONStatements) bulkSelectEventJSON(
|
func (s *eventJSONStatements) bulkSelectEventJSON(
|
||||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||||
) ([]eventJSONPair, error) {
|
) ([]eventJSONPair, error) {
|
||||||
rows, err := common.TxStmt(txn, s.bulkSelectEventJSONStmt).QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
///////////////
|
||||||
|
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||||
|
for k, v := range eventNIDs {
|
||||||
|
iEventNIDs[k] = v
|
||||||
|
}
|
||||||
|
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
||||||
|
selectPrep, err := s.db.Prepare(selectOrig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
///////////////
|
||||||
|
|
||||||
|
rows, err := common.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("bulkSelectEventJSON s.bulkSelectEventJSONStmt.QueryContext:", err)
|
fmt.Println("bulkSelectEventJSON s.bulkSelectEventJSONStmt.QueryContext:", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
@ -66,6 +66,7 @@ const bulkSelectEventStateKeySQL = `
|
||||||
`
|
`
|
||||||
|
|
||||||
type eventStateKeyStatements struct {
|
type eventStateKeyStatements struct {
|
||||||
|
db *sql.DB
|
||||||
insertEventStateKeyNIDStmt *sql.Stmt
|
insertEventStateKeyNIDStmt *sql.Stmt
|
||||||
insertEventStateKeyNIDResultStmt *sql.Stmt
|
insertEventStateKeyNIDResultStmt *sql.Stmt
|
||||||
selectEventStateKeyNIDStmt *sql.Stmt
|
selectEventStateKeyNIDStmt *sql.Stmt
|
||||||
|
|
@ -74,6 +75,7 @@ type eventStateKeyStatements struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
|
func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
s.db = db
|
||||||
_, err = db.Exec(eventStateKeysSchema)
|
_, err = db.Exec(eventStateKeysSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
@ -110,18 +112,32 @@ func (s *eventStateKeyStatements) selectEventStateKeyNID(
|
||||||
) (types.EventStateKeyNID, error) {
|
) (types.EventStateKeyNID, error) {
|
||||||
var eventStateKeyNID int64
|
var eventStateKeyNID int64
|
||||||
stmt := common.TxStmt(txn, s.selectEventStateKeyNIDStmt)
|
stmt := common.TxStmt(txn, s.selectEventStateKeyNIDStmt)
|
||||||
|
fmt.Println("selectEventStateKeyNID for", eventStateKey)
|
||||||
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
|
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("selectEventStateKeyNID stmt.QueryRowContext:", err)
|
fmt.Println("selectEventStateKeyNID stmt.QueryRowContext:", err)
|
||||||
}
|
}
|
||||||
|
fmt.Println("selectEventStateKeyNID returns", eventStateKeyNID)
|
||||||
return types.EventStateKeyNID(eventStateKeyNID), err
|
return types.EventStateKeyNID(eventStateKeyNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
|
func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
|
||||||
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
|
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
|
||||||
) (map[string]types.EventStateKeyNID, error) {
|
) (map[string]types.EventStateKeyNID, error) {
|
||||||
rows, err := common.TxStmt(txn, s.bulkSelectEventStateKeyNIDStmt).QueryContext(
|
///////////////
|
||||||
ctx, sqliteInStr(pq.StringArray(eventStateKeys)),
|
iEventStateKeys := make([]interface{}, len(eventStateKeys))
|
||||||
|
for k, v := range eventStateKeys {
|
||||||
|
iEventStateKeys[k] = v
|
||||||
|
}
|
||||||
|
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeys)), 1)
|
||||||
|
selectPrep, err := s.db.Prepare(selectOrig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
///////////////
|
||||||
|
|
||||||
|
rows, err := common.TxStmt(txn, selectPrep).QueryContext(
|
||||||
|
ctx, iEventStateKeys...,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("bulkSelectEventStateKeyNID s.bulkSelectEventStateKeyNIDStmt.QueryContext:", err)
|
fmt.Println("bulkSelectEventStateKeyNID s.bulkSelectEventStateKeyNIDStmt.QueryContext:", err)
|
||||||
|
|
@ -144,11 +160,23 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
|
||||||
func (s *eventStateKeyStatements) bulkSelectEventStateKey(
|
func (s *eventStateKeyStatements) bulkSelectEventStateKey(
|
||||||
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
|
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||||
) (map[types.EventStateKeyNID]string, error) {
|
) (map[types.EventStateKeyNID]string, error) {
|
||||||
nIDs := make(pq.Int64Array, len(eventStateKeyNIDs))
|
///////////////
|
||||||
|
iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
|
||||||
|
for k, v := range eventStateKeyNIDs {
|
||||||
|
iEventStateKeyNIDs[k] = v
|
||||||
|
}
|
||||||
|
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeyNIDs)), 1)
|
||||||
|
selectPrep, err := s.db.Prepare(selectOrig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
///////////////
|
||||||
|
|
||||||
|
/*nIDs := make(pq.Int64Array, len(eventStateKeyNIDs))
|
||||||
for i := range eventStateKeyNIDs {
|
for i := range eventStateKeyNIDs {
|
||||||
nIDs[i] = int64(eventStateKeyNIDs[i])
|
nIDs[i] = int64(eventStateKeyNIDs[i])
|
||||||
}
|
}*/
|
||||||
rows, err := common.TxStmt(txn, s.bulkSelectEventStateKeyStmt).QueryContext(ctx, nIDs)
|
rows, err := common.TxStmt(txn, selectPrep).QueryContext(ctx, iEventStateKeyNIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("bulkSelectEventStateKey s.bulkSelectEventStateKeyStmt.QueryContext:", err)
|
fmt.Println("bulkSelectEventStateKey s.bulkSelectEventStateKeyStmt.QueryContext:", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,8 @@ package sqlite3
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
@ -74,6 +74,7 @@ const bulkSelectEventTypeNIDSQL = `
|
||||||
`
|
`
|
||||||
|
|
||||||
type eventTypeStatements struct {
|
type eventTypeStatements struct {
|
||||||
|
db *sql.DB
|
||||||
insertEventTypeNIDStmt *sql.Stmt
|
insertEventTypeNIDStmt *sql.Stmt
|
||||||
insertEventTypeNIDResultStmt *sql.Stmt
|
insertEventTypeNIDResultStmt *sql.Stmt
|
||||||
selectEventTypeNIDStmt *sql.Stmt
|
selectEventTypeNIDStmt *sql.Stmt
|
||||||
|
|
@ -81,6 +82,7 @@ type eventTypeStatements struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventTypeStatements) prepare(db *sql.DB) (err error) {
|
func (s *eventTypeStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
s.db = db
|
||||||
_, err = db.Exec(eventTypesSchema)
|
_, err = db.Exec(eventTypesSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
@ -119,8 +121,20 @@ func (s *eventTypeStatements) selectEventTypeNID(
|
||||||
func (s *eventTypeStatements) bulkSelectEventTypeNID(
|
func (s *eventTypeStatements) bulkSelectEventTypeNID(
|
||||||
ctx context.Context, tx *sql.Tx, eventTypes []string,
|
ctx context.Context, tx *sql.Tx, eventTypes []string,
|
||||||
) (map[string]types.EventTypeNID, error) {
|
) (map[string]types.EventTypeNID, error) {
|
||||||
selectStmt := common.TxStmt(tx, s.bulkSelectEventTypeNIDStmt)
|
///////////////
|
||||||
rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventTypes)))
|
iEventTypes := make([]interface{}, len(eventTypes))
|
||||||
|
for k, v := range eventTypes {
|
||||||
|
iEventTypes[k] = v
|
||||||
|
}
|
||||||
|
selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", queryVariadic(len(iEventTypes)), 1)
|
||||||
|
selectPrep, err := s.db.Prepare(selectOrig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
///////////////
|
||||||
|
|
||||||
|
selectStmt := common.TxStmt(tx, selectPrep)
|
||||||
|
rows, err := selectStmt.QueryContext(ctx, iEventTypes...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
|
@ -96,6 +97,7 @@ const selectMaxEventDepthSQL = "" +
|
||||||
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
|
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
|
||||||
|
|
||||||
type eventStatements struct {
|
type eventStatements struct {
|
||||||
|
db *sql.DB
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
insertEventResultStmt *sql.Stmt
|
insertEventResultStmt *sql.Stmt
|
||||||
selectEventStmt *sql.Stmt
|
selectEventStmt *sql.Stmt
|
||||||
|
|
@ -113,6 +115,7 @@ type eventStatements struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) prepare(db *sql.DB) (err error) {
|
func (s *eventStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
s.db = db
|
||||||
_, err = db.Exec(eventsSchema)
|
_, err = db.Exec(eventsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
@ -157,7 +160,14 @@ func (s *eventStatements) insertEvent(
|
||||||
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
|
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
|
||||||
); err == nil {
|
); err == nil {
|
||||||
err = resultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID)
|
err = resultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("insertEvent HAS FAILED!", err)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
fmt.Println("insertEvent HAS GONE WRONG!", err)
|
||||||
|
}
|
||||||
|
fmt.Println("Event NID:", eventNID)
|
||||||
|
fmt.Println("State snapshot NID:", stateNID)
|
||||||
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -176,8 +186,20 @@ func (s *eventStatements) selectEvent(
|
||||||
func (s *eventStatements) bulkSelectStateEventByID(
|
func (s *eventStatements) bulkSelectStateEventByID(
|
||||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
selectStmt := common.TxStmt(txn, s.bulkSelectStateEventByIDStmt)
|
///////////////
|
||||||
rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventIDs)))
|
iEventIDs := make([]interface{}, len(eventIDs))
|
||||||
|
for k, v := range eventIDs {
|
||||||
|
iEventIDs[k] = v
|
||||||
|
}
|
||||||
|
selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
|
||||||
|
selectPrep, err := s.db.Prepare(selectOrig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
///////////////
|
||||||
|
|
||||||
|
selectStmt := common.TxStmt(txn, selectPrep)
|
||||||
|
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -217,8 +239,20 @@ func (s *eventStatements) bulkSelectStateEventByID(
|
||||||
func (s *eventStatements) bulkSelectStateAtEventByID(
|
func (s *eventStatements) bulkSelectStateAtEventByID(
|
||||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
) ([]types.StateAtEvent, error) {
|
) ([]types.StateAtEvent, error) {
|
||||||
selectStmt := common.TxStmt(txn, s.bulkSelectStateAtEventByIDStmt)
|
///////////////
|
||||||
rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventIDs)))
|
iEventIDs := make([]interface{}, len(eventIDs))
|
||||||
|
for k, v := range eventIDs {
|
||||||
|
iEventIDs[k] = v
|
||||||
|
}
|
||||||
|
selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
|
||||||
|
selectPrep, err := s.db.Prepare(selectOrig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
///////////////
|
||||||
|
|
||||||
|
selectStmt := common.TxStmt(txn, selectPrep)
|
||||||
|
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -296,8 +330,20 @@ func (s *eventStatements) selectEventID(
|
||||||
func (s *eventStatements) bulkSelectStateAtEventAndReference(
|
func (s *eventStatements) bulkSelectStateAtEventAndReference(
|
||||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||||
) ([]types.StateAtEventAndReference, error) {
|
) ([]types.StateAtEventAndReference, error) {
|
||||||
selectStmt := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt)
|
///////////////
|
||||||
rows, err := selectStmt.QueryContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs)))
|
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||||
|
for k, v := range eventNIDs {
|
||||||
|
iEventNIDs[k] = v
|
||||||
|
}
|
||||||
|
selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
||||||
|
selectPrep, err := s.db.Prepare(selectOrig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
///////////////
|
||||||
|
|
||||||
|
selectStmt := common.TxStmt(txn, selectPrep)
|
||||||
|
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("bulkSelectStateAtEventAndREference stmt.QueryContext:", err)
|
fmt.Println("bulkSelectStateAtEventAndREference stmt.QueryContext:", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -337,8 +383,20 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(
|
||||||
func (s *eventStatements) bulkSelectEventReference(
|
func (s *eventStatements) bulkSelectEventReference(
|
||||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||||
) ([]gomatrixserverlib.EventReference, error) {
|
) ([]gomatrixserverlib.EventReference, error) {
|
||||||
selectStmt := common.TxStmt(txn, s.bulkSelectEventReferenceStmt)
|
///////////////
|
||||||
rows, err := selectStmt.QueryContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs)))
|
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||||
|
for k, v := range eventNIDs {
|
||||||
|
iEventNIDs[k] = v
|
||||||
|
}
|
||||||
|
selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
||||||
|
selectPrep, err := s.db.Prepare(selectOrig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
///////////////
|
||||||
|
|
||||||
|
selectStmt := common.TxStmt(txn, selectPrep)
|
||||||
|
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("bulkSelectEventReference s.bulkSelectEventReferenceStmt.QueryContext:", err)
|
fmt.Println("bulkSelectEventReference s.bulkSelectEventReferenceStmt.QueryContext:", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -361,8 +419,20 @@ func (s *eventStatements) bulkSelectEventReference(
|
||||||
|
|
||||||
// bulkSelectEventID returns a map from numeric event ID to string event ID.
|
// bulkSelectEventID returns a map from numeric event ID to string event ID.
|
||||||
func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
||||||
selectStmt := common.TxStmt(txn, s.bulkSelectEventIDStmt)
|
///////////////
|
||||||
rows, err := selectStmt.QueryContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs)))
|
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||||
|
for k, v := range eventNIDs {
|
||||||
|
iEventNIDs[k] = v
|
||||||
|
}
|
||||||
|
selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
||||||
|
selectPrep, err := s.db.Prepare(selectOrig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
///////////////
|
||||||
|
|
||||||
|
selectStmt := common.TxStmt(txn, selectPrep)
|
||||||
|
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("bulkSelectEventID s.bulkSelectEventIDStmt.QueryContext:", err)
|
fmt.Println("bulkSelectEventID s.bulkSelectEventIDStmt.QueryContext:", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -388,8 +458,20 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
|
||||||
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
|
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
|
||||||
// If an event ID is not in the database then it is omitted from the map.
|
// If an event ID is not in the database then it is omitted from the map.
|
||||||
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
|
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
|
||||||
selectStmt := common.TxStmt(txn, s.bulkSelectEventNIDStmt)
|
///////////////
|
||||||
rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventIDs)))
|
iEventIDs := make([]interface{}, len(eventIDs))
|
||||||
|
for k, v := range eventIDs {
|
||||||
|
iEventIDs[k] = v
|
||||||
|
}
|
||||||
|
selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
|
||||||
|
selectPrep, err := s.db.Prepare(selectOrig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
///////////////
|
||||||
|
|
||||||
|
selectStmt := common.TxStmt(txn, selectPrep)
|
||||||
|
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("bulkSelectEventNID s.bulkSelectEventNIDStmt.QueryContext:", err)
|
fmt.Println("bulkSelectEventNID s.bulkSelectEventNIDStmt.QueryContext:", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
||||||
|
|
@ -41,11 +41,6 @@ const insertRoomNIDSQL = `
|
||||||
ON CONFLICT DO NOTHING;
|
ON CONFLICT DO NOTHING;
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertRoomNIDResultSQL = `
|
|
||||||
SELECT room_nid FROM roomserver_rooms
|
|
||||||
WHERE rowid = last_insert_rowid();
|
|
||||||
`
|
|
||||||
|
|
||||||
const selectRoomNIDSQL = "" +
|
const selectRoomNIDSQL = "" +
|
||||||
"SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
|
"SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
|
||||||
|
|
||||||
|
|
@ -60,7 +55,6 @@ const updateLatestEventNIDsSQL = "" +
|
||||||
|
|
||||||
type roomStatements struct {
|
type roomStatements struct {
|
||||||
insertRoomNIDStmt *sql.Stmt
|
insertRoomNIDStmt *sql.Stmt
|
||||||
insertRoomNIDResultStmt *sql.Stmt
|
|
||||||
selectRoomNIDStmt *sql.Stmt
|
selectRoomNIDStmt *sql.Stmt
|
||||||
selectLatestEventNIDsStmt *sql.Stmt
|
selectLatestEventNIDsStmt *sql.Stmt
|
||||||
selectLatestEventNIDsForUpdateStmt *sql.Stmt
|
selectLatestEventNIDsForUpdateStmt *sql.Stmt
|
||||||
|
|
@ -74,7 +68,6 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
|
||||||
}
|
}
|
||||||
return statementList{
|
return statementList{
|
||||||
{&s.insertRoomNIDStmt, insertRoomNIDSQL},
|
{&s.insertRoomNIDStmt, insertRoomNIDSQL},
|
||||||
{&s.insertRoomNIDResultStmt, insertRoomNIDResultSQL},
|
|
||||||
{&s.selectRoomNIDStmt, selectRoomNIDSQL},
|
{&s.selectRoomNIDStmt, selectRoomNIDSQL},
|
||||||
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
|
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
|
||||||
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
|
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
|
||||||
|
|
@ -85,19 +78,14 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
|
||||||
func (s *roomStatements) insertRoomNID(
|
func (s *roomStatements) insertRoomNID(
|
||||||
ctx context.Context, txn *sql.Tx, roomID string,
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
) (types.RoomNID, error) {
|
) (types.RoomNID, error) {
|
||||||
var roomNID int64
|
|
||||||
var err error
|
var err error
|
||||||
insertStmt := common.TxStmt(txn, s.insertRoomNIDStmt)
|
insertStmt := common.TxStmt(txn, s.insertRoomNIDStmt)
|
||||||
resultStmt := common.TxStmt(txn, s.insertRoomNIDResultStmt)
|
|
||||||
if _, err = insertStmt.ExecContext(ctx, roomID); err == nil {
|
if _, err = insertStmt.ExecContext(ctx, roomID); err == nil {
|
||||||
err = resultStmt.QueryRowContext(ctx).Scan(&roomNID)
|
return s.selectRoomNID(ctx, txn, roomID)
|
||||||
if err != nil {
|
|
||||||
fmt.Println("insertRoomNID resultStmt.QueryRowContext:", err)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
fmt.Println("insertRoomNID insertStmt.ExecContext:", err)
|
fmt.Println("insertRoomNID insertStmt.ExecContext:", err)
|
||||||
|
return types.RoomNID(0), err
|
||||||
}
|
}
|
||||||
return types.RoomNID(roomNID), err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) selectRoomNID(
|
func (s *roomStatements) selectRoomNID(
|
||||||
|
|
@ -106,9 +94,6 @@ func (s *roomStatements) selectRoomNID(
|
||||||
var roomNID int64
|
var roomNID int64
|
||||||
stmt := common.TxStmt(txn, s.selectRoomNIDStmt)
|
stmt := common.TxStmt(txn, s.selectRoomNIDStmt)
|
||||||
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
|
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
|
||||||
if err != nil {
|
|
||||||
fmt.Println("selectRoomNID stmt.QueryRowContext:", err)
|
|
||||||
}
|
|
||||||
return types.RoomNID(roomNID), err
|
return types.RoomNID(roomNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ package sqlite3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
type statements struct {
|
type statements struct {
|
||||||
|
|
@ -58,3 +59,16 @@ func (s *statements) prepare(db *sql.DB) error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hack of the century
|
||||||
|
func queryVariadic(count int) string {
|
||||||
|
str := "("
|
||||||
|
for i := 1; i <= count; i++ {
|
||||||
|
str += fmt.Sprintf("$%d", i)
|
||||||
|
if i < count {
|
||||||
|
str += ", "
|
||||||
|
}
|
||||||
|
}
|
||||||
|
str += ")"
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,9 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"runtime/debug"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
|
@ -44,7 +46,7 @@ const insertStateDataSQL = "" +
|
||||||
const selectNextStateBlockNIDSQL = `
|
const selectNextStateBlockNIDSQL = `
|
||||||
SELECT COALESCE((
|
SELECT COALESCE((
|
||||||
SELECT seq+1 AS state_block_nid FROM sqlite_sequence
|
SELECT seq+1 AS state_block_nid FROM sqlite_sequence
|
||||||
WHERE name = 'roomserver_state_block'), 0
|
WHERE name = 'roomserver_state_block'), 1
|
||||||
) AS state_block_nid
|
) AS state_block_nid
|
||||||
`
|
`
|
||||||
|
|
||||||
|
|
@ -73,6 +75,7 @@ const bulkSelectFilteredStateBlockEntriesSQL = "" +
|
||||||
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
||||||
|
|
||||||
type stateBlockStatements struct {
|
type stateBlockStatements struct {
|
||||||
|
db *sql.DB
|
||||||
insertStateDataStmt *sql.Stmt
|
insertStateDataStmt *sql.Stmt
|
||||||
selectNextStateBlockNIDStmt *sql.Stmt
|
selectNextStateBlockNIDStmt *sql.Stmt
|
||||||
bulkSelectStateBlockEntriesStmt *sql.Stmt
|
bulkSelectStateBlockEntriesStmt *sql.Stmt
|
||||||
|
|
@ -80,6 +83,7 @@ type stateBlockStatements struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateBlockStatements) prepare(db *sql.DB) (err error) {
|
func (s *stateBlockStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
s.db = db
|
||||||
_, err = db.Exec(stateDataSchema)
|
_, err = db.Exec(stateDataSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
@ -108,6 +112,7 @@ func (s *stateBlockStatements) bulkInsertStateData(
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("bulkInsertStateData s.insertStateDataStmt.ExecContext:", err)
|
fmt.Println("bulkInsertStateData s.insertStateDataStmt.ExecContext:", err)
|
||||||
|
debug.PrintStack()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -127,12 +132,25 @@ func (s *stateBlockStatements) selectNextStateBlockNID(
|
||||||
func (s *stateBlockStatements) bulkSelectStateBlockEntries(
|
func (s *stateBlockStatements) bulkSelectStateBlockEntries(
|
||||||
ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID,
|
ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID,
|
||||||
) ([]types.StateEntryList, error) {
|
) ([]types.StateEntryList, error) {
|
||||||
|
///////////////
|
||||||
|
nids := make([]interface{}, len(stateBlockNIDs))
|
||||||
|
for k, v := range stateBlockNIDs {
|
||||||
|
nids[k] = v
|
||||||
|
}
|
||||||
|
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", queryVariadic(len(nids)), 1)
|
||||||
|
selectPrep, err := s.db.Prepare(selectOrig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
///////////////
|
||||||
|
/*
|
||||||
nids := make([]int64, len(stateBlockNIDs))
|
nids := make([]int64, len(stateBlockNIDs))
|
||||||
for i := range stateBlockNIDs {
|
for i := range stateBlockNIDs {
|
||||||
nids[i] = int64(stateBlockNIDs[i])
|
nids[i] = int64(stateBlockNIDs[i])
|
||||||
}
|
}
|
||||||
selectStmt := common.TxStmt(txn, s.bulkSelectStateBlockEntriesStmt)
|
*/
|
||||||
rows, err := selectStmt.QueryContext(ctx, sqliteIn(pq.Int64Array(nids)))
|
selectStmt := common.TxStmt(txn, selectPrep)
|
||||||
|
rows, err := selectStmt.QueryContext(ctx, nids...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("bulkSelectStateBlockEntries s.bulkSelectStateBlockEntriesStmt.QueryContext:", err)
|
fmt.Println("bulkSelectStateBlockEntries s.bulkSelectStateBlockEntriesStmt.QueryContext:", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
|
@ -51,12 +52,14 @@ const bulkSelectStateBlockNIDsSQL = "" +
|
||||||
" WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
|
" WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
|
||||||
|
|
||||||
type stateSnapshotStatements struct {
|
type stateSnapshotStatements struct {
|
||||||
|
db *sql.DB
|
||||||
insertStateStmt *sql.Stmt
|
insertStateStmt *sql.Stmt
|
||||||
insertStateResultStmt *sql.Stmt
|
insertStateResultStmt *sql.Stmt
|
||||||
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) {
|
func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
s.db = db
|
||||||
_, err = db.Exec(stateSnapshotSchema)
|
_, err = db.Exec(stateSnapshotSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
@ -92,12 +95,25 @@ func (s *stateSnapshotStatements) insertState(
|
||||||
func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
|
func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
|
||||||
ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
|
ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
|
||||||
) ([]types.StateBlockNIDList, error) {
|
) ([]types.StateBlockNIDList, error) {
|
||||||
|
///////////////
|
||||||
|
nids := make([]interface{}, len(stateNIDs))
|
||||||
|
for k, v := range stateNIDs {
|
||||||
|
nids[k] = v
|
||||||
|
}
|
||||||
|
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", queryVariadic(len(nids)), 1)
|
||||||
|
selectPrep, err := s.db.Prepare(selectOrig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
///////////////
|
||||||
|
/*
|
||||||
nids := make([]int64, len(stateNIDs))
|
nids := make([]int64, len(stateNIDs))
|
||||||
for i := range stateNIDs {
|
for i := range stateNIDs {
|
||||||
nids[i] = int64(stateNIDs[i])
|
nids[i] = int64(stateNIDs[i])
|
||||||
}
|
}
|
||||||
selectStmt := common.TxStmt(txn, s.bulkSelectStateBlockNIDsStmt)
|
*/
|
||||||
rows, err := selectStmt.QueryContext(ctx, sqliteIn(pq.Int64Array(nids)))
|
selectStmt := common.TxStmt(txn, selectPrep)
|
||||||
|
rows, err := selectStmt.QueryContext(ctx, nids...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("bulkSelectStateBlockNIDs s.bulkSelectStateBlockNIDsStmt.QueryContext:", err)
|
fmt.Println("bulkSelectStateBlockNIDs s.bulkSelectStateBlockNIDsStmt.QueryContext:", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
||||||
|
|
@ -44,18 +44,18 @@ func Open(dataSourceName string) (*Database, error) {
|
||||||
}
|
}
|
||||||
var cs string
|
var cs string
|
||||||
if uri.Opaque != "" { // file:filename.db
|
if uri.Opaque != "" { // file:filename.db
|
||||||
cs = fmt.Sprintf("%s?cache=shared&_busy_timeout=9999999", uri.Opaque)
|
cs = fmt.Sprintf("%s", uri.Opaque)
|
||||||
} else if uri.Path != "" { // file:///path/to/filename.db
|
} else if uri.Path != "" { // file:///path/to/filename.db
|
||||||
cs = fmt.Sprintf("%s?cache=shared&_busy_timeout=9999999", uri.Path)
|
cs = fmt.Sprintf("%s", uri.Path)
|
||||||
} else {
|
} else {
|
||||||
return nil, errors.New("no filename or path in connect string")
|
return nil, errors.New("no filename or path in connect string")
|
||||||
}
|
}
|
||||||
if d.db, err = sql.Open("sqlite3", cs); err != nil {
|
if d.db, err = sql.Open("sqlite3", cs); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
d.db.Exec("PRAGMA journal_mode=WAL;")
|
//d.db.Exec("PRAGMA journal_mode=WAL;")
|
||||||
//d.db.Exec("PRAGMA parser_trace = true;")
|
//d.db.Exec("PRAGMA parser_trace = true;")
|
||||||
//d.db.SetMaxOpenConns(1)
|
d.db.SetMaxOpenConns(1)
|
||||||
if err = d.statements.prepare(d.db); err != nil {
|
if err = d.statements.prepare(d.db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -76,32 +76,24 @@ func (d *Database) StoreEvent(
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
if txnAndSessionID != nil {
|
if txnAndSessionID != nil {
|
||||||
if err = d.statements.insertTransaction(
|
if err = d.statements.insertTransaction(
|
||||||
ctx, txnAndSessionID.TransactionID,
|
ctx, txn, txnAndSessionID.TransactionID,
|
||||||
txnAndSessionID.SessionID, event.Sender(), event.EventID(),
|
txnAndSessionID.SessionID, event.Sender(), event.EventID(),
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return 0, types.StateAtEvent{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = common.WithTransaction(d.db, func(tx *sql.Tx) error {
|
|
||||||
roomNID, err = d.assignRoomNID(ctx, tx, event.RoomID())
|
|
||||||
return err
|
return err
|
||||||
})
|
}
|
||||||
if err != nil {
|
|
||||||
return 0, types.StateAtEvent{}, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = common.WithTransaction(d.db, func(tx *sql.Tx) error {
|
if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID()); err != nil {
|
||||||
eventTypeNID, err = d.assignEventTypeNID(ctx, tx, event.Type())
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil {
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return 0, types.StateAtEvent{}, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
eventStateKey := event.StateKey()
|
eventStateKey := event.StateKey()
|
||||||
// Assigned a numeric ID for the state_key if there is one present.
|
// Assigned a numeric ID for the state_key if there is one present.
|
||||||
// Otherwise set the numeric ID for the state_key to 0.
|
// Otherwise set the numeric ID for the state_key to 0.
|
||||||
|
|
@ -161,8 +153,8 @@ func (d *Database) assignRoomNID(
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
// We don't have a numeric ID so insert one into the database.
|
// We don't have a numeric ID so insert one into the database.
|
||||||
roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID)
|
roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID)
|
||||||
if err == sql.ErrNoRows {
|
if err == nil {
|
||||||
// We raced with another insert so run the select again.
|
// Now get the numeric ID back out of the database
|
||||||
roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
|
roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -242,10 +234,11 @@ func (d *Database) Events(
|
||||||
) ([]types.Event, error) {
|
) ([]types.Event, error) {
|
||||||
var eventJSONs []eventJSONPair
|
var eventJSONs []eventJSONPair
|
||||||
var err error
|
var err error
|
||||||
results := make([]types.Event, len(eventJSONs))
|
results := make([]types.Event, len(eventNIDs))
|
||||||
common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs)
|
eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs)
|
||||||
if err != nil {
|
if err != nil || len(eventJSONs) == 0 {
|
||||||
|
fmt.Println("d.statements.bulkSelectEventJSON:", err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
for i, eventJSON := range eventJSONs {
|
for i, eventJSON := range eventJSONs {
|
||||||
|
|
@ -372,7 +365,7 @@ func (d *Database) GetTransactionEventID(
|
||||||
ctx context.Context, transactionID string,
|
ctx context.Context, transactionID string,
|
||||||
sessionID int64, userID string,
|
sessionID int64, userID string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, sessionID, userID)
|
eventID, err := d.statements.selectTransactionEventID(ctx, nil, transactionID, sessionID, userID)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
const transactionsSchema = `
|
const transactionsSchema = `
|
||||||
|
|
@ -58,13 +60,14 @@ func (s *transactionStatements) prepare(db *sql.DB) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *transactionStatements) insertTransaction(
|
func (s *transactionStatements) insertTransaction(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
transactionID string,
|
transactionID string,
|
||||||
sessionID int64,
|
sessionID int64,
|
||||||
userID string,
|
userID string,
|
||||||
eventID string,
|
eventID string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.insertTransactionStmt.ExecContext(
|
stmt := common.TxStmt(txn, s.insertTransactionStmt)
|
||||||
|
_, err = stmt.ExecContext(
|
||||||
ctx, transactionID, sessionID, userID, eventID,
|
ctx, transactionID, sessionID, userID, eventID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -74,12 +77,13 @@ func (s *transactionStatements) insertTransaction(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *transactionStatements) selectTransactionEventID(
|
func (s *transactionStatements) selectTransactionEventID(
|
||||||
ctx context.Context,
|
ctx context.Context, txn *sql.Tx,
|
||||||
transactionID string,
|
transactionID string,
|
||||||
sessionID int64,
|
sessionID int64,
|
||||||
userID string,
|
userID string,
|
||||||
) (eventID string, err error) {
|
) (eventID string, err error) {
|
||||||
err = s.selectTransactionEventIDStmt.QueryRowContext(
|
stmt := common.TxStmt(txn, s.selectTransactionEventIDStmt)
|
||||||
|
err = stmt.QueryRowContext(
|
||||||
ctx, transactionID, sessionID, userID,
|
ctx, transactionID, sessionID, userID,
|
||||||
).Scan(&eventID)
|
).Scan(&eventID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue