From 0360f07110de5a232c99094daa5c9033abf5f148 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 30 Jan 2020 16:46:16 +0000 Subject: [PATCH] Not sure if roomserver is better or worse now --- clientapi/routing/routing.go | 9 -- roomserver/input/events.go | 5 + roomserver/state/state.go | 7 +- roomserver/storage/postgres/events_table.go | 2 + .../storage/sqlite3/event_json_table.go | 17 ++- .../storage/sqlite3/event_state_keys_table.go | 40 ++++++- .../storage/sqlite3/event_types_table.go | 20 +++- roomserver/storage/sqlite3/events_table.go | 106 ++++++++++++++++-- roomserver/storage/sqlite3/rooms_table.go | 19 +--- roomserver/storage/sqlite3/sql.go | 14 +++ .../storage/sqlite3/state_block_table.go | 30 ++++- .../storage/sqlite3/state_snapshot_table.go | 26 ++++- roomserver/storage/sqlite3/storage.go | 61 +++++----- .../storage/sqlite3/transactions_table.go | 12 +- 14 files changed, 270 insertions(+), 98 deletions(-) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 933fdf60c..f7b94914a 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -114,15 +114,6 @@ func Setup( return SendMembership(req, accountDB, device, vars["roomID"], vars["membership"], cfg, queryAPI, asAPI, producer) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/send/{eventType}", - common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars, err := common.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.ErrorResponse(err) - } - return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, queryAPI, producer, nil) - }), - ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { vars, err := common.URLDecodeMapValues(mux.Vars(req)) diff --git a/roomserver/input/events.go b/roomserver/input/events.go index b30c39928..11bd51749 100644 --- a/roomserver/input/events.go +++ b/roomserver/input/events.go @@ -124,8 +124,10 @@ func processRoomEvent( if stateAtEvent.BeforeStateSnapshotNID == 0 { // We haven't calculated a state for this event yet. // Lets calculate one. + fmt.Println("We don't have a state snapshot NID yet") err = calculateAndSetState(ctx, db, input, roomNID, &stateAtEvent, event) if err != nil { + fmt.Println("Failed to calculateAndSetState:", err) return } } @@ -151,6 +153,7 @@ func calculateAndSetState( ) error { var err error if input.HasState { + fmt.Println("We have state") // We've been told what the state at the event is so we don't need to calculate it. // Check that those state events are in the database and store the state. var entries []types.StateEntry @@ -162,11 +165,13 @@ func calculateAndSetState( return err } } else { + fmt.Println("We don't have state") // We haven't been told what the state at the event is so we need to calculate it from the prev_events if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(ctx, db, event, roomNID); err != nil { return err } } + fmt.Println("Then set state", stateAtEvent) return db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) } diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 2a0b7f574..dbb04eb01 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -556,11 +556,15 @@ func CalculateAndStoreStateBeforeEvent( prevEventIDs[i] = prevEventRefs[i].EventID } + fmt.Println("Previous event IDs:", prevEventIDs) + prevStates, err := db.StateAtEventIDs(ctx, prevEventIDs) if err != nil { return 0, err } + fmt.Println("Previous states:", prevStates) + // The state before this event will be the state after the events that came before it. return CalculateAndStoreStateAfterEvents(ctx, db, roomNID, prevStates) } @@ -574,7 +578,6 @@ func CalculateAndStoreStateAfterEvents( prevStates []types.StateAtEvent, ) (types.StateSnapshotNID, error) { metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} - if len(prevStates) == 0 { // 2) There weren't any prev_events for this event so the state is // empty. @@ -592,6 +595,7 @@ func CalculateAndStoreStateAfterEvents( metrics.algorithm = "no_change" return metrics.stop(prevState.BeforeStateSnapshotNID, nil) } + // The previous event was a state event so we need to store a copy // of the previous state updated with that event. stateBlockNIDLists, err := db.StateBlockNIDs( @@ -614,6 +618,7 @@ func CalculateAndStoreStateAfterEvents( // So fall through to calculateAndStoreStateAfterManyEvents } + fmt.Println("Falling through to calculateAndStoreStateAfterManyEvents") return calculateAndStoreStateAfterManyEvents(ctx, db, roomNID, prevStates, metrics) } diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 1e8a5665b..8c381396c 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -262,7 +262,9 @@ func (s *eventStatements) bulkSelectStateAtEventByID( func (s *eventStatements) updateEventState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { + fmt.Println("updateEventState eventNID", eventNID, "stateNID", stateNID) _, err := s.updateEventStateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID)) + fmt.Println("Errors?", err) return err } diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index ffca7ff2e..f6738159d 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "fmt" + "strings" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" @@ -46,11 +47,13 @@ const bulkSelectEventJSONSQL = ` ` type eventJSONStatements struct { + db *sql.DB insertEventJSONStmt *sql.Stmt bulkSelectEventJSONStmt *sql.Stmt } func (s *eventJSONStatements) prepare(db *sql.DB) (err error) { + s.db = db _, err = db.Exec(eventJSONSchema) if err != nil { return @@ -76,7 +79,19 @@ type eventJSONPair struct { func (s *eventJSONStatements) bulkSelectEventJSON( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]eventJSONPair, error) { - rows, err := common.TxStmt(txn, s.bulkSelectEventJSONStmt).QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + /////////////// + iEventNIDs := make([]interface{}, len(eventNIDs)) + for k, v := range eventNIDs { + iEventNIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + rows, err := common.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...) if err != nil { fmt.Println("bulkSelectEventJSON s.bulkSelectEventJSONStmt.QueryContext:", err) return nil, err diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index 60d1cb30e..5ce72809a 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -19,8 +19,8 @@ import ( "context" "database/sql" "fmt" + "strings" - "github.com/lib/pq" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -66,6 +66,7 @@ const bulkSelectEventStateKeySQL = ` ` type eventStateKeyStatements struct { + db *sql.DB insertEventStateKeyNIDStmt *sql.Stmt insertEventStateKeyNIDResultStmt *sql.Stmt selectEventStateKeyNIDStmt *sql.Stmt @@ -74,6 +75,7 @@ type eventStateKeyStatements struct { } func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { + s.db = db _, err = db.Exec(eventStateKeysSchema) if err != nil { return @@ -110,18 +112,32 @@ func (s *eventStateKeyStatements) selectEventStateKeyNID( ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 stmt := common.TxStmt(txn, s.selectEventStateKeyNIDStmt) + fmt.Println("selectEventStateKeyNID for", eventStateKey) err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) if err != nil { fmt.Println("selectEventStateKeyNID stmt.QueryRowContext:", err) } + fmt.Println("selectEventStateKeyNID returns", eventStateKeyNID) return types.EventStateKeyNID(eventStateKeyNID), err } func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { - rows, err := common.TxStmt(txn, s.bulkSelectEventStateKeyNIDStmt).QueryContext( - ctx, sqliteInStr(pq.StringArray(eventStateKeys)), + /////////////// + iEventStateKeys := make([]interface{}, len(eventStateKeys)) + for k, v := range eventStateKeys { + iEventStateKeys[k] = v + } + selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeys)), 1) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + rows, err := common.TxStmt(txn, selectPrep).QueryContext( + ctx, iEventStateKeys..., ) if err != nil { fmt.Println("bulkSelectEventStateKeyNID s.bulkSelectEventStateKeyNIDStmt.QueryContext:", err) @@ -144,11 +160,23 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( func (s *eventStateKeyStatements) bulkSelectEventStateKey( ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { - nIDs := make(pq.Int64Array, len(eventStateKeyNIDs)) + /////////////// + iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs)) + for k, v := range eventStateKeyNIDs { + iEventStateKeyNIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeyNIDs)), 1) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + /*nIDs := make(pq.Int64Array, len(eventStateKeyNIDs)) for i := range eventStateKeyNIDs { nIDs[i] = int64(eventStateKeyNIDs[i]) - } - rows, err := common.TxStmt(txn, s.bulkSelectEventStateKeyStmt).QueryContext(ctx, nIDs) + }*/ + rows, err := common.TxStmt(txn, selectPrep).QueryContext(ctx, iEventStateKeyNIDs...) if err != nil { fmt.Println("bulkSelectEventStateKey s.bulkSelectEventStateKeyStmt.QueryContext:", err) return nil, err diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 4b00f3fdf..edc759c01 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -18,8 +18,8 @@ package sqlite3 import ( "context" "database/sql" + "strings" - "github.com/lib/pq" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -74,6 +74,7 @@ const bulkSelectEventTypeNIDSQL = ` ` type eventTypeStatements struct { + db *sql.DB insertEventTypeNIDStmt *sql.Stmt insertEventTypeNIDResultStmt *sql.Stmt selectEventTypeNIDStmt *sql.Stmt @@ -81,6 +82,7 @@ type eventTypeStatements struct { } func (s *eventTypeStatements) prepare(db *sql.DB) (err error) { + s.db = db _, err = db.Exec(eventTypesSchema) if err != nil { return @@ -119,8 +121,20 @@ func (s *eventTypeStatements) selectEventTypeNID( func (s *eventTypeStatements) bulkSelectEventTypeNID( ctx context.Context, tx *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { - selectStmt := common.TxStmt(tx, s.bulkSelectEventTypeNIDStmt) - rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventTypes))) + /////////////// + iEventTypes := make([]interface{}, len(eventTypes)) + for k, v := range eventTypes { + iEventTypes[k] = v + } + selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", queryVariadic(len(iEventTypes)), 1) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + selectStmt := common.TxStmt(tx, selectPrep) + rows, err := selectStmt.QueryContext(ctx, iEventTypes...) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 6b03a42f1..6fa5fdb08 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "fmt" + "strings" "github.com/lib/pq" "github.com/matrix-org/dendrite/common" @@ -96,6 +97,7 @@ const selectMaxEventDepthSQL = "" + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" type eventStatements struct { + db *sql.DB insertEventStmt *sql.Stmt insertEventResultStmt *sql.Stmt selectEventStmt *sql.Stmt @@ -113,6 +115,7 @@ type eventStatements struct { } func (s *eventStatements) prepare(db *sql.DB) (err error) { + s.db = db _, err = db.Exec(eventsSchema) if err != nil { return @@ -157,7 +160,14 @@ func (s *eventStatements) insertEvent( eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, ); err == nil { err = resultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID) + if err != nil { + fmt.Println("insertEvent HAS FAILED!", err) + } + } else { + fmt.Println("insertEvent HAS GONE WRONG!", err) } + fmt.Println("Event NID:", eventNID) + fmt.Println("State snapshot NID:", stateNID) return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } @@ -176,8 +186,20 @@ func (s *eventStatements) selectEvent( func (s *eventStatements) bulkSelectStateEventByID( ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateEntry, error) { - selectStmt := common.TxStmt(txn, s.bulkSelectStateEventByIDStmt) - rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventIDs))) + /////////////// + iEventIDs := make([]interface{}, len(eventIDs)) + for k, v := range eventIDs { + iEventIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + selectStmt := common.TxStmt(txn, selectPrep) + rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { return nil, err } @@ -217,8 +239,20 @@ func (s *eventStatements) bulkSelectStateEventByID( func (s *eventStatements) bulkSelectStateAtEventByID( ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateAtEvent, error) { - selectStmt := common.TxStmt(txn, s.bulkSelectStateAtEventByIDStmt) - rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventIDs))) + /////////////// + iEventIDs := make([]interface{}, len(eventIDs)) + for k, v := range eventIDs { + iEventIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + selectStmt := common.TxStmt(txn, selectPrep) + rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { return nil, err } @@ -296,8 +330,20 @@ func (s *eventStatements) selectEventID( func (s *eventStatements) bulkSelectStateAtEventAndReference( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]types.StateAtEventAndReference, error) { - selectStmt := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt) - rows, err := selectStmt.QueryContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs))) + /////////////// + iEventNIDs := make([]interface{}, len(eventNIDs)) + for k, v := range eventNIDs { + iEventNIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + selectStmt := common.TxStmt(txn, selectPrep) + rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) if err != nil { fmt.Println("bulkSelectStateAtEventAndREference stmt.QueryContext:", err) return nil, err @@ -337,8 +383,20 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference( func (s *eventStatements) bulkSelectEventReference( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]gomatrixserverlib.EventReference, error) { - selectStmt := common.TxStmt(txn, s.bulkSelectEventReferenceStmt) - rows, err := selectStmt.QueryContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs))) + /////////////// + iEventNIDs := make([]interface{}, len(eventNIDs)) + for k, v := range eventNIDs { + iEventNIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + selectStmt := common.TxStmt(txn, selectPrep) + rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) if err != nil { fmt.Println("bulkSelectEventReference s.bulkSelectEventReferenceStmt.QueryContext:", err) return nil, err @@ -361,8 +419,20 @@ func (s *eventStatements) bulkSelectEventReference( // bulkSelectEventID returns a map from numeric event ID to string event ID. func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { - selectStmt := common.TxStmt(txn, s.bulkSelectEventIDStmt) - rows, err := selectStmt.QueryContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs))) + /////////////// + iEventNIDs := make([]interface{}, len(eventNIDs)) + for k, v := range eventNIDs { + iEventNIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + selectStmt := common.TxStmt(txn, selectPrep) + rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) if err != nil { fmt.Println("bulkSelectEventID s.bulkSelectEventIDStmt.QueryContext:", err) return nil, err @@ -388,8 +458,20 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, ev // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { - selectStmt := common.TxStmt(txn, s.bulkSelectEventNIDStmt) - rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventIDs))) + /////////////// + iEventIDs := make([]interface{}, len(eventIDs)) + for k, v := range eventIDs { + iEventIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + selectStmt := common.TxStmt(txn, selectPrep) + rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { fmt.Println("bulkSelectEventNID s.bulkSelectEventNIDStmt.QueryContext:", err) return nil, err diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 491084dd5..84ee50585 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -41,11 +41,6 @@ const insertRoomNIDSQL = ` ON CONFLICT DO NOTHING; ` -const insertRoomNIDResultSQL = ` - SELECT room_nid FROM roomserver_rooms - WHERE rowid = last_insert_rowid(); -` - const selectRoomNIDSQL = "" + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" @@ -60,7 +55,6 @@ const updateLatestEventNIDsSQL = "" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt - insertRoomNIDResultStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt @@ -74,7 +68,6 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { } return statementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, - {&s.insertRoomNIDResultStmt, insertRoomNIDResultSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, @@ -85,19 +78,14 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { func (s *roomStatements) insertRoomNID( ctx context.Context, txn *sql.Tx, roomID string, ) (types.RoomNID, error) { - var roomNID int64 var err error insertStmt := common.TxStmt(txn, s.insertRoomNIDStmt) - resultStmt := common.TxStmt(txn, s.insertRoomNIDResultStmt) if _, err = insertStmt.ExecContext(ctx, roomID); err == nil { - err = resultStmt.QueryRowContext(ctx).Scan(&roomNID) - if err != nil { - fmt.Println("insertRoomNID resultStmt.QueryRowContext:", err) - } + return s.selectRoomNID(ctx, txn, roomID) } else { fmt.Println("insertRoomNID insertStmt.ExecContext:", err) + return types.RoomNID(0), err } - return types.RoomNID(roomNID), err } func (s *roomStatements) selectRoomNID( @@ -106,9 +94,6 @@ func (s *roomStatements) selectRoomNID( var roomNID int64 stmt := common.TxStmt(txn, s.selectRoomNIDStmt) err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) - if err != nil { - fmt.Println("selectRoomNID stmt.QueryRowContext:", err) - } return types.RoomNID(roomNID), err } diff --git a/roomserver/storage/sqlite3/sql.go b/roomserver/storage/sqlite3/sql.go index 0d49432b8..c424a2fe9 100644 --- a/roomserver/storage/sqlite3/sql.go +++ b/roomserver/storage/sqlite3/sql.go @@ -17,6 +17,7 @@ package sqlite3 import ( "database/sql" + "fmt" ) type statements struct { @@ -58,3 +59,16 @@ func (s *statements) prepare(db *sql.DB) error { return nil } + +// Hack of the century +func queryVariadic(count int) string { + str := "(" + for i := 1; i <= count; i++ { + str += fmt.Sprintf("$%d", i) + if i < count { + str += ", " + } + } + str += ")" + return str +} diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 7a84f3136..9fd497a2c 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -19,7 +19,9 @@ import ( "context" "database/sql" "fmt" + "runtime/debug" "sort" + "strings" "github.com/lib/pq" "github.com/matrix-org/dendrite/common" @@ -44,7 +46,7 @@ const insertStateDataSQL = "" + const selectNextStateBlockNIDSQL = ` SELECT COALESCE(( SELECT seq+1 AS state_block_nid FROM sqlite_sequence - WHERE name = 'roomserver_state_block'), 0 + WHERE name = 'roomserver_state_block'), 1 ) AS state_block_nid ` @@ -73,6 +75,7 @@ const bulkSelectFilteredStateBlockEntriesSQL = "" + " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" type stateBlockStatements struct { + db *sql.DB insertStateDataStmt *sql.Stmt selectNextStateBlockNIDStmt *sql.Stmt bulkSelectStateBlockEntriesStmt *sql.Stmt @@ -80,6 +83,7 @@ type stateBlockStatements struct { } func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { + s.db = db _, err = db.Exec(stateDataSchema) if err != nil { return @@ -108,6 +112,7 @@ func (s *stateBlockStatements) bulkInsertStateData( ) if err != nil { fmt.Println("bulkInsertStateData s.insertStateDataStmt.ExecContext:", err) + debug.PrintStack() return err } } @@ -127,12 +132,25 @@ func (s *stateBlockStatements) selectNextStateBlockNID( func (s *stateBlockStatements) bulkSelectStateBlockEntries( ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { - nids := make([]int64, len(stateBlockNIDs)) - for i := range stateBlockNIDs { - nids[i] = int64(stateBlockNIDs[i]) + /////////////// + nids := make([]interface{}, len(stateBlockNIDs)) + for k, v := range stateBlockNIDs { + nids[k] = v } - selectStmt := common.TxStmt(txn, s.bulkSelectStateBlockEntriesStmt) - rows, err := selectStmt.QueryContext(ctx, sqliteIn(pq.Int64Array(nids))) + selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", queryVariadic(len(nids)), 1) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + /* + nids := make([]int64, len(stateBlockNIDs)) + for i := range stateBlockNIDs { + nids[i] = int64(stateBlockNIDs[i]) + } + */ + selectStmt := common.TxStmt(txn, selectPrep) + rows, err := selectStmt.QueryContext(ctx, nids...) if err != nil { fmt.Println("bulkSelectStateBlockEntries s.bulkSelectStateBlockEntriesStmt.QueryContext:", err) return nil, err diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 71a40f03f..4b9ebf7bf 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "fmt" + "strings" "github.com/lib/pq" "github.com/matrix-org/dendrite/common" @@ -51,12 +52,14 @@ const bulkSelectStateBlockNIDsSQL = "" + " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC" type stateSnapshotStatements struct { + db *sql.DB insertStateStmt *sql.Stmt insertStateResultStmt *sql.Stmt bulkSelectStateBlockNIDsStmt *sql.Stmt } func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) { + s.db = db _, err = db.Exec(stateSnapshotSchema) if err != nil { return @@ -92,12 +95,25 @@ func (s *stateSnapshotStatements) insertState( func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { - nids := make([]int64, len(stateNIDs)) - for i := range stateNIDs { - nids[i] = int64(stateNIDs[i]) + /////////////// + nids := make([]interface{}, len(stateNIDs)) + for k, v := range stateNIDs { + nids[k] = v } - selectStmt := common.TxStmt(txn, s.bulkSelectStateBlockNIDsStmt) - rows, err := selectStmt.QueryContext(ctx, sqliteIn(pq.Int64Array(nids))) + selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", queryVariadic(len(nids)), 1) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + /* + nids := make([]int64, len(stateNIDs)) + for i := range stateNIDs { + nids[i] = int64(stateNIDs[i]) + } + */ + selectStmt := common.TxStmt(txn, selectPrep) + rows, err := selectStmt.QueryContext(ctx, nids...) if err != nil { fmt.Println("bulkSelectStateBlockNIDs s.bulkSelectStateBlockNIDsStmt.QueryContext:", err) return nil, err diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 138d971f9..0b43defde 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -44,18 +44,18 @@ func Open(dataSourceName string) (*Database, error) { } var cs string if uri.Opaque != "" { // file:filename.db - cs = fmt.Sprintf("%s?cache=shared&_busy_timeout=9999999", uri.Opaque) + cs = fmt.Sprintf("%s", uri.Opaque) } else if uri.Path != "" { // file:///path/to/filename.db - cs = fmt.Sprintf("%s?cache=shared&_busy_timeout=9999999", uri.Path) + cs = fmt.Sprintf("%s", uri.Path) } else { return nil, errors.New("no filename or path in connect string") } if d.db, err = sql.Open("sqlite3", cs); err != nil { return nil, err } - d.db.Exec("PRAGMA journal_mode=WAL;") + //d.db.Exec("PRAGMA journal_mode=WAL;") //d.db.Exec("PRAGMA parser_trace = true;") - //d.db.SetMaxOpenConns(1) + d.db.SetMaxOpenConns(1) if err = d.statements.prepare(d.db); err != nil { return nil, err } @@ -76,32 +76,24 @@ func (d *Database) StoreEvent( err error ) - if txnAndSessionID != nil { - if err = d.statements.insertTransaction( - ctx, txnAndSessionID.TransactionID, - txnAndSessionID.SessionID, event.Sender(), event.EventID(), - ); err != nil { - return 0, types.StateAtEvent{}, err - } - } - - err = common.WithTransaction(d.db, func(tx *sql.Tx) error { - roomNID, err = d.assignRoomNID(ctx, tx, event.RoomID()) - return err - }) - if err != nil { - return 0, types.StateAtEvent{}, err - } - - err = common.WithTransaction(d.db, func(tx *sql.Tx) error { - eventTypeNID, err = d.assignEventTypeNID(ctx, tx, event.Type()) - return err - }) - if err != nil { - return 0, types.StateAtEvent{}, err - } - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + if txnAndSessionID != nil { + if err = d.statements.insertTransaction( + ctx, txn, txnAndSessionID.TransactionID, + txnAndSessionID.SessionID, event.Sender(), event.EventID(), + ); err != nil { + return err + } + } + + if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID()); err != nil { + return err + } + + if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil { + return err + } + eventStateKey := event.StateKey() // Assigned a numeric ID for the state_key if there is one present. // Otherwise set the numeric ID for the state_key to 0. @@ -161,8 +153,8 @@ func (d *Database) assignRoomNID( if err == sql.ErrNoRows { // We don't have a numeric ID so insert one into the database. roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID) - if err == sql.ErrNoRows { - // We raced with another insert so run the select again. + if err == nil { + // Now get the numeric ID back out of the database roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) } } @@ -242,10 +234,11 @@ func (d *Database) Events( ) ([]types.Event, error) { var eventJSONs []eventJSONPair var err error - results := make([]types.Event, len(eventJSONs)) + results := make([]types.Event, len(eventNIDs)) common.WithTransaction(d.db, func(txn *sql.Tx) error { eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs) - if err != nil { + if err != nil || len(eventJSONs) == 0 { + fmt.Println("d.statements.bulkSelectEventJSON:", err) return nil } for i, eventJSON := range eventJSONs { @@ -372,7 +365,7 @@ func (d *Database) GetTransactionEventID( ctx context.Context, transactionID string, sessionID int64, userID string, ) (string, error) { - eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, sessionID, userID) + eventID, err := d.statements.selectTransactionEventID(ctx, nil, transactionID, sessionID, userID) if err == sql.ErrNoRows { return "", nil } diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go index 260e21360..5ff6d215e 100644 --- a/roomserver/storage/sqlite3/transactions_table.go +++ b/roomserver/storage/sqlite3/transactions_table.go @@ -19,6 +19,8 @@ import ( "context" "database/sql" "fmt" + + "github.com/matrix-org/dendrite/common" ) const transactionsSchema = ` @@ -58,13 +60,14 @@ func (s *transactionStatements) prepare(db *sql.DB) (err error) { } func (s *transactionStatements) insertTransaction( - ctx context.Context, + ctx context.Context, txn *sql.Tx, transactionID string, sessionID int64, userID string, eventID string, ) (err error) { - _, err = s.insertTransactionStmt.ExecContext( + stmt := common.TxStmt(txn, s.insertTransactionStmt) + _, err = stmt.ExecContext( ctx, transactionID, sessionID, userID, eventID, ) if err != nil { @@ -74,12 +77,13 @@ func (s *transactionStatements) insertTransaction( } func (s *transactionStatements) selectTransactionEventID( - ctx context.Context, + ctx context.Context, txn *sql.Tx, transactionID string, sessionID int64, userID string, ) (eventID string, err error) { - err = s.selectTransactionEventIDStmt.QueryRowContext( + stmt := common.TxStmt(txn, s.selectTransactionEventIDStmt) + err = stmt.QueryRowContext( ctx, transactionID, sessionID, userID, ).Scan(&eventID) if err != nil {