Not sure if roomserver is better or worse now

This commit is contained in:
Neil Alexander 2020-01-30 16:46:16 +00:00
parent fca23c356a
commit 0360f07110
14 changed files with 270 additions and 98 deletions

View file

@ -114,15 +114,6 @@ func Setup(
return SendMembership(req, accountDB, device, vars["roomID"], vars["membership"], cfg, queryAPI, asAPI, producer)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/send/{eventType}",
common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, queryAPI, producer, nil)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))

View file

@ -124,8 +124,10 @@ func processRoomEvent(
if stateAtEvent.BeforeStateSnapshotNID == 0 {
// We haven't calculated a state for this event yet.
// Lets calculate one.
fmt.Println("We don't have a state snapshot NID yet")
err = calculateAndSetState(ctx, db, input, roomNID, &stateAtEvent, event)
if err != nil {
fmt.Println("Failed to calculateAndSetState:", err)
return
}
}
@ -151,6 +153,7 @@ func calculateAndSetState(
) error {
var err error
if input.HasState {
fmt.Println("We have state")
// We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state.
var entries []types.StateEntry
@ -162,11 +165,13 @@ func calculateAndSetState(
return err
}
} else {
fmt.Println("We don't have state")
// We haven't been told what the state at the event is so we need to calculate it from the prev_events
if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(ctx, db, event, roomNID); err != nil {
return err
}
}
fmt.Println("Then set state", stateAtEvent)
return db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
}

View file

@ -556,11 +556,15 @@ func CalculateAndStoreStateBeforeEvent(
prevEventIDs[i] = prevEventRefs[i].EventID
}
fmt.Println("Previous event IDs:", prevEventIDs)
prevStates, err := db.StateAtEventIDs(ctx, prevEventIDs)
if err != nil {
return 0, err
}
fmt.Println("Previous states:", prevStates)
// The state before this event will be the state after the events that came before it.
return CalculateAndStoreStateAfterEvents(ctx, db, roomNID, prevStates)
}
@ -574,7 +578,6 @@ func CalculateAndStoreStateAfterEvents(
prevStates []types.StateAtEvent,
) (types.StateSnapshotNID, error) {
metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)}
if len(prevStates) == 0 {
// 2) There weren't any prev_events for this event so the state is
// empty.
@ -592,6 +595,7 @@ func CalculateAndStoreStateAfterEvents(
metrics.algorithm = "no_change"
return metrics.stop(prevState.BeforeStateSnapshotNID, nil)
}
// The previous event was a state event so we need to store a copy
// of the previous state updated with that event.
stateBlockNIDLists, err := db.StateBlockNIDs(
@ -614,6 +618,7 @@ func CalculateAndStoreStateAfterEvents(
// So fall through to calculateAndStoreStateAfterManyEvents
}
fmt.Println("Falling through to calculateAndStoreStateAfterManyEvents")
return calculateAndStoreStateAfterManyEvents(ctx, db, roomNID, prevStates, metrics)
}

View file

@ -262,7 +262,9 @@ func (s *eventStatements) bulkSelectStateAtEventByID(
func (s *eventStatements) updateEventState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
fmt.Println("updateEventState eventNID", eventNID, "stateNID", stateNID)
_, err := s.updateEventStateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID))
fmt.Println("Errors?", err)
return err
}

View file

@ -19,6 +19,7 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
@ -46,11 +47,13 @@ const bulkSelectEventJSONSQL = `
`
type eventJSONStatements struct {
db *sql.DB
insertEventJSONStmt *sql.Stmt
bulkSelectEventJSONStmt *sql.Stmt
}
func (s *eventJSONStatements) prepare(db *sql.DB) (err error) {
s.db = db
_, err = db.Exec(eventJSONSchema)
if err != nil {
return
@ -76,7 +79,19 @@ type eventJSONPair struct {
func (s *eventJSONStatements) bulkSelectEventJSON(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]eventJSONPair, error) {
rows, err := common.TxStmt(txn, s.bulkSelectEventJSONStmt).QueryContext(ctx, eventNIDsAsArray(eventNIDs))
///////////////
iEventNIDs := make([]interface{}, len(eventNIDs))
for k, v := range eventNIDs {
iEventNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
///////////////
rows, err := common.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...)
if err != nil {
fmt.Println("bulkSelectEventJSON s.bulkSelectEventJSONStmt.QueryContext:", err)
return nil, err

View file

@ -19,8 +19,8 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
)
@ -66,6 +66,7 @@ const bulkSelectEventStateKeySQL = `
`
type eventStateKeyStatements struct {
db *sql.DB
insertEventStateKeyNIDStmt *sql.Stmt
insertEventStateKeyNIDResultStmt *sql.Stmt
selectEventStateKeyNIDStmt *sql.Stmt
@ -74,6 +75,7 @@ type eventStateKeyStatements struct {
}
func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
s.db = db
_, err = db.Exec(eventStateKeysSchema)
if err != nil {
return
@ -110,18 +112,32 @@ func (s *eventStateKeyStatements) selectEventStateKeyNID(
) (types.EventStateKeyNID, error) {
var eventStateKeyNID int64
stmt := common.TxStmt(txn, s.selectEventStateKeyNIDStmt)
fmt.Println("selectEventStateKeyNID for", eventStateKey)
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
if err != nil {
fmt.Println("selectEventStateKeyNID stmt.QueryRowContext:", err)
}
fmt.Println("selectEventStateKeyNID returns", eventStateKeyNID)
return types.EventStateKeyNID(eventStateKeyNID), err
}
func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
rows, err := common.TxStmt(txn, s.bulkSelectEventStateKeyNIDStmt).QueryContext(
ctx, sqliteInStr(pq.StringArray(eventStateKeys)),
///////////////
iEventStateKeys := make([]interface{}, len(eventStateKeys))
for k, v := range eventStateKeys {
iEventStateKeys[k] = v
}
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeys)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
///////////////
rows, err := common.TxStmt(txn, selectPrep).QueryContext(
ctx, iEventStateKeys...,
)
if err != nil {
fmt.Println("bulkSelectEventStateKeyNID s.bulkSelectEventStateKeyNIDStmt.QueryContext:", err)
@ -144,11 +160,23 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
func (s *eventStateKeyStatements) bulkSelectEventStateKey(
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
nIDs := make(pq.Int64Array, len(eventStateKeyNIDs))
///////////////
iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
for k, v := range eventStateKeyNIDs {
iEventStateKeyNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeyNIDs)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
///////////////
/*nIDs := make(pq.Int64Array, len(eventStateKeyNIDs))
for i := range eventStateKeyNIDs {
nIDs[i] = int64(eventStateKeyNIDs[i])
}
rows, err := common.TxStmt(txn, s.bulkSelectEventStateKeyStmt).QueryContext(ctx, nIDs)
}*/
rows, err := common.TxStmt(txn, selectPrep).QueryContext(ctx, iEventStateKeyNIDs...)
if err != nil {
fmt.Println("bulkSelectEventStateKey s.bulkSelectEventStateKeyStmt.QueryContext:", err)
return nil, err

View file

@ -18,8 +18,8 @@ package sqlite3
import (
"context"
"database/sql"
"strings"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
)
@ -74,6 +74,7 @@ const bulkSelectEventTypeNIDSQL = `
`
type eventTypeStatements struct {
db *sql.DB
insertEventTypeNIDStmt *sql.Stmt
insertEventTypeNIDResultStmt *sql.Stmt
selectEventTypeNIDStmt *sql.Stmt
@ -81,6 +82,7 @@ type eventTypeStatements struct {
}
func (s *eventTypeStatements) prepare(db *sql.DB) (err error) {
s.db = db
_, err = db.Exec(eventTypesSchema)
if err != nil {
return
@ -119,8 +121,20 @@ func (s *eventTypeStatements) selectEventTypeNID(
func (s *eventTypeStatements) bulkSelectEventTypeNID(
ctx context.Context, tx *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
selectStmt := common.TxStmt(tx, s.bulkSelectEventTypeNIDStmt)
rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventTypes)))
///////////////
iEventTypes := make([]interface{}, len(eventTypes))
for k, v := range eventTypes {
iEventTypes[k] = v
}
selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", queryVariadic(len(iEventTypes)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
///////////////
selectStmt := common.TxStmt(tx, selectPrep)
rows, err := selectStmt.QueryContext(ctx, iEventTypes...)
if err != nil {
return nil, err
}

View file

@ -19,6 +19,7 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
@ -96,6 +97,7 @@ const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
type eventStatements struct {
db *sql.DB
insertEventStmt *sql.Stmt
insertEventResultStmt *sql.Stmt
selectEventStmt *sql.Stmt
@ -113,6 +115,7 @@ type eventStatements struct {
}
func (s *eventStatements) prepare(db *sql.DB) (err error) {
s.db = db
_, err = db.Exec(eventsSchema)
if err != nil {
return
@ -157,7 +160,14 @@ func (s *eventStatements) insertEvent(
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
); err == nil {
err = resultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID)
if err != nil {
fmt.Println("insertEvent HAS FAILED!", err)
}
} else {
fmt.Println("insertEvent HAS GONE WRONG!", err)
}
fmt.Println("Event NID:", eventNID)
fmt.Println("State snapshot NID:", stateNID)
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
}
@ -176,8 +186,20 @@ func (s *eventStatements) selectEvent(
func (s *eventStatements) bulkSelectStateEventByID(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateEntry, error) {
selectStmt := common.TxStmt(txn, s.bulkSelectStateEventByIDStmt)
rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventIDs)))
///////////////
iEventIDs := make([]interface{}, len(eventIDs))
for k, v := range eventIDs {
iEventIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
///////////////
selectStmt := common.TxStmt(txn, selectPrep)
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
if err != nil {
return nil, err
}
@ -217,8 +239,20 @@ func (s *eventStatements) bulkSelectStateEventByID(
func (s *eventStatements) bulkSelectStateAtEventByID(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateAtEvent, error) {
selectStmt := common.TxStmt(txn, s.bulkSelectStateAtEventByIDStmt)
rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventIDs)))
///////////////
iEventIDs := make([]interface{}, len(eventIDs))
for k, v := range eventIDs {
iEventIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
///////////////
selectStmt := common.TxStmt(txn, selectPrep)
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
if err != nil {
return nil, err
}
@ -296,8 +330,20 @@ func (s *eventStatements) selectEventID(
func (s *eventStatements) bulkSelectStateAtEventAndReference(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]types.StateAtEventAndReference, error) {
selectStmt := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt)
rows, err := selectStmt.QueryContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs)))
///////////////
iEventNIDs := make([]interface{}, len(eventNIDs))
for k, v := range eventNIDs {
iEventNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
///////////////
selectStmt := common.TxStmt(txn, selectPrep)
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
if err != nil {
fmt.Println("bulkSelectStateAtEventAndREference stmt.QueryContext:", err)
return nil, err
@ -337,8 +383,20 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(
func (s *eventStatements) bulkSelectEventReference(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]gomatrixserverlib.EventReference, error) {
selectStmt := common.TxStmt(txn, s.bulkSelectEventReferenceStmt)
rows, err := selectStmt.QueryContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs)))
///////////////
iEventNIDs := make([]interface{}, len(eventNIDs))
for k, v := range eventNIDs {
iEventNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
///////////////
selectStmt := common.TxStmt(txn, selectPrep)
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
if err != nil {
fmt.Println("bulkSelectEventReference s.bulkSelectEventReferenceStmt.QueryContext:", err)
return nil, err
@ -361,8 +419,20 @@ func (s *eventStatements) bulkSelectEventReference(
// bulkSelectEventID returns a map from numeric event ID to string event ID.
func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
selectStmt := common.TxStmt(txn, s.bulkSelectEventIDStmt)
rows, err := selectStmt.QueryContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs)))
///////////////
iEventNIDs := make([]interface{}, len(eventNIDs))
for k, v := range eventNIDs {
iEventNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
///////////////
selectStmt := common.TxStmt(txn, selectPrep)
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
if err != nil {
fmt.Println("bulkSelectEventID s.bulkSelectEventIDStmt.QueryContext:", err)
return nil, err
@ -388,8 +458,20 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
selectStmt := common.TxStmt(txn, s.bulkSelectEventNIDStmt)
rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventIDs)))
///////////////
iEventIDs := make([]interface{}, len(eventIDs))
for k, v := range eventIDs {
iEventIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
///////////////
selectStmt := common.TxStmt(txn, selectPrep)
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
if err != nil {
fmt.Println("bulkSelectEventNID s.bulkSelectEventNIDStmt.QueryContext:", err)
return nil, err

View file

@ -41,11 +41,6 @@ const insertRoomNIDSQL = `
ON CONFLICT DO NOTHING;
`
const insertRoomNIDResultSQL = `
SELECT room_nid FROM roomserver_rooms
WHERE rowid = last_insert_rowid();
`
const selectRoomNIDSQL = "" +
"SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
@ -60,7 +55,6 @@ const updateLatestEventNIDsSQL = "" +
type roomStatements struct {
insertRoomNIDStmt *sql.Stmt
insertRoomNIDResultStmt *sql.Stmt
selectRoomNIDStmt *sql.Stmt
selectLatestEventNIDsStmt *sql.Stmt
selectLatestEventNIDsForUpdateStmt *sql.Stmt
@ -74,7 +68,6 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
}
return statementList{
{&s.insertRoomNIDStmt, insertRoomNIDSQL},
{&s.insertRoomNIDResultStmt, insertRoomNIDResultSQL},
{&s.selectRoomNIDStmt, selectRoomNIDSQL},
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
@ -85,19 +78,14 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
func (s *roomStatements) insertRoomNID(
ctx context.Context, txn *sql.Tx, roomID string,
) (types.RoomNID, error) {
var roomNID int64
var err error
insertStmt := common.TxStmt(txn, s.insertRoomNIDStmt)
resultStmt := common.TxStmt(txn, s.insertRoomNIDResultStmt)
if _, err = insertStmt.ExecContext(ctx, roomID); err == nil {
err = resultStmt.QueryRowContext(ctx).Scan(&roomNID)
if err != nil {
fmt.Println("insertRoomNID resultStmt.QueryRowContext:", err)
}
return s.selectRoomNID(ctx, txn, roomID)
} else {
fmt.Println("insertRoomNID insertStmt.ExecContext:", err)
return types.RoomNID(0), err
}
return types.RoomNID(roomNID), err
}
func (s *roomStatements) selectRoomNID(
@ -106,9 +94,6 @@ func (s *roomStatements) selectRoomNID(
var roomNID int64
stmt := common.TxStmt(txn, s.selectRoomNIDStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
if err != nil {
fmt.Println("selectRoomNID stmt.QueryRowContext:", err)
}
return types.RoomNID(roomNID), err
}

View file

@ -17,6 +17,7 @@ package sqlite3
import (
"database/sql"
"fmt"
)
type statements struct {
@ -58,3 +59,16 @@ func (s *statements) prepare(db *sql.DB) error {
return nil
}
// Hack of the century
func queryVariadic(count int) string {
str := "("
for i := 1; i <= count; i++ {
str += fmt.Sprintf("$%d", i)
if i < count {
str += ", "
}
}
str += ")"
return str
}

View file

@ -19,7 +19,9 @@ import (
"context"
"database/sql"
"fmt"
"runtime/debug"
"sort"
"strings"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
@ -44,7 +46,7 @@ const insertStateDataSQL = "" +
const selectNextStateBlockNIDSQL = `
SELECT COALESCE((
SELECT seq+1 AS state_block_nid FROM sqlite_sequence
WHERE name = 'roomserver_state_block'), 0
WHERE name = 'roomserver_state_block'), 1
) AS state_block_nid
`
@ -73,6 +75,7 @@ const bulkSelectFilteredStateBlockEntriesSQL = "" +
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
type stateBlockStatements struct {
db *sql.DB
insertStateDataStmt *sql.Stmt
selectNextStateBlockNIDStmt *sql.Stmt
bulkSelectStateBlockEntriesStmt *sql.Stmt
@ -80,6 +83,7 @@ type stateBlockStatements struct {
}
func (s *stateBlockStatements) prepare(db *sql.DB) (err error) {
s.db = db
_, err = db.Exec(stateDataSchema)
if err != nil {
return
@ -108,6 +112,7 @@ func (s *stateBlockStatements) bulkInsertStateData(
)
if err != nil {
fmt.Println("bulkInsertStateData s.insertStateDataStmt.ExecContext:", err)
debug.PrintStack()
return err
}
}
@ -127,12 +132,25 @@ func (s *stateBlockStatements) selectNextStateBlockNID(
func (s *stateBlockStatements) bulkSelectStateBlockEntries(
ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
nids := make([]int64, len(stateBlockNIDs))
for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i])
///////////////
nids := make([]interface{}, len(stateBlockNIDs))
for k, v := range stateBlockNIDs {
nids[k] = v
}
selectStmt := common.TxStmt(txn, s.bulkSelectStateBlockEntriesStmt)
rows, err := selectStmt.QueryContext(ctx, sqliteIn(pq.Int64Array(nids)))
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", queryVariadic(len(nids)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
///////////////
/*
nids := make([]int64, len(stateBlockNIDs))
for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i])
}
*/
selectStmt := common.TxStmt(txn, selectPrep)
rows, err := selectStmt.QueryContext(ctx, nids...)
if err != nil {
fmt.Println("bulkSelectStateBlockEntries s.bulkSelectStateBlockEntriesStmt.QueryContext:", err)
return nil, err

View file

@ -19,6 +19,7 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
@ -51,12 +52,14 @@ const bulkSelectStateBlockNIDsSQL = "" +
" WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
type stateSnapshotStatements struct {
db *sql.DB
insertStateStmt *sql.Stmt
insertStateResultStmt *sql.Stmt
bulkSelectStateBlockNIDsStmt *sql.Stmt
}
func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) {
s.db = db
_, err = db.Exec(stateSnapshotSchema)
if err != nil {
return
@ -92,12 +95,25 @@ func (s *stateSnapshotStatements) insertState(
func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
nids := make([]int64, len(stateNIDs))
for i := range stateNIDs {
nids[i] = int64(stateNIDs[i])
///////////////
nids := make([]interface{}, len(stateNIDs))
for k, v := range stateNIDs {
nids[k] = v
}
selectStmt := common.TxStmt(txn, s.bulkSelectStateBlockNIDsStmt)
rows, err := selectStmt.QueryContext(ctx, sqliteIn(pq.Int64Array(nids)))
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", queryVariadic(len(nids)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
///////////////
/*
nids := make([]int64, len(stateNIDs))
for i := range stateNIDs {
nids[i] = int64(stateNIDs[i])
}
*/
selectStmt := common.TxStmt(txn, selectPrep)
rows, err := selectStmt.QueryContext(ctx, nids...)
if err != nil {
fmt.Println("bulkSelectStateBlockNIDs s.bulkSelectStateBlockNIDsStmt.QueryContext:", err)
return nil, err

View file

@ -44,18 +44,18 @@ func Open(dataSourceName string) (*Database, error) {
}
var cs string
if uri.Opaque != "" { // file:filename.db
cs = fmt.Sprintf("%s?cache=shared&_busy_timeout=9999999", uri.Opaque)
cs = fmt.Sprintf("%s", uri.Opaque)
} else if uri.Path != "" { // file:///path/to/filename.db
cs = fmt.Sprintf("%s?cache=shared&_busy_timeout=9999999", uri.Path)
cs = fmt.Sprintf("%s", uri.Path)
} else {
return nil, errors.New("no filename or path in connect string")
}
if d.db, err = sql.Open("sqlite3", cs); err != nil {
return nil, err
}
d.db.Exec("PRAGMA journal_mode=WAL;")
//d.db.Exec("PRAGMA journal_mode=WAL;")
//d.db.Exec("PRAGMA parser_trace = true;")
//d.db.SetMaxOpenConns(1)
d.db.SetMaxOpenConns(1)
if err = d.statements.prepare(d.db); err != nil {
return nil, err
}
@ -76,32 +76,24 @@ func (d *Database) StoreEvent(
err error
)
if txnAndSessionID != nil {
if err = d.statements.insertTransaction(
ctx, txnAndSessionID.TransactionID,
txnAndSessionID.SessionID, event.Sender(), event.EventID(),
); err != nil {
return 0, types.StateAtEvent{}, err
}
}
err = common.WithTransaction(d.db, func(tx *sql.Tx) error {
roomNID, err = d.assignRoomNID(ctx, tx, event.RoomID())
return err
})
if err != nil {
return 0, types.StateAtEvent{}, err
}
err = common.WithTransaction(d.db, func(tx *sql.Tx) error {
eventTypeNID, err = d.assignEventTypeNID(ctx, tx, event.Type())
return err
})
if err != nil {
return 0, types.StateAtEvent{}, err
}
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
if txnAndSessionID != nil {
if err = d.statements.insertTransaction(
ctx, txn, txnAndSessionID.TransactionID,
txnAndSessionID.SessionID, event.Sender(), event.EventID(),
); err != nil {
return err
}
}
if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID()); err != nil {
return err
}
if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil {
return err
}
eventStateKey := event.StateKey()
// Assigned a numeric ID for the state_key if there is one present.
// Otherwise set the numeric ID for the state_key to 0.
@ -161,8 +153,8 @@ func (d *Database) assignRoomNID(
if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database.
roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID)
if err == sql.ErrNoRows {
// We raced with another insert so run the select again.
if err == nil {
// Now get the numeric ID back out of the database
roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
}
}
@ -242,10 +234,11 @@ func (d *Database) Events(
) ([]types.Event, error) {
var eventJSONs []eventJSONPair
var err error
results := make([]types.Event, len(eventJSONs))
results := make([]types.Event, len(eventNIDs))
common.WithTransaction(d.db, func(txn *sql.Tx) error {
eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs)
if err != nil {
if err != nil || len(eventJSONs) == 0 {
fmt.Println("d.statements.bulkSelectEventJSON:", err)
return nil
}
for i, eventJSON := range eventJSONs {
@ -372,7 +365,7 @@ func (d *Database) GetTransactionEventID(
ctx context.Context, transactionID string,
sessionID int64, userID string,
) (string, error) {
eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, sessionID, userID)
eventID, err := d.statements.selectTransactionEventID(ctx, nil, transactionID, sessionID, userID)
if err == sql.ErrNoRows {
return "", nil
}

View file

@ -19,6 +19,8 @@ import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/common"
)
const transactionsSchema = `
@ -58,13 +60,14 @@ func (s *transactionStatements) prepare(db *sql.DB) (err error) {
}
func (s *transactionStatements) insertTransaction(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
transactionID string,
sessionID int64,
userID string,
eventID string,
) (err error) {
_, err = s.insertTransactionStmt.ExecContext(
stmt := common.TxStmt(txn, s.insertTransactionStmt)
_, err = stmt.ExecContext(
ctx, transactionID, sessionID, userID, eventID,
)
if err != nil {
@ -74,12 +77,13 @@ func (s *transactionStatements) insertTransaction(
}
func (s *transactionStatements) selectTransactionEventID(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
transactionID string,
sessionID int64,
userID string,
) (eventID string, err error) {
err = s.selectTransactionEventIDStmt.QueryRowContext(
stmt := common.TxStmt(txn, s.selectTransactionEventIDStmt)
err = stmt.QueryRowContext(
ctx, transactionID, sessionID, userID,
).Scan(&eventID)
if err != nil {