From c617d7c335ac507ddf94d20a4169754c7c6a01e0 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 20 Apr 2021 14:56:16 +0100 Subject: [PATCH] Hash-deduplicated state storage (and migrations) for PostgreSQL and SQLite --- go.mod | 2 +- go.sum | 4 +- .../20201028212440_add_forgotten_column.go | 1 + .../20210416150927_state_blocks_refactor.go | 197 ++++++++++++++++ roomserver/storage/postgres/events_table.go | 46 ++++ .../storage/postgres/state_block_table.go | 195 +++------------- .../storage/postgres/state_snapshot_table.go | 32 +-- roomserver/storage/postgres/storage.go | 10 +- roomserver/storage/shared/storage.go | 58 ++++- .../20201028212440_add_forgotten_column.go | 1 + .../20210416150927_state_blocks_refactor.go | 152 ++++++++++++ roomserver/storage/sqlite3/events_table.go | 61 +++++ .../storage/sqlite3/state_block_table.go | 220 ++++-------------- .../storage/sqlite3/state_snapshot_table.go | 21 +- roomserver/storage/sqlite3/storage.go | 1 + roomserver/storage/tables/interface.go | 9 +- roomserver/types/types.go | 40 ++++ 17 files changed, 678 insertions(+), 372 deletions(-) create mode 100644 roomserver/storage/postgres/deltas/20210416150927_state_blocks_refactor.go create mode 100644 roomserver/storage/sqlite3/deltas/20210416150927_state_blocks_refactor.go diff --git a/go.mod b/go.mod index a3d80f1b1..16c25ead1 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( github.com/matrix-org/gomatrixserverlib v0.0.0-20210302161955-6142fe3f8c2c github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 - github.com/mattn/go-sqlite3 v1.14.6 + github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6 github.com/opentracing/opentracing-go v1.2.0 diff --git a/go.sum b/go.sum index 90b5527c8..f245a270a 100644 --- a/go.sum +++ b/go.sum @@ -684,8 +684,8 @@ github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcME github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/mattn/go-runewidth v0.0.7/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-sqlite3 v1.14.2/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus= -github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= -github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb h1:ax2vG2unlxsjwS7PMRo4FECIfAdQLowd6ejWYwPQhBo= +github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= diff --git a/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go b/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go index 733f0fa14..f3bd8632f 100644 --- a/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go +++ b/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go @@ -24,6 +24,7 @@ import ( func LoadFromGoose() { goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn) + goose.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor) } func LoadAddForgottenColumn(m *sqlutil.Migrations) { diff --git a/roomserver/storage/postgres/deltas/20210416150927_state_blocks_refactor.go b/roomserver/storage/postgres/deltas/20210416150927_state_blocks_refactor.go new file mode 100644 index 000000000..c4dcd639b --- /dev/null +++ b/roomserver/storage/postgres/deltas/20210416150927_state_blocks_refactor.go @@ -0,0 +1,197 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +type stateSnapshotData struct { + StateSnapshotNID types.StateSnapshotNID + RoomNID types.RoomNID +} + +type stateBlockData struct { + stateSnapshotData + StateBlockNID types.StateBlockNID + EventNIDs types.EventNIDs +} + +func LoadStateBlocksRefactor(m *sqlutil.Migrations) { + m.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor) +} + +func UpStateBlocksRefactor(tx *sql.Tx) error { + logrus.Warn("Performing state block refactor upgrade. Please wait, this may take some time!") + + if _, err := tx.Exec(`ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil { + return fmt.Errorf("tx.Exec: %w", err) + } + if _, err := tx.Exec(`ALTER TABLE roomserver_state_snapshots RENAME TO _roomserver_state_snapshots;`); err != nil { + return fmt.Errorf("tx.Exec: %w", err) + } + _, err := tx.Exec(` + CREATE TABLE IF NOT EXISTS roomserver_state_block ( + state_block_nid bigserial PRIMARY KEY, + state_block_hash BYTEA UNIQUE, + event_nids bigint[] NOT NULL + ); + `) + if err != nil { + return fmt.Errorf("tx.Exec: %w", err) + } + _, err = tx.Exec(` + CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( + state_snapshot_nid bigserial PRIMARY KEY, + state_snapshot_hash BYTEA UNIQUE, + room_nid bigint NOT NULL, + state_block_nids bigint[] NOT NULL + ); + `) + if err != nil { + return fmt.Errorf("tx.Exec: %w", err) + } + logrus.Warn("New tables created...") + + var snapshotcount int + err = tx.QueryRow(` + SELECT COUNT(DISTINCT state_snapshot_nid) FROM _roomserver_state_snapshots; + `).Scan(&snapshotcount) + if err != nil { + return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err) + } + + batchsize := 100 + for batchoffset := 0; batchoffset < snapshotcount; batchoffset += batchsize { + var snapshotrows *sql.Rows + snapshotrows, err = tx.Query(` + SELECT + state_snapshot_nid, + room_nid, + state_block_nid, + ARRAY_AGG(event_nid) AS event_nids + FROM ( + SELECT + _roomserver_state_snapshots.state_snapshot_nid, + _roomserver_state_snapshots.room_nid, + _roomserver_state_block.state_block_nid, + _roomserver_state_block.event_nid + FROM + _roomserver_state_snapshots + JOIN _roomserver_state_block ON _roomserver_state_block.state_block_nid = ANY (_roomserver_state_snapshots.state_block_nids) + WHERE + _roomserver_state_snapshots.state_snapshot_nid = ANY ( SELECT DISTINCT + _roomserver_state_snapshots.state_snapshot_nid + FROM + _roomserver_state_snapshots + LIMIT $1 OFFSET $2)) AS _roomserver_state_block + GROUP BY + state_snapshot_nid, + room_nid, + state_block_nid; + `, batchsize, batchoffset) + if err != nil { + return fmt.Errorf("tx.Query: %w", err) + } + + logrus.Warnf("Rewriting snapshots %d-%d of %d...", batchoffset, batchoffset+batchsize, snapshotcount) + var snapshots []stateBlockData + + for snapshotrows.Next() { + var snapshot stateBlockData + var eventsarray pq.Int64Array + if err = snapshotrows.Scan(&snapshot.StateSnapshotNID, &snapshot.RoomNID, &snapshot.StateBlockNID, &eventsarray); err != nil { + return fmt.Errorf("rows.Scan: %w", err) + } + for _, e := range eventsarray { + snapshot.EventNIDs = append(snapshot.EventNIDs, types.EventNID(e)) + } + snapshot.EventNIDs = snapshot.EventNIDs[:util.SortAndUnique(snapshot.EventNIDs)] + snapshots = append(snapshots, snapshot) + } + + if err = snapshotrows.Close(); err != nil { + return fmt.Errorf("snapshots.Close: %w", err) + } + + newsnapshots := map[stateSnapshotData]types.StateBlockNIDs{} + + for _, snapshot := range snapshots { + var eventsarray pq.Int64Array + for _, e := range snapshot.EventNIDs { + eventsarray = append(eventsarray, int64(e)) + } + + var blocknid types.StateBlockNID + err = tx.QueryRow(` + INSERT INTO roomserver_state_block (state_block_hash, event_nids) + VALUES ($1, $2) + ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2 + RETURNING state_block_nid + `, snapshot.EventNIDs.Hash(), eventsarray).Scan(&blocknid) + if err != nil { + return fmt.Errorf("tx.QueryRow.Scan (insert new block with %d events): %w", len(eventsarray), err) + } + index := stateSnapshotData{snapshot.StateSnapshotNID, snapshot.RoomNID} + newsnapshots[index] = append(newsnapshots[index], blocknid) + } + + for snapshotdata, newblocks := range newsnapshots { + var newblocksarray pq.Int64Array + for _, b := range newblocks { + newblocksarray = append(newblocksarray, int64(b)) + } + + var newNID types.StateSnapshotNID + err = tx.QueryRow(` + INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids) + VALUES ($1, $2, $3) + ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2 + RETURNING state_snapshot_nid + `, newblocks.Hash(), snapshotdata.RoomNID, newblocksarray).Scan(&newNID) + if err != nil { + return fmt.Errorf("tx.QueryRow.Scan (insert new snapshot): %w", err) + } + + if _, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2`, newNID, snapshotdata.StateSnapshotNID); err != nil { + return fmt.Errorf("tx.Exec (update events): %w", err) + } + + if _, err = tx.Exec(`UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2`, newNID, snapshotdata.StateSnapshotNID); err != nil { + return fmt.Errorf("tx.Exec (update rooms): %w", err) + } + } + } + + if _, err = tx.Exec(`DROP TABLE _roomserver_state_snapshots;`); err != nil { + return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err) + } + if _, err = tx.Exec(`DROP TABLE _roomserver_state_block;`); err != nil { + return fmt.Errorf("tx.Exec (delete old block table): %w", err) + } + + return nil +} + +func DownStateBlocksRefactor(tx *sql.Tx) error { + panic("Downgrading state storage is not supported") +} diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 0cf0bd22f..052b0b13b 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "fmt" + "sort" "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" @@ -88,6 +89,13 @@ const bulkSelectStateEventByIDSQL = "" + " WHERE event_id = ANY($1)" + " ORDER BY event_type_nid, event_state_key_nid ASC" +const bulkSelectStateEventByNIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + + " WHERE event_nid = ANY($1)" + + " AND (CARDINALITY($2::bigint[]) = 0 OR event_type_nid = ANY($2))" + + " AND (CARDINALITY($3::bigint[]) = 0 OR event_state_key_nid = ANY($3))" + + " ORDER BY event_type_nid, event_state_key_nid ASC" + const bulkSelectStateAtEventByIDSQL = "" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + " WHERE event_id = ANY($1)" @@ -127,6 +135,7 @@ type eventStatements struct { insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt + bulkSelectStateEventByNIDStmt *sql.Stmt bulkSelectStateAtEventByIDStmt *sql.Stmt updateEventStateStmt *sql.Stmt selectEventSentToOutputStmt *sql.Stmt @@ -151,6 +160,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, + {&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL}, {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, {&s.updateEventStateStmt, updateEventStateSQL}, {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, @@ -238,6 +248,42 @@ func (s *eventStatements) BulkSelectStateEventByID( return results, nil } +// bulkSelectStateEventByNID lookups a list of state events by event NID. +// If any of the requested events are missing from the database it returns a types.MissingEventError +func (s *eventStatements) BulkSelectStateEventByNID( + ctx context.Context, eventNIDs []types.EventNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntry, error) { + tuples := stateKeyTupleSorter(stateKeyTuples) + sort.Sort(tuples) + eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() + rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed") + // We know that we will only get as many results as event IDs + // because of the unique constraint on event IDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than IDs so we adjust the length of the slice before returning it. + results := make([]types.StateEntry, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventTypeNID, + &result.EventStateKeyNID, + &result.EventNID, + ); err != nil { + return nil, err + } + } + if err = rows.Err(); err != nil { + return nil, err + } + return results, nil +} + // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index d618686f7..e75663fb7 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -39,53 +39,26 @@ const stateDataSchema = ` -- lookup a specific (type, state_key) pair for an event. It also makes it easy -- to read the state for a given state_block_nid ordered by (type, state_key) -- which in turn makes it easier to merge state data blocks. -CREATE SEQUENCE IF NOT EXISTS roomserver_state_block_nid_seq; CREATE TABLE IF NOT EXISTS roomserver_state_block ( - -- Local numeric ID for this state data. - state_block_nid bigint NOT NULL, - event_type_nid bigint NOT NULL, - event_state_key_nid bigint NOT NULL, - event_nid bigint NOT NULL, - UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) + state_block_nid bigserial PRIMARY KEY, + state_block_hash BYTEA UNIQUE, + event_nids bigint[] NOT NULL ); ` const insertStateDataSQL = "" + - "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + - " VALUES ($1, $2, $3, $4)" + "INSERT INTO roomserver_state_block (state_block_hash, event_nids)" + + " VALUES ($1, $2)" + + " ON CONFLICT (event_nids) DO UPDATE SET event_nids=$2" + + " RETURNING state_block_nid" -const selectNextStateBlockNIDSQL = "" + - "SELECT nextval('roomserver_state_block_nid_seq')" - -// Bulk state lookup by numeric state block ID. -// Sort by the state_block_nid, event_type_nid, event_state_key_nid -// This means that all the entries for a given state_block_nid will appear -// together in the list and those entries will sorted by event_type_nid -// and event_state_key_nid. This property makes it easier to merge two -// state data blocks together. const bulkSelectStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid = ANY($1)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" - -// Bulk state lookup by numeric state block ID. -// Filters the rows in each block to the requested types and state keys. -// We would like to restrict to particular type state key pairs but we are -// restricted by the query language to pull the cross product of a list -// of types and a list state_keys. So we have to filter the result in the -// application to restrict it to the list of event types and state keys we -// actually wanted. -const bulkSelectFilteredStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid = ANY($1)" + - " AND event_type_nid = ANY($2) AND event_state_key_nid = ANY($3)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + "SELECT state_block_nid, event_nids" + + " FROM roomserver_state_block WHERE state_block_nid = ANY($1)" type stateBlockStatements struct { - insertStateDataStmt *sql.Stmt - selectNextStateBlockNIDStmt *sql.Stmt - bulkSelectStateBlockEntriesStmt *sql.Stmt - bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt + insertStateDataStmt *sql.Stmt + bulkSelectStateBlockEntriesStmt *sql.Stmt } func NewPostgresStateBlockTable(db *sql.DB) (tables.StateBlock, error) { @@ -97,85 +70,48 @@ func NewPostgresStateBlockTable(db *sql.DB) (tables.StateBlock, error) { return s, shared.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, - {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, - {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, }.Prepare(db) } func (s *stateBlockStatements) BulkInsertStateData( ctx context.Context, txn *sql.Tx, - entries []types.StateEntry, -) (types.StateBlockNID, error) { - stateBlockNID, err := s.selectNextStateBlockNID(ctx) - if err != nil { - return 0, err + entries types.StateEntries, +) (id types.StateBlockNID, err error) { + entries = entries[:util.SortAndUnique(entries)] + var nids types.EventNIDs + for _, e := range entries { + nids = append(nids, e.EventNID) } - for _, entry := range entries { - _, err := s.insertStateDataStmt.ExecContext( - ctx, - int64(stateBlockNID), - int64(entry.EventTypeNID), - int64(entry.EventStateKeyNID), - int64(entry.EventNID), - ) - if err != nil { - return 0, err - } - } - return stateBlockNID, nil -} - -func (s *stateBlockStatements) selectNextStateBlockNID( - ctx context.Context, -) (types.StateBlockNID, error) { - var stateBlockNID int64 - err := s.selectNextStateBlockNIDStmt.QueryRowContext(ctx).Scan(&stateBlockNID) - return types.StateBlockNID(stateBlockNID), err + err = s.insertStateDataStmt.QueryRowContext( + ctx, nids.Hash(), eventNIDsAsArray(nids), + ).Scan(&id) + return } func (s *stateBlockStatements) BulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) ([]types.StateEntryList, error) { - nids := make([]int64, len(stateBlockNIDs)) - for i := range stateBlockNIDs { - nids[i] = int64(stateBlockNIDs[i]) - } - rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, pq.Int64Array(nids)) + ctx context.Context, stateBlockNIDs types.StateBlockNIDs, +) ([][]types.EventNID, error) { + rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs)) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed") - results := make([]types.StateEntryList, len(stateBlockNIDs)) - // current is a pointer to the StateEntryList to append the state entries to. - var current *types.StateEntryList + results := make([][]types.EventNID, len(stateBlockNIDs)) i := 0 - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err = rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { + for ; rows.Next(); i++ { + var stateBlockNID types.StateBlockNID + var result pq.Int64Array + if err = rows.Scan(&stateBlockNID, &result); err != nil { return nil, err } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we start appending to the next entry in the list. - current = &results[i] - current.StateBlockNID = types.StateBlockNID(stateBlockNID) - i++ + r := []types.EventNID{} + for _, e := range result { + r = append(r, types.EventNID(e)) } - current.StateEntries = append(current.StateEntries, entry) + results[i] = r } if err = rows.Err(); err != nil { return nil, err @@ -186,71 +122,6 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries( return results, err } -func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntryList, error) { - tuples := stateKeyTupleSorter(stateKeyTuples) - // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db. - sort.Sort(tuples) - - eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.QueryContext( - ctx, - stateBlockNIDsAsArray(stateBlockNIDs), - eventTypeNIDArray, - eventStateKeyNIDArray, - ) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed") - - var results []types.StateEntryList - var current types.StateEntryList - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err := rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { - return nil, err - } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - - // We can use binary search here because we sorted the tuples earlier - if !tuples.contains(entry.StateKeyTuple) { - // The select will return the cross product of types and state keys. - // So we need to check if type of the entry is in the list. - continue - } - - if types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we append the current entry to the results and start adding to a new one. - // The first time through the loop current will be empty. - if current.StateEntries != nil { - results = append(results, current) - } - current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)} - } - current.StateEntries = append(current.StateEntries, entry) - } - // Add the last entry to the list if it is not empty. - if current.StateEntries != nil { - results = append(results, current) - } - return results, rows.Err() -} - func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array { nids := make([]int64, len(stateBlockNIDs)) for i := range stateBlockNIDs { diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index 63175955f..f5b0c0cd4 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/util" ) const stateSnapshotSchema = ` @@ -38,21 +39,20 @@ const stateSnapshotSchema = ` -- because room state tends to accumulate small changes over time. Although if -- the list of deltas becomes too long it becomes more efficient to encode -- the full state under single state_block_nid. -CREATE SEQUENCE IF NOT EXISTS roomserver_state_snapshot_nid_seq; CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( - -- Local numeric ID for the state. - state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_snapshot_nid_seq'), - -- Local numeric ID of the room this state is for. - -- Unused in normal operation, but useful for background work or ad-hoc debugging. - room_nid bigint NOT NULL, - -- List of state_block_nids, stored sorted by state_block_nid. - state_block_nids bigint[] NOT NULL + state_snapshot_nid bigserial PRIMARY KEY, + state_snapshot_hash BYTEA UNIQUE, + room_nid bigint NOT NULL, + state_block_nids bigint[] NOT NULL ); ` const insertStateSQL = "" + - "INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)" + - " VALUES ($1, $2)" + + "INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids)" + + " VALUES ($1, $2, $3)" + + " ON CONFLICT (room_nid, state_block_nids) DO UPDATE SET room_nid=$2" + + // Performing an update, above, ensures that the RETURNING statement + // below will always return a valid state snapshot ID " RETURNING state_snapshot_nid" // Bulk state data NID lookup. @@ -81,13 +81,15 @@ func NewPostgresStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { } func (s *stateSnapshotStatements) InsertState( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs, ) (stateNID types.StateSnapshotNID, err error) { - nids := make([]int64, len(stateBlockNIDs)) - for i := range stateBlockNIDs { - nids[i] = int64(stateBlockNIDs[i]) + nids = nids[:util.SortAndUnique(nids)] + var id int64 + err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&id) + if err != nil { + return 0, err } - err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) + stateNID = types.StateSnapshotNID(id) return } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index bb3f841d0..a7eb4e100 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -17,6 +17,7 @@ package postgres import ( "database/sql" + "fmt" // Import the postgres database driver. _ "github.com/lib/pq" @@ -39,22 +40,23 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) var db *sql.DB var err error if db, err = sqlutil.Open(dbProperties); err != nil { - return nil, err + return nil, fmt.Errorf("sqlutil.Open: %w", err) } // Create tables before executing migrations so we don't fail if the table is missing, // and THEN prepare statements so we don't fail due to referencing new columns ms := membershipStatements{} if err := ms.execSchema(db); err != nil { - return nil, err + return nil, fmt.Errorf("ms.execSchema: %w", err) } m := sqlutil.NewMigrations() deltas.LoadAddForgottenColumn(m) + deltas.LoadStateBlocksRefactor(m) if err := m.RunDeltas(db, dbProperties); err != nil { - return nil, err + return nil, fmt.Errorf("m.RunDeltas: %w", err) } if err := d.prepare(db, cache); err != nil { - return nil, err + return nil, fmt.Errorf("d.prepare: %w", err) } return &d, nil diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 24b487723..c96746c0f 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -118,9 +118,24 @@ func (d *Database) StateEntriesForTuples( stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntryList, error) { - return d.StateBlockTable.BulkSelectFilteredStateBlockEntries( - ctx, stateBlockNIDs, stateKeyTuples, + entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( + ctx, stateBlockNIDs, ) + if err != nil { + return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) + } + lists := []types.StateEntryList{} + for i, entry := range entries { + entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, stateKeyTuples) + if err != nil { + return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) + } + lists = append(lists, types.StateEntryList{ + StateBlockNID: stateBlockNIDs[i], + StateEntries: entries, + }) + } + return lists, nil } func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { @@ -141,8 +156,28 @@ func (d *Database) AddState( stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, ) (stateNID types.StateSnapshotNID, err error) { + if len(stateBlockNIDs) > 0 { + // Check to see if the event already appears in any of the existing state + // blocks. If it does then we should not add it again, as this will just + // result in excess state blocks and snapshots. + // TODO: Investigate why this is happening - probably input_events.go! + blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) + if berr != nil { + return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr) + } + for i := len(state) - 1; i >= 0; i-- { + for _, events := range blocks { + for _, event := range events { + if state[i].EventNID == event { + state = append(state[:i], state[i+1:]...) + } + } + } + } + } err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if len(state) > 0 { + // If there's any state left to add then let's add new blocks. var stateBlockNID types.StateBlockNID stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state) if err != nil { @@ -237,7 +272,24 @@ func (d *Database) StateBlockNIDs( func (d *Database) StateEntries( ctx context.Context, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { - return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) + entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( + ctx, stateBlockNIDs, + ) + if err != nil { + return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) + } + lists := []types.StateEntryList{} + for i, entry := range entries { + eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, nil) + if err != nil { + return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) + } + lists = append(lists, types.StateEntryList{ + StateBlockNID: stateBlockNIDs[i], + StateEntries: eventNIDs, + }) + } + return lists, nil } func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { diff --git a/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go b/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go index 33fe9e2a9..d08ab02d5 100644 --- a/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go +++ b/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go @@ -24,6 +24,7 @@ import ( func LoadFromGoose() { goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn) + goose.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor) } func LoadAddForgottenColumn(m *sqlutil.Migrations) { diff --git a/roomserver/storage/sqlite3/deltas/20210416150927_state_blocks_refactor.go b/roomserver/storage/sqlite3/deltas/20210416150927_state_blocks_refactor.go new file mode 100644 index 000000000..53cf9ae69 --- /dev/null +++ b/roomserver/storage/sqlite3/deltas/20210416150927_state_blocks_refactor.go @@ -0,0 +1,152 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/util" +) + +func LoadStateBlocksRefactor(m *sqlutil.Migrations) { + m.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor) +} + +func UpStateBlocksRefactor(tx *sql.Tx) error { + if _, err := tx.Exec(`ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil { + return fmt.Errorf("tx.Exec: %w", err) + } + if _, err := tx.Exec(`ALTER TABLE roomserver_state_snapshots RENAME TO _roomserver_state_snapshots;`); err != nil { + return fmt.Errorf("tx.Exec: %w", err) + } + _, err := tx.Exec(` + CREATE TABLE IF NOT EXISTS roomserver_state_block ( + state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT, + state_block_hash BLOB UNIQUE, + event_nids TEXT NOT NULL DEFAULT '[]' + ); + `) + if err != nil { + return fmt.Errorf("tx.Exec: %w", err) + } + _, err = tx.Exec(` + CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( + state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, + state_snapshot_hash BLOB UNIQUE, + room_nid INTEGER NOT NULL, + state_block_nids TEXT NOT NULL DEFAULT '[]' + ); + `) + if err != nil { + return fmt.Errorf("tx.Exec: %w", err) + } + snapshotrows, err := tx.Query(`SELECT state_snapshot_nid, room_nid, state_block_nids FROM _roomserver_state_snapshots;`) + if err != nil { + return fmt.Errorf("tx.Query: %w", err) + } + defer internal.CloseAndLogIfError(context.TODO(), snapshotrows, "rows.close() failed") + for snapshotrows.Next() { + var snapshot types.StateSnapshotNID + var room types.RoomNID + var jsonblocks string + var blocks []types.StateBlockNID + if err = snapshotrows.Scan(&snapshot, &room, &jsonblocks); err != nil { + return fmt.Errorf("rows.Scan: %w", err) + } + if err = json.Unmarshal([]byte(jsonblocks), &blocks); err != nil { + return fmt.Errorf("json.Unmarshal: %w", err) + } + + var newblocks types.StateBlockNIDs + for _, block := range blocks { + if err = func() error { + blockrows, berr := tx.Query(`SELECT event_nid FROM _roomserver_state_block WHERE state_block_nid = $1`, block) + if berr != nil { + return fmt.Errorf("tx.Query (event nids from old block): %w", berr) + } + defer internal.CloseAndLogIfError(context.TODO(), blockrows, "rows.close() failed") + events := types.EventNIDs{} + for blockrows.Next() { + var event types.EventNID + if err = blockrows.Scan(&event); err != nil { + return fmt.Errorf("rows.Scan: %w", err) + } + events = append(events, event) + } + events = events[:util.SortAndUnique(events)] + eventjson, eerr := json.Marshal(events) + if eerr != nil { + return fmt.Errorf("json.Marshal: %w", eerr) + } + + var blocknid types.StateBlockNID + err = tx.QueryRow(` + INSERT INTO roomserver_state_block (state_block_hash, event_nids) + VALUES ($1, $2) + ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2 + RETURNING state_block_nid + `, events.Hash(), eventjson).Scan(&blocknid) + if err != nil { + return fmt.Errorf("tx.QueryRow.Scan (insert new block): %w", err) + } + newblocks = append(newblocks, blocknid) + return nil + }(); err != nil { + return err + } + + newblocksjson, jerr := json.Marshal(newblocks) + if jerr != nil { + return fmt.Errorf("json.Marshal (new blocks): %w", jerr) + } + var newsnapshot types.StateSnapshotNID + err = tx.QueryRow(` + INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids) + VALUES ($1, $2, $3) + ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2 + RETURNING state_snapshot_nid + `, newblocks.Hash(), room, newblocksjson).Scan(&newsnapshot) + if err != nil { + return fmt.Errorf("tx.QueryRow.Scan (insert new snapshot): %w", err) + } + + if _, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2`, newsnapshot, snapshot); err != nil { + return fmt.Errorf("tx.Exec (update events): %w", err) + } + if _, err = tx.Exec(`UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2`, newsnapshot, snapshot); err != nil { + return fmt.Errorf("tx.Exec (update rooms): %w", err) + } + } + } + + if _, err = tx.Exec(`DROP TABLE _roomserver_state_snapshots;`); err != nil { + return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err) + } + if _, err = tx.Exec(`DROP TABLE _roomserver_state_block;`); err != nil { + return fmt.Errorf("tx.Exec (delete old block table): %w", err) + } + + return nil +} + +func DownStateBlocksRefactor(tx *sql.Tx) error { + panic("Downgrading state storage is not supported") +} diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 53269657e..420a4845e 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "encoding/json" "fmt" + "sort" "strings" "github.com/matrix-org/dendrite/internal" @@ -63,6 +64,11 @@ const bulkSelectStateEventByIDSQL = "" + " WHERE event_id IN ($1)" + " ORDER BY event_type_nid, event_state_key_nid ASC" +const bulkSelectStateEventByNIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + + " WHERE event_nid IN ($1)" + // Rest of query is built by BulkSelectStateEventByNID + const bulkSelectStateAtEventByIDSQL = "" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + " WHERE event_id IN ($1)" @@ -232,6 +238,61 @@ func (s *eventStatements) BulkSelectStateEventByID( return results, err } +// bulkSelectStateEventByID lookups a list of state events by event ID. +// If any of the requested events are missing from the database it returns a types.MissingEventError +func (s *eventStatements) BulkSelectStateEventByNID( + ctx context.Context, eventNIDs []types.EventNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntry, error) { + tuples := stateKeyTupleSorter(stateKeyTuples) + sort.Sort(tuples) + eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() + params := make([]interface{}, 0, len(eventNIDs)+len(eventTypeNIDArray)+len(eventStateKeyNIDArray)) + selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1) + for _, v := range eventNIDs { + params = append(params, v) + } + if len(eventTypeNIDArray) > 0 { + selectOrig += " AND event_type_nid IN " + sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(params)) + for _, v := range eventTypeNIDArray { + params = append(params, v) + } + } + if len(eventStateKeyNIDArray) > 0 { + selectOrig += " AND event_state_key_nid IN " + sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(params)) + for _, v := range eventStateKeyNIDArray { + params = append(params, v) + } + } + selectOrig += " ORDER BY event_type_nid, event_state_key_nid ASC" + selectStmt, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, fmt.Errorf("s.db.Prepare: %w", err) + } + rows, err := selectStmt.QueryContext(ctx, params...) + if err != nil { + return nil, fmt.Errorf("selectStmt.QueryContext: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed") + // We know that we will only get as many results as event IDs + // because of the unique constraint on event IDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than IDs so we adjust the length of the slice before returning it. + results := make([]types.StateEntry, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventTypeNID, + &result.EventStateKeyNID, + &result.EventNID, + ); err != nil { + return nil, err + } + } + return results, err +} + // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 2c544f2b0..508779dd5 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -18,6 +18,7 @@ package sqlite3 import ( "context" "database/sql" + "encoding/json" "fmt" "sort" "strings" @@ -32,52 +33,27 @@ import ( const stateDataSchema = ` CREATE TABLE IF NOT EXISTS roomserver_state_block ( - state_block_nid INTEGER NOT NULL, - event_type_nid INTEGER NOT NULL, - event_state_key_nid INTEGER NOT NULL, - event_nid INTEGER NOT NULL, - UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) + state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT, + state_block_hash BLOB UNIQUE, + event_nids TEXT NOT NULL DEFAULT '[]' ); ` -const insertStateDataSQL = "" + - "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + - " VALUES ($1, $2, $3, $4)" - -const selectNextStateBlockNIDSQL = ` -SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block +const insertStateDataSQL = ` + INSERT INTO roomserver_state_block (state_block_hash, event_nids) + VALUES ($1, $2) + ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2 + RETURNING state_block_nid ` -// Bulk state lookup by numeric state block ID. -// Sort by the state_block_nid, event_type_nid, event_state_key_nid -// This means that all the entries for a given state_block_nid will appear -// together in the list and those entries will sorted by event_type_nid -// and event_state_key_nid. This property makes it easier to merge two -// state data blocks together. const bulkSelectStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" - -// Bulk state lookup by numeric state block ID. -// Filters the rows in each block to the requested types and state keys. -// We would like to restrict to particular type state key pairs but we are -// restricted by the query language to pull the cross product of a list -// of types and a list state_keys. So we have to filter the result in the -// application to restrict it to the list of event types and state keys we -// actually wanted. -const bulkSelectFilteredStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + - " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + "SELECT state_block_nid, event_nids" + + " FROM roomserver_state_block WHERE state_block_nid IN ($1)" type stateBlockStatements struct { - db *sql.DB - insertStateDataStmt *sql.Stmt - selectNextStateBlockNIDStmt *sql.Stmt - bulkSelectStateBlockEntriesStmt *sql.Stmt - bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt + db *sql.DB + insertStateDataStmt *sql.Stmt + bulkSelectStateBlockEntriesStmt *sql.Stmt } func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { @@ -91,169 +67,69 @@ func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { return s, shared.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, - {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, - {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, }.Prepare(db) } func (s *stateBlockStatements) BulkInsertStateData( - ctx context.Context, txn *sql.Tx, - entries []types.StateEntry, -) (types.StateBlockNID, error) { - if len(entries) == 0 { - return 0, nil + ctx context.Context, + txn *sql.Tx, + entries types.StateEntries, +) (id types.StateBlockNID, err error) { + entries = entries[:util.SortAndUnique(entries)] + var nids types.EventNIDs + for _, e := range entries { + nids = append(nids, e.EventNID) } - var stateBlockNID types.StateBlockNID - err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) + js, err := json.Marshal(nids) if err != nil { - return 0, err + return 0, fmt.Errorf("json.Marshal: %w", err) } - for _, entry := range entries { - _, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext( - ctx, - int64(stateBlockNID), - int64(entry.EventTypeNID), - int64(entry.EventStateKeyNID), - int64(entry.EventNID), - ) - if err != nil { - return 0, err - } - } - return stateBlockNID, err + err = s.insertStateDataStmt.QueryRowContext( + ctx, nids.Hash(), js, + ).Scan(&id) + return } func (s *stateBlockStatements) BulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) ([]types.StateEntryList, error) { - nids := make([]interface{}, len(stateBlockNIDs)) - for k, v := range stateBlockNIDs { - nids[k] = v + ctx context.Context, stateBlockNIDs types.StateBlockNIDs, +) ([][]types.EventNID, error) { + intfs := make([]interface{}, len(stateBlockNIDs)) + for i := range stateBlockNIDs { + intfs[i] = int64(stateBlockNIDs[i]) } - selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(intfs)), 1) selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } - rows, err := selectStmt.QueryContext(ctx, nids...) + rows, err := selectStmt.QueryContext(ctx, intfs...) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed") - results := make([]types.StateEntryList, len(stateBlockNIDs)) - // current is a pointer to the StateEntryList to append the state entries to. - var current *types.StateEntryList + results := make([][]types.EventNID, len(stateBlockNIDs)) i := 0 - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err := rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { + for ; rows.Next(); i++ { + var stateBlockNID types.StateBlockNID + var result json.RawMessage + if err = rows.Scan(&stateBlockNID, &result); err != nil { return nil, err } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we start appending to the next entry in the list. - current = &results[i] - current.StateBlockNID = types.StateBlockNID(stateBlockNID) - i++ + r := []types.EventNID{} + if err = json.Unmarshal(result, &r); err != nil { + return nil, fmt.Errorf("json.Unmarshal: %w", err) } - current.StateEntries = append(current.StateEntries, entry) + results[i] = r } - if i != len(nids) { - return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids)) - } - return results, nil -} - -func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntryList, error) { - tuples := stateKeyTupleSorter(stateKeyTuples) - // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db. - sort.Sort(tuples) - - eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1) - sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1) - sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1) - - var params []interface{} - for _, val := range stateBlockNIDs { - params = append(params, int64(val)) - } - for _, val := range eventTypeNIDArray { - params = append(params, val) - } - for _, val := range eventStateKeyNIDArray { - params = append(params, val) - } - - rows, err := s.db.QueryContext( - ctx, - sqlStatement, - params..., - ) - if err != nil { + if err = rows.Err(); err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed") - - var results []types.StateEntryList - var current types.StateEntryList - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err := rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { - return nil, err - } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - - // We can use binary search here because we sorted the tuples earlier - if !tuples.contains(entry.StateKeyTuple) { - // The select will return the cross product of types and state keys. - // So we need to check if type of the entry is in the list. - continue - } - - if types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we append the current entry to the results and start adding to a new one. - // The first time through the loop current will be empty. - if current.StateEntries != nil { - results = append(results, current) - } - current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)} - } - current.StateEntries = append(current.StateEntries, entry) + if i != len(stateBlockNIDs) { + return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", len(results), len(stateBlockNIDs)) } - // Add the last entry to the list if it is not empty. - if current.StateEntries != nil { - results = append(results, current) - } - return results, nil + return results, err } type stateKeyTupleSorter []types.StateKeyTuple diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index bf49f62c2..15bf521f4 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -27,19 +27,24 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/util" ) const stateSnapshotSchema = ` CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, + state_snapshot_hash BLOB UNIQUE, room_nid INTEGER NOT NULL, state_block_nids TEXT NOT NULL DEFAULT '[]' ); ` const insertStateSQL = ` - INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids) - VALUES ($1, $2);` + INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids) + VALUES ($1, $2, $3) + ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2 + RETURNING state_snapshot_nid +` // Bulk state data NID lookup. // Sorting by state_snapshot_nid means we can use binary search over the result @@ -70,22 +75,20 @@ func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { } func (s *stateSnapshotStatements) InsertState( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs, ) (stateNID types.StateSnapshotNID, err error) { + stateBlockNIDs = stateBlockNIDs[:util.SortAndUnique(stateBlockNIDs)] stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs) if err != nil { return } insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt) - res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)) + var id int64 + err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&id) if err != nil { return 0, err } - lastRowID, err := res.LastInsertId() - if err != nil { - return 0, err - } - stateNID = types.StateSnapshotNID(lastRowID) + stateNID = types.StateSnapshotNID(id) return } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 8e608a6db..351bbc316 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -61,6 +61,7 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) } m := sqlutil.NewMigrations() deltas.LoadAddForgottenColumn(m) + deltas.LoadStateBlocksRefactor(m) if err := m.RunDeltas(db, dbProperties); err != nil { return nil, err } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 26bf5cf04..dd486873a 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -43,6 +43,7 @@ type Events interface { // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) + BulkSelectStateEventByNID(ctx context.Context, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error) // BulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. @@ -81,14 +82,14 @@ type Transactions interface { } type StateSnapshot interface { - InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error) + InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error) BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) } type StateBlock interface { - BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries []types.StateEntry) (types.StateBlockNID, error) - BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) - BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) + BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error) + BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error) + //BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) } type RoomAliases interface { diff --git a/roomserver/types/types.go b/roomserver/types/types.go index e866f6cbe..d7e03ad61 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -16,9 +16,11 @@ package types import ( + "encoding/json" "sort" "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/blake2b" ) // EventTypeNID is a numeric ID for an event type. @@ -40,6 +42,38 @@ type StateSnapshotNID int64 // These blocks of state data are combined to form the actual state. type StateBlockNID int64 +// EventNIDs is used to sort and dedupe event NIDs. +type EventNIDs []EventNID + +func (a EventNIDs) Len() int { return len(a) } +func (a EventNIDs) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a EventNIDs) Less(i, j int) bool { return a[i] < a[j] } + +func (a EventNIDs) Hash() []byte { + j, err := json.Marshal(a) + if err != nil { + return nil + } + h := blake2b.Sum256(j) + return h[:] +} + +// StateBlockNIDs is used to sort and dedupe state block NIDs. +type StateBlockNIDs []StateBlockNID + +func (a StateBlockNIDs) Len() int { return len(a) } +func (a StateBlockNIDs) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a StateBlockNIDs) Less(i, j int) bool { return a[i] < a[j] } + +func (a StateBlockNIDs) Hash() []byte { + j, err := json.Marshal(a) + if err != nil { + return nil + } + h := blake2b.Sum256(j) + return h[:] +} + // A StateKeyTuple is a pair of a numeric event type and a numeric state key. // It is used to lookup state entries. type StateKeyTuple struct { @@ -65,6 +99,12 @@ type StateEntry struct { EventNID EventNID } +type StateEntries []StateEntry + +func (a StateEntries) Len() int { return len(a) } +func (a StateEntries) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a StateEntries) Less(i, j int) bool { return a[i].EventNID < a[j].EventNID } + // LessThan returns true if this state entry is less than the other state entry. // The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries. func (a StateEntry) LessThan(b StateEntry) bool {