diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go new file mode 100644 index 000000000..fe2c21ba0 --- /dev/null +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -0,0 +1,98 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/roomserver/types" +) + +const eventJSONSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_event_json ( + event_nid INTEGER NOT NULL PRIMARY KEY, + event_json TEXT NOT NULL + ); +` + +const insertEventJSONSQL = ` + INSERT INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2) + ON CONFLICT DO NOTHING +` + +// Bulk event JSON lookup by numeric event ID. +// Sort by the numeric event ID. +// This means that we can use binary search to lookup by numeric event ID. +const bulkSelectEventJSONSQL = ` + SELECT event_nid, event_json FROM roomserver_event_json + WHERE event_nid IN ($1) + ORDER BY event_nid ASC" +` + +type eventJSONStatements struct { + insertEventJSONStmt *sql.Stmt + bulkSelectEventJSONStmt *sql.Stmt +} + +func (s *eventJSONStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(eventJSONSchema) + if err != nil { + return + } + return statementList{ + {&s.insertEventJSONStmt, insertEventJSONSQL}, + {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, + }.prepare(db) +} + +func (s *eventJSONStatements) insertEventJSON( + ctx context.Context, eventNID types.EventNID, eventJSON []byte, +) error { + _, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON) + return err +} + +type eventJSONPair struct { + EventNID types.EventNID + EventJSON []byte +} + +func (s *eventJSONStatements) bulkSelectEventJSON( + ctx context.Context, eventNIDs []types.EventNID, +) ([]eventJSONPair, error) { + rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + + // We know that we will only get as many results as event NIDs + // because of the unique constraint on event NIDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than NIDs so we adjust the length of the slice before returning it. + results := make([]eventJSONPair, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + var eventNID int64 + if err := rows.Scan(&eventNID, &result.EventJSON); err != nil { + return nil, err + } + result.EventNID = types.EventNID(eventNID) + } + return results[:i], nil +} diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go new file mode 100644 index 000000000..1f199a0af --- /dev/null +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -0,0 +1,148 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const eventStateKeysSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_event_state_keys ( + event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT, + event_state_key TEXT NOT NULL UNIQUE + ); + INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key) + VALUES (1, '') + ON CONFLICT DO NOTHING; +` + +// Same as insertEventTypeNIDSQL +const insertEventStateKeyNIDSQL = ` + INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1) + ON CONFLICT DO NOTHING; + SELECT event_state_key_nid FROM roomserver_event_state_keys + WHERE rowid = last_insert_rowid(); +` + +const selectEventStateKeyNIDSQL = ` + SELECT event_state_key_nid FROM roomserver_event_state_keys + WHERE event_state_key = $1 +` + +// Bulk lookup from string state key to numeric ID for that state key. +// Takes an array of strings as the query parameter. +const bulkSelectEventStateKeyNIDSQL = ` + SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys + WHERE event_state_key IN ($1) +` + +// Bulk lookup from numeric ID to string state key for that state key. +// Takes an array of strings as the query parameter. +const bulkSelectEventStateKeySQL = ` + SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys + WHERE event_state_key_nid IN ($1) +` + +type eventStateKeyStatements struct { + insertEventStateKeyNIDStmt *sql.Stmt + selectEventStateKeyNIDStmt *sql.Stmt + bulkSelectEventStateKeyNIDStmt *sql.Stmt + bulkSelectEventStateKeyStmt *sql.Stmt +} + +func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(eventStateKeysSchema) + if err != nil { + return + } + return statementList{ + {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, + {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, + {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, + {&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL}, + }.prepare(db) +} + +func (s *eventStateKeyStatements) insertEventStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (types.EventStateKeyNID, error) { + var eventStateKeyNID int64 + stmt := common.TxStmt(txn, s.insertEventStateKeyNIDStmt) + err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) + return types.EventStateKeyNID(eventStateKeyNID), err +} + +func (s *eventStateKeyStatements) selectEventStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (types.EventStateKeyNID, error) { + var eventStateKeyNID int64 + stmt := common.TxStmt(txn, s.selectEventStateKeyNIDStmt) + err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) + return types.EventStateKeyNID(eventStateKeyNID), err +} + +func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( + ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext( + ctx, pq.StringArray(eventStateKeys), + ) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + + result := make(map[string]types.EventStateKeyNID, len(eventStateKeys)) + for rows.Next() { + var stateKey string + var stateKeyNID int64 + if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { + return nil, err + } + result[stateKey] = types.EventStateKeyNID(stateKeyNID) + } + return result, nil +} + +func (s *eventStateKeyStatements) bulkSelectEventStateKey( + ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, +) (map[types.EventStateKeyNID]string, error) { + nIDs := make(pq.Int64Array, len(eventStateKeyNIDs)) + for i := range eventStateKeyNIDs { + nIDs[i] = int64(eventStateKeyNIDs[i]) + } + rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + + result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs)) + for rows.Next() { + var stateKey string + var stateKeyNID int64 + if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { + return nil, err + } + result[types.EventStateKeyNID(stateKeyNID)] = stateKey + } + return result, nil +} diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go new file mode 100644 index 000000000..da91a115b --- /dev/null +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -0,0 +1,126 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const eventTypesSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_event_types ( + event_type_nid INTEGER PRIMARY KEY AUTOINCREMENT, + event_type TEXT NOT NULL UNIQUE + ); + INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES + (1, 'm.room.create'), + (2, 'm.room.power_levels'), + (3, 'm.room.join_rules'), + (4, 'm.room.third_party_invite'), + (5, 'm.room.member'), + (6, 'm.room.redaction'), + (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING; +` + +// Assign a new numeric event type ID. +// The usual case is that the event type is not in the database. +// In that case the ID will be assigned using the next value from the sequence. +// We use `RETURNING` to tell postgres to return the assigned ID. +// But it's possible that the type was added in a query that raced with us. +// This will result in a conflict on the event_type_unique constraint, in this +// case we do nothing. Postgresql won't return a row in that case so we rely on +// the caller catching the sql.ErrNoRows error and running a select to get the row. +// We could get postgresql to return the row on a conflict by updating the row +// but it doesn't seem like a good idea to modify the rows just to make postgresql +// return it. Modifying the rows will cause postgres to assign a new tuple for the +// row even though the data doesn't change resulting in unncesssary modifications +// to the indexes. +const insertEventTypeNIDSQL = ` + INSERT INTO roomserver_event_types (event_type) VALUES ($1) + ON CONFLICT DO NOTHING; + SELECT event_type_nid FROM roomserver_event_types + WHERE rowid = last_insert_rowid(); +` + +const selectEventTypeNIDSQL = ` + SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1 +` + +// Bulk lookup from string event type to numeric ID for that event type. +// Takes an array of strings as the query parameter. +const bulkSelectEventTypeNIDSQL = ` + SELECT event_type, event_type_nid FROM roomserver_event_types + WHERE event_type IN ($1) +` + +type eventTypeStatements struct { + insertEventTypeNIDStmt *sql.Stmt + selectEventTypeNIDStmt *sql.Stmt + bulkSelectEventTypeNIDStmt *sql.Stmt +} + +func (s *eventTypeStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(eventTypesSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, + {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, + {&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL}, + }.prepare(db) +} + +func (s *eventTypeStatements) insertEventTypeNID( + ctx context.Context, eventType string, +) (types.EventTypeNID, error) { + var eventTypeNID int64 + err := s.insertEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) + return types.EventTypeNID(eventTypeNID), err +} + +func (s *eventTypeStatements) selectEventTypeNID( + ctx context.Context, eventType string, +) (types.EventTypeNID, error) { + var eventTypeNID int64 + err := s.selectEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) + return types.EventTypeNID(eventTypeNID), err +} + +func (s *eventTypeStatements) bulkSelectEventTypeNID( + ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + + result := make(map[string]types.EventTypeNID, len(eventTypes)) + for rows.Next() { + var eventType string + var eventTypeNID int64 + if err := rows.Scan(&eventType, &eventTypeNID); err != nil { + return nil, err + } + result[eventType] = types.EventTypeNID(eventTypeNID) + } + return result, nil +} diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go new file mode 100644 index 000000000..990c7b558 --- /dev/null +++ b/roomserver/storage/sqlite3/events_table.go @@ -0,0 +1,387 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const eventsSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_events ( + event_nid INTEGER PRIMARY KEY AUTOINCREMENT, + room_nid INTEGER NOT NULL, + event_type_nid INTEGER NOT NULL, + event_state_key_nid INTEGER NOT NULL, + sent_to_output BOOLEAN NOT NULL DEFAULT FALSE, + state_snapshot_nid INTEGER NOT NULL DEFAULT 0, + depth INTEGER NOT NULL, + event_id TEXT NOT NULL UNIQUE, + reference_sha256 BLOB NOT NULL, + auth_event_nids TEXT NOT NULL + ); +` + +const insertEventSQL = ` + INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT DO NOTHING; + SELECT event_nid, state_snapshot_nid FROM roomserver_events + WHERE rowid = last_insert_rowid(); +` + +const selectEventSQL = "" + + "SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1" + +// Bulk lookup of events by string ID. +// Sort by the numeric IDs for event type and state key. +// This means we can use binary search to lookup entries by type and state key. +const bulkSelectStateEventByIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + + " WHERE event_id IN ($1)" + + " ORDER BY event_type_nid, event_state_key_nid ASC" + +const bulkSelectStateAtEventByIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid FROM roomserver_events" + + " WHERE event_id IN ($1)" + +const updateEventStateSQL = "" + + "UPDATE roomserver_events SET state_snapshot_nid = $2 WHERE event_nid = $1" + +const selectEventSentToOutputSQL = "" + + "SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1" + +const updateEventSentToOutputSQL = "" + + "UPDATE roomserver_events SET sent_to_output = TRUE WHERE event_nid = $1" + +const selectEventIDSQL = "" + + "SELECT event_id FROM roomserver_events WHERE event_nid = $1" + +const bulkSelectStateAtEventAndReferenceSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" + + " FROM roomserver_events WHERE event_nid IN ($1)" + +const bulkSelectEventReferenceSQL = "" + + "SELECT event_id, reference_sha256 FROM roomserver_events WHERE event_nid IN ($1)" + +const bulkSelectEventIDSQL = "" + + "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid IN ($1)" + +const bulkSelectEventNIDSQL = "" + + "SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)" + +const selectMaxEventDepthSQL = "" + + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" + +type eventStatements struct { + insertEventStmt *sql.Stmt + selectEventStmt *sql.Stmt + bulkSelectStateEventByIDStmt *sql.Stmt + bulkSelectStateAtEventByIDStmt *sql.Stmt + updateEventStateStmt *sql.Stmt + selectEventSentToOutputStmt *sql.Stmt + updateEventSentToOutputStmt *sql.Stmt + selectEventIDStmt *sql.Stmt + bulkSelectStateAtEventAndReferenceStmt *sql.Stmt + bulkSelectEventReferenceStmt *sql.Stmt + bulkSelectEventIDStmt *sql.Stmt + bulkSelectEventNIDStmt *sql.Stmt + selectMaxEventDepthStmt *sql.Stmt +} + +func (s *eventStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(eventsSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertEventStmt, insertEventSQL}, + {&s.selectEventStmt, selectEventSQL}, + {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, + {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, + {&s.updateEventStateStmt, updateEventStateSQL}, + {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, + {&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL}, + {&s.selectEventIDStmt, selectEventIDSQL}, + {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, + {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, + {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, + {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, + {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, + }.prepare(db) +} + +func (s *eventStatements) insertEvent( + ctx context.Context, + roomNID types.RoomNID, + eventTypeNID types.EventTypeNID, + eventStateKeyNID types.EventStateKeyNID, + eventID string, + referenceSHA256 []byte, + authEventNIDs []types.EventNID, + depth int64, +) (types.EventNID, types.StateSnapshotNID, error) { + var eventNID int64 + var stateNID int64 + err := s.insertEventStmt.QueryRowContext( + ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), + eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, + ).Scan(&eventNID, &stateNID) + return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err +} + +func (s *eventStatements) selectEvent( + ctx context.Context, eventID string, +) (types.EventNID, types.StateSnapshotNID, error) { + var eventNID int64 + var stateNID int64 + err := s.selectEventStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID) + return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err +} + +// bulkSelectStateEventByID lookups a list of state events by event ID. +// If any of the requested events are missing from the database it returns a types.MissingEventError +func (s *eventStatements) bulkSelectStateEventByID( + ctx context.Context, eventIDs []string, +) ([]types.StateEntry, error) { + rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + // We know that we will only get as many results as event IDs + // because of the unique constraint on event IDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than IDs so we adjust the length of the slice before returning it. + results := make([]types.StateEntry, len(eventIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventTypeNID, + &result.EventStateKeyNID, + &result.EventNID, + ); err != nil { + return nil, err + } + } + if i != len(eventIDs) { + // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have. + // We don't know which ones were missing because we don't return the string IDs in the query. + // However it should be possible debug this by replaying queries or entries from the input kafka logs. + // If this turns out to be impossible and we do need the debug information here, it would be better + // to do it as a separate query rather than slowing down/complicating the common case. + return nil, types.MissingEventError( + fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)), + ) + } + return results, err +} + +// bulkSelectStateAtEventByID lookups the state at a list of events by event ID. +// If any of the requested events are missing from the database it returns a types.MissingEventError. +// If we do not have the state for any of the requested events it returns a types.MissingEventError. +func (s *eventStatements) bulkSelectStateAtEventByID( + ctx context.Context, eventIDs []string, +) ([]types.StateAtEvent, error) { + rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make([]types.StateAtEvent, len(eventIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventTypeNID, + &result.EventStateKeyNID, + &result.EventNID, + &result.BeforeStateSnapshotNID, + ); err != nil { + return nil, err + } + if result.BeforeStateSnapshotNID == 0 { + return nil, types.MissingEventError( + fmt.Sprintf("storage: missing state for event NID %d", result.EventNID), + ) + } + } + if i != len(eventIDs) { + return nil, types.MissingEventError( + fmt.Sprintf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)), + ) + } + return results, err +} + +func (s *eventStatements) updateEventState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, +) error { + _, err := s.updateEventStateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID)) + return err +} + +func (s *eventStatements) selectEventSentToOutput( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, +) (sentToOutput bool, err error) { + stmt := common.TxStmt(txn, s.selectEventSentToOutputStmt) + err = stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) + return +} + +func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { + stmt := common.TxStmt(txn, s.updateEventSentToOutputStmt) + _, err := stmt.ExecContext(ctx, int64(eventNID)) + return err +} + +func (s *eventStatements) selectEventID( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, +) (eventID string, err error) { + stmt := common.TxStmt(txn, s.selectEventIDStmt) + err = stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID) + return +} + +func (s *eventStatements) bulkSelectStateAtEventAndReference( + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, +) ([]types.StateAtEventAndReference, error) { + stmt := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make([]types.StateAtEventAndReference, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + var ( + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + stateSnapshotNID int64 + eventID string + eventSHA256 []byte + ) + if err = rows.Scan( + &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256, + ); err != nil { + return nil, err + } + result := &results[i] + result.EventTypeNID = types.EventTypeNID(eventTypeNID) + result.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) + result.EventNID = types.EventNID(eventNID) + result.BeforeStateSnapshotNID = types.StateSnapshotNID(stateSnapshotNID) + result.EventID = eventID + result.EventSHA256 = eventSHA256 + } + if i != len(eventNIDs) { + return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) + } + return results, nil +} + +func (s *eventStatements) bulkSelectEventReference( + ctx context.Context, eventNIDs []types.EventNID, +) ([]gomatrixserverlib.EventReference, error) { + rows, err := s.bulkSelectEventReferenceStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make([]gomatrixserverlib.EventReference, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan(&result.EventID, &result.EventSHA256); err != nil { + return nil, err + } + } + if i != len(eventNIDs) { + return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) + } + return results, nil +} + +// bulkSelectEventID returns a map from numeric event ID to string event ID. +func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { + rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make(map[types.EventNID]string, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + var eventNID int64 + var eventID string + if err = rows.Scan(&eventNID, &eventID); err != nil { + return nil, err + } + results[types.EventNID(eventNID)] = eventID + } + if i != len(eventNIDs) { + return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) + } + return results, nil +} + +// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. +// If an event ID is not in the database then it is omitted from the map. +func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { + rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make(map[string]types.EventNID, len(eventIDs)) + for rows.Next() { + var eventID string + var eventNID int64 + if err = rows.Scan(&eventID, &eventNID); err != nil { + return nil, err + } + results[eventID] = types.EventNID(eventNID) + } + return results, nil +} + +func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) { + var result int64 + stmt := s.selectMaxEventDepthStmt + err := stmt.QueryRowContext(ctx, eventNIDsAsArray(eventNIDs)).Scan(&result) + if err != nil { + return 0, err + } + return result, nil +} + +func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array { + nids := make([]int64, len(eventNIDs)) + for i := range eventNIDs { + nids[i] = int64(eventNIDs[i]) + } + return nids +} diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go new file mode 100644 index 000000000..03b4bc5f1 --- /dev/null +++ b/roomserver/storage/sqlite3/invite_table.go @@ -0,0 +1,140 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const inviteSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_invites ( + invite_event_id TEXT PRIMARY KEY, + room_nid INTEGER NOT NULL, + target_nid INTEGER NOT NULL, + sender_nid INTEGER NOT NULL DEFAULT 0, + retired BOOLEAN NOT NULL DEFAULT FALSE, + invite_event_json TEXT NOT NULL + ); + + CREATE INDEX IF NOT EXISTS roomserver_invites_active_idx ON roomserver_invites (target_nid, room_nid) + WHERE NOT retired; +` +const insertInviteEventSQL = "" + + "INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," + + " sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" + + " ON CONFLICT DO NOTHING" + +const selectInviteActiveForUserInRoomSQL = "" + + "SELECT sender_nid FROM roomserver_invites" + + " WHERE target_nid = $1 AND room_nid = $2" + + " AND NOT retired" + +// Retire every active invite for a user in a room. +// Ideally we'd know which invite events were retired by a given update so we +// wouldn't need to remove every active invite. +// However the matrix protocol doesn't give us a way to reliably identify the +// invites that were retired, so we are forced to retire all of them. +const updateInviteRetiredSQL = ` + UPDATE roomserver_invites SET retired = TRUE + WHERE room_nid = $1 AND target_nid = $2 AND NOT retired; + SELECT invite_event_id FROM roomserver_invites + WHERE rowid = last_insert_rowid(); +` + +type inviteStatements struct { + insertInviteEventStmt *sql.Stmt + selectInviteActiveForUserInRoomStmt *sql.Stmt + updateInviteRetiredStmt *sql.Stmt +} + +func (s *inviteStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(inviteSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertInviteEventStmt, insertInviteEventSQL}, + {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, + {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, + }.prepare(db) +} + +func (s *inviteStatements) insertInviteEvent( + ctx context.Context, + txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, + targetUserNID, senderUserNID types.EventStateKeyNID, + inviteEventJSON []byte, +) (bool, error) { + result, err := common.TxStmt(txn, s.insertInviteEventStmt).ExecContext( + ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, + ) + if err != nil { + return false, err + } + count, err := result.RowsAffected() + if err != nil { + return false, err + } + return count != 0, nil +} + +func (s *inviteStatements) updateInviteRetired( + ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, +) (eventIDs []string, err error) { + stmt := common.TxStmt(txn, s.updateInviteRetiredStmt) + rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) + if err != nil { + return nil, err + } + defer (func() { err = rows.Close() })() + for rows.Next() { + var inviteEventID string + if err := rows.Scan(&inviteEventID); err != nil { + return nil, err + } + eventIDs = append(eventIDs, inviteEventID) + } + return +} + +// selectInviteActiveForUserInRoom returns a list of sender state key NIDs +func (s *inviteStatements) selectInviteActiveForUserInRoom( + ctx context.Context, + targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, +) ([]types.EventStateKeyNID, error) { + rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( + ctx, targetUserNID, roomNID, + ) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + var result []types.EventStateKeyNID + for rows.Next() { + var senderUserNID int64 + if err := rows.Scan(&senderUserNID); err != nil { + return nil, err + } + result = append(result, types.EventStateKeyNID(senderUserNID)) + } + return result, nil +} diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go new file mode 100644 index 000000000..eaf87d2b0 --- /dev/null +++ b/roomserver/storage/sqlite3/membership_table.go @@ -0,0 +1,173 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +type membershipState int64 + +const ( + membershipStateLeaveOrBan membershipState = 1 + membershipStateInvite membershipState = 2 + membershipStateJoin membershipState = 3 +) + +const membershipSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_membership ( + room_nid INTEGER NOT NULL, + target_nid INTEGER NOT NULL, + sender_nid INTEGER NOT NULL DEFAULT 0, + membership_nid INTEGER NOT NULL DEFAULT 1, + event_nid INTEGER NOT NULL DEFAULT 0, + UNIQUE (room_nid, target_nid) + ); +` + +// Insert a row in to membership table so that it can be locked by the +// SELECT FOR UPDATE +const insertMembershipSQL = "" + + "INSERT INTO roomserver_membership (room_nid, target_nid)" + + " VALUES ($1, $2)" + + " ON CONFLICT DO NOTHING" + +const selectMembershipFromRoomAndTargetSQL = "" + + "SELECT membership_nid, event_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND target_nid = $2" + +const selectMembershipsFromRoomAndMembershipSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND membership_nid = $2" + +const selectMembershipsFromRoomSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1" + +const selectMembershipForUpdateSQL = "" + + "SELECT membership_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND target_nid = $2" + +const updateMembershipSQL = "" + + "UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" + + " WHERE room_nid = $1 AND target_nid = $2" + +type membershipStatements struct { + insertMembershipStmt *sql.Stmt + selectMembershipForUpdateStmt *sql.Stmt + selectMembershipFromRoomAndTargetStmt *sql.Stmt + selectMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectMembershipsFromRoomStmt *sql.Stmt + updateMembershipStmt *sql.Stmt +} + +func (s *membershipStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(membershipSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertMembershipStmt, insertMembershipSQL}, + {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, + {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, + {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL}, + {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, + {&s.updateMembershipStmt, updateMembershipSQL}, + }.prepare(db) +} + +func (s *membershipStatements) insertMembership( + ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, +) error { + stmt := common.TxStmt(txn, s.insertMembershipStmt) + _, err := stmt.ExecContext(ctx, roomNID, targetUserNID) + return err +} + +func (s *membershipStatements) selectMembershipForUpdate( + ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, +) (membership membershipState, err error) { + err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext( + ctx, roomNID, targetUserNID, + ).Scan(&membership) + return +} + +func (s *membershipStatements) selectMembershipFromRoomAndTarget( + ctx context.Context, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, +) (eventNID types.EventNID, membership membershipState, err error) { + err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( + ctx, roomNID, targetUserNID, + ).Scan(&membership, &eventNID) + return +} + +func (s *membershipStatements) selectMembershipsFromRoom( + ctx context.Context, roomNID types.RoomNID, +) (eventNIDs []types.EventNID, err error) { + rows, err := s.selectMembershipsFromRoomStmt.QueryContext(ctx, roomNID) + if err != nil { + return + } + + for rows.Next() { + var eNID types.EventNID + if err = rows.Scan(&eNID); err != nil { + return + } + eventNIDs = append(eventNIDs, eNID) + } + return +} +func (s *membershipStatements) selectMembershipsFromRoomAndMembership( + ctx context.Context, + roomNID types.RoomNID, membership membershipState, +) (eventNIDs []types.EventNID, err error) { + stmt := s.selectMembershipsFromRoomAndMembershipStmt + rows, err := stmt.QueryContext(ctx, roomNID, membership) + if err != nil { + return + } + + for rows.Next() { + var eNID types.EventNID + if err = rows.Scan(&eNID); err != nil { + return + } + eventNIDs = append(eventNIDs, eNID) + } + return +} + +func (s *membershipStatements) updateMembership( + ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + senderUserNID types.EventStateKeyNID, membership membershipState, + eventNID types.EventNID, +) error { + _, err := common.TxStmt(txn, s.updateMembershipStmt).ExecContext( + ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID, + ) + return err +} diff --git a/roomserver/storage/sqlite3/prepare.go b/roomserver/storage/sqlite3/prepare.go new file mode 100644 index 000000000..482dfa2b9 --- /dev/null +++ b/roomserver/storage/sqlite3/prepare.go @@ -0,0 +1,36 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "database/sql" +) + +// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. +type statementList []struct { + statement **sql.Stmt + sql string +} + +// prepare the SQL for each statement in the list and assign the result to the prepared statement. +func (s statementList) prepare(db *sql.DB) (err error) { + for _, statement := range s { + if *statement.statement, err = db.Prepare(statement.sql); err != nil { + return + } + } + return +} diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go new file mode 100644 index 000000000..9ed64a38e --- /dev/null +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -0,0 +1,92 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const previousEventSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_previous_events ( + previous_event_id TEXT NOT NULL, + previous_reference_sha256 BLOB NOT NULL, + event_nids TEXT NOT NULL, + UNIQUE (previous_event_id, previous_reference_sha256) + ); +` + +// Insert an entry into the previous_events table. +// If there is already an entry indicating that an event references that previous event then +// add the event NID to the list to indicate that this event references that previous event as well. +// This should only be modified while holding a "FOR UPDATE" lock on the row in the rooms table for this room. +// The lock is necessary to avoid data races when checking whether an event is already referenced by another event. +const insertPreviousEventSQL = ` + INSERT OR REPLACE INTO roomserver_previous_events + (previous_event_id, previous_reference_sha256, event_nids) + VALUES ($1, $2, $3) +` + +// Check if the event is referenced by another event in the table. +// This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room. +const selectPreviousEventExistsSQL = ` + SELECT 1 FROM roomserver_previous_events + WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 +` + +type previousEventStatements struct { + insertPreviousEventStmt *sql.Stmt + selectPreviousEventExistsStmt *sql.Stmt +} + +func (s *previousEventStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(previousEventSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertPreviousEventStmt, insertPreviousEventSQL}, + {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, + }.prepare(db) +} + +func (s *previousEventStatements) insertPreviousEvent( + ctx context.Context, + txn *sql.Tx, + previousEventID string, + previousEventReferenceSHA256 []byte, + eventNID types.EventNID, +) error { + stmt := common.TxStmt(txn, s.insertPreviousEventStmt) + _, err := stmt.ExecContext( + ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), + ) + return err +} + +// Check if the event reference exists +// Returns sql.ErrNoRows if the event reference doesn't exist. +func (s *previousEventStatements) selectPreviousEventExists( + ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, +) error { + var ok int64 + stmt := common.TxStmt(txn, s.selectPreviousEventExistsStmt) + return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) +} diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go new file mode 100644 index 000000000..8c9498fd2 --- /dev/null +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -0,0 +1,128 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" +) + +const roomAliasesSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_room_aliases ( + alias TEXT NOT NULL PRIMARY KEY, + room_id TEXT NOT NULL, + creator_id TEXT NOT NULL + ); + + CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id); +` + +const insertRoomAliasSQL = ` + INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3) +` + +const selectRoomIDFromAliasSQL = ` + SELECT room_id FROM roomserver_room_aliases WHERE alias = $1 +` + +const selectAliasesFromRoomIDSQL = ` + SELECT alias FROM roomserver_room_aliases WHERE room_id = $1 +` + +const selectCreatorIDFromAliasSQL = ` + SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1 +` + +const deleteRoomAliasSQL = ` + DELETE FROM roomserver_room_aliases WHERE alias = $1 +` + +type roomAliasesStatements struct { + insertRoomAliasStmt *sql.Stmt + selectRoomIDFromAliasStmt *sql.Stmt + selectAliasesFromRoomIDStmt *sql.Stmt + selectCreatorIDFromAliasStmt *sql.Stmt + deleteRoomAliasStmt *sql.Stmt +} + +func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(roomAliasesSchema) + if err != nil { + return + } + return statementList{ + {&s.insertRoomAliasStmt, insertRoomAliasSQL}, + {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, + {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, + {&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL}, + {&s.deleteRoomAliasStmt, deleteRoomAliasSQL}, + }.prepare(db) +} + +func (s *roomAliasesStatements) insertRoomAlias( + ctx context.Context, alias string, roomID string, creatorUserID string, +) (err error) { + _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) + return +} + +func (s *roomAliasesStatements) selectRoomIDFromAlias( + ctx context.Context, alias string, +) (roomID string, err error) { + err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) + if err == sql.ErrNoRows { + return "", nil + } + return +} + +func (s *roomAliasesStatements) selectAliasesFromRoomID( + ctx context.Context, roomID string, +) (aliases []string, err error) { + aliases = []string{} + rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) + if err != nil { + return + } + + for rows.Next() { + var alias string + if err = rows.Scan(&alias); err != nil { + return + } + + aliases = append(aliases, alias) + } + + return +} + +func (s *roomAliasesStatements) selectCreatorIDFromAlias( + ctx context.Context, alias string, +) (creatorID string, err error) { + err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) + if err == sql.ErrNoRows { + return "", nil + } + return +} + +func (s *roomAliasesStatements) deleteRoomAlias( + ctx context.Context, alias string, +) (err error) { + _, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias) + return +} diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go new file mode 100644 index 000000000..5d1fe6810 --- /dev/null +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -0,0 +1,149 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const roomsSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_rooms ( + room_nid INTEGER PRIMARY KEY AUTOINCREMENT, + room_id TEXT NOT NULL UNIQUE, + latest_event_nids TEXT NOT NULL DEFAULT '', + last_event_sent_nid INTEGER NOT NULL DEFAULT 0, + state_snapshot_nid INTEGER NOT NULL DEFAULT 0 + ); +` + +// Same as insertEventTypeNIDSQL +const insertRoomNIDSQL = ` + INSERT INTO roomserver_rooms (room_id) VALUES ($1) + ON CONFLICT DO NOTHING; + SELECT room_nid FROM roomserver_rooms + WHERE rowid = last_insert_rowid(); +` + +const selectRoomNIDSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" + +const selectLatestEventNIDsSQL = "" + + "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" + +const selectLatestEventNIDsForUpdateSQL = "" + + "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" + +const updateLatestEventNIDsSQL = "" + + "UPDATE roomserver_rooms SET latest_event_nids = $2, last_event_sent_nid = $3, state_snapshot_nid = $4 WHERE room_nid = $1" + +type roomStatements struct { + insertRoomNIDStmt *sql.Stmt + selectRoomNIDStmt *sql.Stmt + selectLatestEventNIDsStmt *sql.Stmt + selectLatestEventNIDsForUpdateStmt *sql.Stmt + updateLatestEventNIDsStmt *sql.Stmt +} + +func (s *roomStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(roomsSchema) + if err != nil { + return + } + return statementList{ + {&s.insertRoomNIDStmt, insertRoomNIDSQL}, + {&s.selectRoomNIDStmt, selectRoomNIDSQL}, + {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, + {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, + {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, + }.prepare(db) +} + +func (s *roomStatements) insertRoomNID( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + var roomNID int64 + stmt := common.TxStmt(txn, s.insertRoomNIDStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err +} + +func (s *roomStatements) selectRoomNID( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + var roomNID int64 + stmt := common.TxStmt(txn, s.selectRoomNIDStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err +} + +func (s *roomStatements) selectLatestEventNIDs( + ctx context.Context, roomNID types.RoomNID, +) ([]types.EventNID, types.StateSnapshotNID, error) { + var nids pq.Int64Array + var stateSnapshotNID int64 + stmt := s.selectLatestEventNIDsStmt + err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID) + if err != nil { + return nil, 0, err + } + eventNIDs := make([]types.EventNID, len(nids)) + for i := range nids { + eventNIDs[i] = types.EventNID(nids[i]) + } + return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil +} + +func (s *roomStatements) selectLatestEventsNIDsForUpdate( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) { + var nids pq.Int64Array + var lastEventSentNID int64 + var stateSnapshotNID int64 + stmt := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt) + err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID) + if err != nil { + return nil, 0, 0, err + } + eventNIDs := make([]types.EventNID, len(nids)) + for i := range nids { + eventNIDs[i] = types.EventNID(nids[i]) + } + return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil +} + +func (s *roomStatements) updateLatestEventNIDs( + ctx context.Context, + txn *sql.Tx, + roomNID types.RoomNID, + eventNIDs []types.EventNID, + lastEventSentNID types.EventNID, + stateSnapshotNID types.StateSnapshotNID, +) error { + stmt := common.TxStmt(txn, s.updateLatestEventNIDsStmt) + _, err := stmt.ExecContext( + ctx, + roomNID, + eventNIDsAsArray(eventNIDs), + int64(lastEventSentNID), + int64(stateSnapshotNID), + ) + return err +} diff --git a/roomserver/storage/sqlite3/sql.go b/roomserver/storage/sqlite3/sql.go new file mode 100644 index 000000000..0d49432b8 --- /dev/null +++ b/roomserver/storage/sqlite3/sql.go @@ -0,0 +1,60 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "database/sql" +) + +type statements struct { + eventTypeStatements + eventStateKeyStatements + roomStatements + eventStatements + eventJSONStatements + stateSnapshotStatements + stateBlockStatements + previousEventStatements + roomAliasesStatements + inviteStatements + membershipStatements + transactionStatements +} + +func (s *statements) prepare(db *sql.DB) error { + var err error + + for _, prepare := range []func(db *sql.DB) error{ + s.eventTypeStatements.prepare, + s.eventStateKeyStatements.prepare, + s.roomStatements.prepare, + s.eventStatements.prepare, + s.eventJSONStatements.prepare, + s.stateSnapshotStatements.prepare, + s.stateBlockStatements.prepare, + s.previousEventStatements.prepare, + s.roomAliasesStatements.prepare, + s.inviteStatements.prepare, + s.membershipStatements.prepare, + s.transactionStatements.prepare, + } { + if err = prepare(db); err != nil { + return err + } + } + + return nil +} diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go new file mode 100644 index 000000000..223c6b998 --- /dev/null +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -0,0 +1,271 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "sort" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/util" +) + +const stateDataSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_state_block ( + state_block_nid INTEGER NOT NULL, + event_type_nid INTEGER NOT NULL, + event_state_key_nid INTEGER NOT NULL, + event_nid INTEGER NOT NULL, + UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) + ); +` + +const insertStateDataSQL = "" + + "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + + " VALUES ($1, $2, $3, $4)" + +const selectNextStateBlockNIDSQL = ` + SELECT seq+1 FROM sqlite_sequence WHERE name = 'roomserver_state_block' +` + +// Bulk state lookup by numeric state block ID. +// Sort by the state_block_nid, event_type_nid, event_state_key_nid +// This means that all the entries for a given state_block_nid will appear +// together in the list and those entries will sorted by event_type_nid +// and event_state_key_nid. This property makes it easier to merge two +// state data blocks together. +const bulkSelectStateBlockEntriesSQL = "" + + "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + + " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + + " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + +// Bulk state lookup by numeric state block ID. +// Filters the rows in each block to the requested types and state keys. +// We would like to restrict to particular type state key pairs but we are +// restricted by the query language to pull the cross product of a list +// of types and a list state_keys. So we have to filter the result in the +// application to restrict it to the list of event types and state keys we +// actually wanted. +const bulkSelectFilteredStateBlockEntriesSQL = "" + + "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + + " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + + " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" + + " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + +type stateBlockStatements struct { + insertStateDataStmt *sql.Stmt + selectNextStateBlockNIDStmt *sql.Stmt + bulkSelectStateBlockEntriesStmt *sql.Stmt + bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt +} + +func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(stateDataSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertStateDataStmt, insertStateDataSQL}, + {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, + {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, + {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, + }.prepare(db) +} + +func (s *stateBlockStatements) bulkInsertStateData( + ctx context.Context, + stateBlockNID types.StateBlockNID, + entries []types.StateEntry, +) error { + for _, entry := range entries { + _, err := s.insertStateDataStmt.ExecContext( + ctx, + int64(stateBlockNID), + int64(entry.EventTypeNID), + int64(entry.EventStateKeyNID), + int64(entry.EventNID), + ) + if err != nil { + return err + } + } + return nil +} + +func (s *stateBlockStatements) selectNextStateBlockNID( + ctx context.Context, +) (types.StateBlockNID, error) { + var stateBlockNID int64 + err := s.selectNextStateBlockNIDStmt.QueryRowContext(ctx).Scan(&stateBlockNID) + return types.StateBlockNID(stateBlockNID), err +} + +func (s *stateBlockStatements) bulkSelectStateBlockEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + nids := make([]int64, len(stateBlockNIDs)) + for i := range stateBlockNIDs { + nids[i] = int64(stateBlockNIDs[i]) + } + rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, pq.Int64Array(nids)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + + results := make([]types.StateEntryList, len(stateBlockNIDs)) + // current is a pointer to the StateEntryList to append the state entries to. + var current *types.StateEntryList + i := 0 + for rows.Next() { + var ( + stateBlockNID int64 + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + entry types.StateEntry + ) + if err := rows.Scan( + &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, + ); err != nil { + return nil, err + } + entry.EventTypeNID = types.EventTypeNID(eventTypeNID) + entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) + entry.EventNID = types.EventNID(eventNID) + if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID { + // The state entry row is for a different state data block to the current one. + // So we start appending to the next entry in the list. + current = &results[i] + current.StateBlockNID = types.StateBlockNID(stateBlockNID) + i++ + } + current.StateEntries = append(current.StateEntries, entry) + } + if i != len(stateBlockNIDs) { + return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs)) + } + return results, nil +} + +func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + tuples := stateKeyTupleSorter(stateKeyTuples) + // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db. + sort.Sort(tuples) + + eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() + rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.QueryContext( + ctx, + stateBlockNIDsAsArray(stateBlockNIDs), + eventTypeNIDArray, + eventStateKeyNIDArray, + ) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + + var results []types.StateEntryList + var current types.StateEntryList + for rows.Next() { + var ( + stateBlockNID int64 + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + entry types.StateEntry + ) + if err := rows.Scan( + &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, + ); err != nil { + return nil, err + } + entry.EventTypeNID = types.EventTypeNID(eventTypeNID) + entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) + entry.EventNID = types.EventNID(eventNID) + + // We can use binary search here because we sorted the tuples earlier + if !tuples.contains(entry.StateKeyTuple) { + // The select will return the cross product of types and state keys. + // So we need to check if type of the entry is in the list. + continue + } + + if types.StateBlockNID(stateBlockNID) != current.StateBlockNID { + // The state entry row is for a different state data block to the current one. + // So we append the current entry to the results and start adding to a new one. + // The first time through the loop current will be empty. + if current.StateEntries != nil { + results = append(results, current) + } + current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)} + } + current.StateEntries = append(current.StateEntries, entry) + } + // Add the last entry to the list if it is not empty. + if current.StateEntries != nil { + results = append(results, current) + } + return results, nil +} + +func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array { + nids := make([]int64, len(stateBlockNIDs)) + for i := range stateBlockNIDs { + nids[i] = int64(stateBlockNIDs[i]) + } + return pq.Int64Array(nids) +} + +type stateKeyTupleSorter []types.StateKeyTuple + +func (s stateKeyTupleSorter) Len() int { return len(s) } +func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } +func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// Check whether a tuple is in the list. Assumes that the list is sorted. +func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool { + i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) }) + return i < len(s) && s[i] == value +} + +// List the unique eventTypeNIDs and eventStateKeyNIDs. +// Assumes that the list is sorted. +func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) { + eventTypeNIDs = make(pq.Int64Array, len(s)) + eventStateKeyNIDs = make(pq.Int64Array, len(s)) + for i := range s { + eventTypeNIDs[i] = int64(s[i].EventTypeNID) + eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID) + } + eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))] + eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))] + return +} + +type int64Sorter []int64 + +func (s int64Sorter) Len() int { return len(s) } +func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] } +func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/storage/sqlite3/state_block_table_test.go b/roomserver/storage/sqlite3/state_block_table_test.go new file mode 100644 index 000000000..98439f5c0 --- /dev/null +++ b/roomserver/storage/sqlite3/state_block_table_test.go @@ -0,0 +1,86 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "sort" + "testing" + + "github.com/matrix-org/dendrite/roomserver/types" +) + +func TestStateKeyTupleSorter(t *testing.T) { + input := stateKeyTupleSorter{ + {EventTypeNID: 1, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 4}, + {EventTypeNID: 2, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 1}, + } + want := []types.StateKeyTuple{ + {EventTypeNID: 1, EventStateKeyNID: 1}, + {EventTypeNID: 1, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 4}, + {EventTypeNID: 2, EventStateKeyNID: 2}, + } + doNotWant := []types.StateKeyTuple{ + {EventTypeNID: 0, EventStateKeyNID: 0}, + {EventTypeNID: 1, EventStateKeyNID: 3}, + {EventTypeNID: 2, EventStateKeyNID: 1}, + {EventTypeNID: 3, EventStateKeyNID: 1}, + } + wantTypeNIDs := []int64{1, 2} + wantStateKeyNIDs := []int64{1, 2, 4} + + // Sort the input and check it's in the right order. + sort.Sort(input) + gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays() + + for i := range want { + if input[i] != want[i] { + t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i]) + } + + if !input.contains(want[i]) { + t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i]) + } + } + + for i := range doNotWant { + if input.contains(doNotWant[i]) { + t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i]) + } + } + + if len(wantTypeNIDs) != len(gotTypeNIDs) { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + + for i := range wantTypeNIDs { + if wantTypeNIDs[i] != gotTypeNIDs[i] { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + } + + if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) { + t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs) + } + + for i := range wantStateKeyNIDs { + if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + } +} diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go new file mode 100644 index 000000000..c69d772d0 --- /dev/null +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -0,0 +1,106 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const stateSnapshotSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( + state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, + room_nid INTEGER NOT NULL, + state_block_nids TEXT NOT NULL + ); +` + +const insertStateSQL = ` + INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids) + VALUES ($1, $2); + SELECT state_snapshot_nid FROM roomserver_state_snapshots + WHERE rowid = last_insert_rowid(); +` + +// Bulk state data NID lookup. +// Sorting by state_snapshot_nid means we can use binary search over the result +// to lookup the state data NIDs for a state snapshot NID. +const bulkSelectStateBlockNIDsSQL = "" + + "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + + " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC" + +type stateSnapshotStatements struct { + insertStateStmt *sql.Stmt + bulkSelectStateBlockNIDsStmt *sql.Stmt +} + +func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(stateSnapshotSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertStateStmt, insertStateSQL}, + {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, + }.prepare(db) +} + +func (s *stateSnapshotStatements) insertState( + ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, +) (stateNID types.StateSnapshotNID, err error) { + nids := make([]int64, len(stateBlockNIDs)) + for i := range stateBlockNIDs { + nids[i] = int64(stateBlockNIDs[i]) + } + err = s.insertStateStmt.QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) + return +} + +func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + nids := make([]int64, len(stateNIDs)) + for i := range stateNIDs { + nids[i] = int64(stateNIDs[i]) + } + rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make([]types.StateBlockNIDList, len(stateNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + var stateBlockNIDs pq.Int64Array + if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil { + return nil, err + } + result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs)) + for k := range stateBlockNIDs { + result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k]) + } + } + if i != len(stateNIDs) { + return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs)) + } + return results, nil +} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go new file mode 100644 index 000000000..5bff7e94f --- /dev/null +++ b/roomserver/storage/sqlite3/storage.go @@ -0,0 +1,723 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "net/url" + + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + _ "github.com/mattn/go-sqlite3" +) + +// A Database is used to store room events and stream offsets. +type Database struct { + statements statements + db *sql.DB +} + +// Open a postgres database. +func Open(dataSourceName string) (*Database, error) { + var d Database + uri, err := url.Parse(dataSourceName) + if err != nil { + return nil, err + } + if uri.Opaque != "" { // file:filename.db + if d.db, err = sql.Open("sqlite3", uri.Opaque); err != nil { + return nil, err + } + } + if uri.Path != "" { // file:///path/to/filename.db + if d.db, err = sql.Open("sqlite3", uri.Path); err != nil { + return nil, err + } + } + if err = d.statements.prepare(d.db); err != nil { + return nil, err + } + return &d, nil +} + +// StoreEvent implements input.EventDatabase +func (d *Database) StoreEvent( + ctx context.Context, event gomatrixserverlib.Event, + txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, +) (types.RoomNID, types.StateAtEvent, error) { + var ( + roomNID types.RoomNID + eventTypeNID types.EventTypeNID + eventStateKeyNID types.EventStateKeyNID + eventNID types.EventNID + stateNID types.StateSnapshotNID + err error + ) + + if txnAndSessionID != nil { + if err = d.statements.insertTransaction( + ctx, txnAndSessionID.TransactionID, + txnAndSessionID.SessionID, event.Sender(), event.EventID(), + ); err != nil { + return 0, types.StateAtEvent{}, err + } + } + + if roomNID, err = d.assignRoomNID(ctx, nil, event.RoomID()); err != nil { + return 0, types.StateAtEvent{}, err + } + + if eventTypeNID, err = d.assignEventTypeNID(ctx, event.Type()); err != nil { + return 0, types.StateAtEvent{}, err + } + + eventStateKey := event.StateKey() + // Assigned a numeric ID for the state_key if there is one present. + // Otherwise set the numeric ID for the state_key to 0. + if eventStateKey != nil { + if eventStateKeyNID, err = d.assignStateKeyNID(ctx, nil, *eventStateKey); err != nil { + return 0, types.StateAtEvent{}, err + } + } + + if eventNID, stateNID, err = d.statements.insertEvent( + ctx, + roomNID, + eventTypeNID, + eventStateKeyNID, + event.EventID(), + event.EventReference().EventSHA256, + authEventNIDs, + event.Depth(), + ); err != nil { + if err == sql.ErrNoRows { + // We've already inserted the event so select the numeric event ID + eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID()) + } + if err != nil { + return 0, types.StateAtEvent{}, err + } + } + + if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil { + return 0, types.StateAtEvent{}, err + } + + return roomNID, types.StateAtEvent{ + BeforeStateSnapshotNID: stateNID, + StateEntry: types.StateEntry{ + StateKeyTuple: types.StateKeyTuple{ + EventTypeNID: eventTypeNID, + EventStateKeyNID: eventStateKeyNID, + }, + EventNID: eventNID, + }, + }, nil +} + +func (d *Database) assignRoomNID( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + // Check if we already have a numeric ID in the database. + roomNID, err := d.statements.selectRoomNID(ctx, txn, roomID) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) + } + } + return roomNID, err +} + +func (d *Database) assignEventTypeNID( + ctx context.Context, eventType string, +) (types.EventTypeNID, error) { + // Check if we already have a numeric ID in the database. + eventTypeNID, err := d.statements.selectEventTypeNID(ctx, eventType) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + eventTypeNID, err = d.statements.insertEventTypeNID(ctx, eventType) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + eventTypeNID, err = d.statements.selectEventTypeNID(ctx, eventType) + } + } + return eventTypeNID, err +} + +func (d *Database) assignStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (types.EventStateKeyNID, error) { + // Check if we already have a numeric ID in the database. + eventStateKeyNID, err := d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) + } + } + return eventStateKeyNID, err +} + +// StateEntriesForEventIDs implements input.EventDatabase +func (d *Database) StateEntriesForEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateEntry, error) { + return d.statements.bulkSelectStateEventByID(ctx, eventIDs) +} + +// EventTypeNIDs implements state.RoomStateDatabase +func (d *Database) EventTypeNIDs( + ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + return d.statements.bulkSelectEventTypeNID(ctx, eventTypes) +} + +// EventStateKeyNIDs implements state.RoomStateDatabase +func (d *Database) EventStateKeyNIDs( + ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + return d.statements.bulkSelectEventStateKeyNID(ctx, eventStateKeys) +} + +// EventStateKeys implements query.RoomserverQueryAPIDatabase +func (d *Database) EventStateKeys( + ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, +) (map[types.EventStateKeyNID]string, error) { + return d.statements.bulkSelectEventStateKey(ctx, eventStateKeyNIDs) +} + +// EventNIDs implements query.RoomserverQueryAPIDatabase +func (d *Database) EventNIDs( + ctx context.Context, eventIDs []string, +) (map[string]types.EventNID, error) { + return d.statements.bulkSelectEventNID(ctx, eventIDs) +} + +// Events implements input.EventDatabase +func (d *Database) Events( + ctx context.Context, eventNIDs []types.EventNID, +) ([]types.Event, error) { + eventJSONs, err := d.statements.bulkSelectEventJSON(ctx, eventNIDs) + if err != nil { + return nil, err + } + results := make([]types.Event, len(eventJSONs)) + for i, eventJSON := range eventJSONs { + result := &results[i] + result.EventNID = eventJSON.EventNID + // TODO: Use NewEventFromTrustedJSON for efficiency + result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON(eventJSON.EventJSON) + if err != nil { + return nil, err + } + } + return results, nil +} + +// AddState implements input.EventDatabase +func (d *Database) AddState( + ctx context.Context, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, +) (types.StateSnapshotNID, error) { + if len(state) > 0 { + stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx) + if err != nil { + return 0, err + } + if err = d.statements.bulkInsertStateData(ctx, stateBlockNID, state); err != nil { + return 0, err + } + stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) + } + + return d.statements.insertState(ctx, roomNID, stateBlockNIDs) +} + +// SetState implements input.EventDatabase +func (d *Database) SetState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, +) error { + return d.statements.updateEventState(ctx, eventNID, stateNID) +} + +// StateAtEventIDs implements input.EventDatabase +func (d *Database) StateAtEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateAtEvent, error) { + return d.statements.bulkSelectStateAtEventByID(ctx, eventIDs) +} + +// StateBlockNIDs implements state.RoomStateDatabase +func (d *Database) StateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + return d.statements.bulkSelectStateBlockNIDs(ctx, stateNIDs) +} + +// StateEntries implements state.RoomStateDatabase +func (d *Database) StateEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs) +} + +// SnapshotNIDFromEventID implements state.RoomStateDatabase +func (d *Database) SnapshotNIDFromEventID( + ctx context.Context, eventID string, +) (types.StateSnapshotNID, error) { + _, stateNID, err := d.statements.selectEvent(ctx, eventID) + return stateNID, err +} + +// EventIDs implements input.RoomEventDatabase +func (d *Database) EventIDs( + ctx context.Context, eventNIDs []types.EventNID, +) (map[types.EventNID]string, error) { + return d.statements.bulkSelectEventID(ctx, eventNIDs) +} + +// GetLatestEventsForUpdate implements input.EventDatabase +func (d *Database) GetLatestEventsForUpdate( + ctx context.Context, roomNID types.RoomNID, +) (types.RoomRecentEventsUpdater, error) { + txn, err := d.db.Begin() + if err != nil { + return nil, err + } + eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := + d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + var lastEventIDSent string + if lastEventNIDSent != 0 { + lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + } + return &roomRecentEventsUpdater{ + transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + }, nil +} + +// GetTransactionEventID implements input.EventDatabase +func (d *Database) GetTransactionEventID( + ctx context.Context, transactionID string, + sessionID int64, userID string, +) (string, error) { + eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, sessionID, userID) + if err == sql.ErrNoRows { + return "", nil + } + return eventID, err +} + +type roomRecentEventsUpdater struct { + transaction + d *Database + roomNID types.RoomNID + latestEvents []types.StateAtEventAndReference + lastEventIDSent string + currentStateSnapshotNID types.StateSnapshotNID +} + +// LatestEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference { + return u.latestEvents +} + +// LastEventIDSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) LastEventIDSent() string { + return u.lastEventIDSent +} + +// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { + return u.currentStateSnapshotNID +} + +// StorePreviousEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { + for _, ref := range previousEventReferences { + if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return err + } + } + return nil +} + +// IsReferenced implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { + err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) + if err == nil { + return true, nil + } + if err == sql.ErrNoRows { + return false, nil + } + return false, err +} + +// SetLatestEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) SetLatestEvents( + roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, + currentStateSnapshotNID types.StateSnapshotNID, +) error { + eventNIDs := make([]types.EventNID, len(latest)) + for i := range latest { + eventNIDs[i] = latest[i].EventNID + } + return u.d.statements.updateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) +} + +// HasEventBeenSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { + return u.d.statements.selectEventSentToOutput(u.ctx, u.txn, eventNID) +} + +// MarkEventAsSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { + return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID) +} + +func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) { + return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID) +} + +// RoomNID implements query.RoomserverQueryAPIDB +func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) { + roomNID, err := d.statements.selectRoomNID(ctx, nil, roomID) + if err == sql.ErrNoRows { + return 0, nil + } + return roomNID, err +} + +// LatestEventIDs implements query.RoomserverQueryAPIDatabase +func (d *Database) LatestEventIDs( + ctx context.Context, roomNID types.RoomNID, +) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) { + eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(ctx, roomNID) + if err != nil { + return nil, 0, 0, err + } + references, err := d.statements.bulkSelectEventReference(ctx, eventNIDs) + if err != nil { + return nil, 0, 0, err + } + depth, err := d.statements.selectMaxEventDepth(ctx, eventNIDs) + if err != nil { + return nil, 0, 0, err + } + return references, currentStateSnapshotNID, depth, nil +} + +// GetInvitesForUser implements query.RoomserverQueryAPIDatabase +func (d *Database) GetInvitesForUser( + ctx context.Context, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, +) (senderUserIDs []types.EventStateKeyNID, err error) { + return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) +} + +// SetRoomAlias implements alias.RoomserverAliasAPIDB +func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { + return d.statements.insertRoomAlias(ctx, alias, roomID, creatorUserID) +} + +// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB +func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { + return d.statements.selectRoomIDFromAlias(ctx, alias) +} + +// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB +func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { + return d.statements.selectAliasesFromRoomID(ctx, roomID) +} + +// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB +func (d *Database) GetCreatorIDForAlias( + ctx context.Context, alias string, +) (string, error) { + return d.statements.selectCreatorIDFromAlias(ctx, alias) +} + +// RemoveRoomAlias implements alias.RoomserverAliasAPIDB +func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { + return d.statements.deleteRoomAlias(ctx, alias) +} + +// StateEntriesForTuples implements state.RoomStateDatabase +func (d *Database) StateEntriesForTuples( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + return d.statements.bulkSelectFilteredStateBlockEntries( + ctx, stateBlockNIDs, stateKeyTuples, + ) +} + +// MembershipUpdater implements input.RoomEventDatabase +func (d *Database) MembershipUpdater( + ctx context.Context, roomID, targetUserID string, +) (types.MembershipUpdater, error) { + txn, err := d.db.Begin() + if err != nil { + return nil, err + } + succeeded := false + defer func() { + if !succeeded { + txn.Rollback() // nolint: errcheck + } + }() + + roomNID, err := d.assignRoomNID(ctx, txn, roomID) + if err != nil { + return nil, err + } + + targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID) + if err != nil { + return nil, err + } + + updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) + if err != nil { + return nil, err + } + + succeeded = true + return updater, nil +} + +type membershipUpdater struct { + transaction + d *Database + roomNID types.RoomNID + targetUserNID types.EventStateKeyNID + membership membershipState +} + +func (d *Database) membershipUpdaterTxn( + ctx context.Context, + txn *sql.Tx, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, +) (types.MembershipUpdater, error) { + + if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil { + return nil, err + } + + membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) + if err != nil { + return nil, err + } + + return &membershipUpdater{ + transaction{ctx, txn}, d, roomNID, targetUserNID, membership, + }, nil +} + +// IsInvite implements types.MembershipUpdater +func (u *membershipUpdater) IsInvite() bool { + return u.membership == membershipStateInvite +} + +// IsJoin implements types.MembershipUpdater +func (u *membershipUpdater) IsJoin() bool { + return u.membership == membershipStateJoin +} + +// IsLeave implements types.MembershipUpdater +func (u *membershipUpdater) IsLeave() bool { + return u.membership == membershipStateLeaveOrBan +} + +// SetToInvite implements types.MembershipUpdater +func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) + if err != nil { + return false, err + } + inserted, err := u.d.statements.insertInviteEvent( + u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), + ) + if err != nil { + return false, err + } + if u.membership != membershipStateInvite { + if err = u.d.statements.updateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, + ); err != nil { + return false, err + } + } + return inserted, nil +} + +// SetToJoin implements types.MembershipUpdater +func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { + var inviteEventIDs []string + + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) + if err != nil { + return nil, err + } + + // If this is a join event update, there is no invite to update + if !isUpdate { + inviteEventIDs, err = u.d.statements.updateInviteRetired( + u.ctx, u.txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return nil, err + } + } + + // Look up the NID of the new join event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return nil, err + } + + if u.membership != membershipStateJoin || isUpdate { + if err = u.d.statements.updateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, + membershipStateJoin, nIDs[eventID], + ); err != nil { + return nil, err + } + } + + return inviteEventIDs, nil +} + +// SetToLeave implements types.MembershipUpdater +func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) + if err != nil { + return nil, err + } + inviteEventIDs, err := u.d.statements.updateInviteRetired( + u.ctx, u.txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return nil, err + } + + // Look up the NID of the new leave event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return nil, err + } + + if u.membership != membershipStateLeaveOrBan { + if err = u.d.statements.updateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, + membershipStateLeaveOrBan, nIDs[eventID], + ); err != nil { + return nil, err + } + } + return inviteEventIDs, nil +} + +// GetMembership implements query.RoomserverQueryAPIDB +func (d *Database) GetMembership( + ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, +) (membershipEventNID types.EventNID, stillInRoom bool, err error) { + requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID) + if err != nil { + return + } + + senderMembershipEventNID, senderMembership, err := + d.statements.selectMembershipFromRoomAndTarget( + ctx, roomNID, requestSenderUserNID, + ) + if err == sql.ErrNoRows { + // The user has never been a member of that room + return 0, false, nil + } else if err != nil { + return + } + + return senderMembershipEventNID, senderMembership == membershipStateJoin, nil +} + +// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB +func (d *Database) GetMembershipEventNIDsForRoom( + ctx context.Context, roomNID types.RoomNID, joinOnly bool, +) ([]types.EventNID, error) { + if joinOnly { + return d.statements.selectMembershipsFromRoomAndMembership( + ctx, roomNID, membershipStateJoin, + ) + } + + return d.statements.selectMembershipsFromRoom(ctx, roomNID) +} + +// EventsFromIDs implements query.RoomserverQueryAPIEventDB +func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { + nidMap, err := d.EventNIDs(ctx, eventIDs) + if err != nil { + return nil, err + } + + var nids []types.EventNID + for _, nid := range nidMap { + nids = append(nids, nid) + } + + return d.Events(ctx, nids) +} + +type transaction struct { + ctx context.Context + txn *sql.Tx +} + +// Commit implements types.Transaction +func (t *transaction) Commit() error { + return t.txn.Commit() +} + +// Rollback implements types.Transaction +func (t *transaction) Rollback() error { + return t.txn.Rollback() +} diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go new file mode 100644 index 000000000..4bb82a210 --- /dev/null +++ b/roomserver/storage/sqlite3/transactions_table.go @@ -0,0 +1,82 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" +) + +const transactionsSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_transactions ( + transaction_id TEXT NOT NULL, + session_id INTEGER NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + PRIMARY KEY (transaction_id, session_id, user_id) + ); +` +const insertTransactionSQL = ` + INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id) + VALUES ($1, $2, $3, $4) +` + +const selectTransactionEventIDSQL = ` + SELECT event_id FROM roomserver_transactions + WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3 +` + +type transactionStatements struct { + insertTransactionStmt *sql.Stmt + selectTransactionEventIDStmt *sql.Stmt +} + +func (s *transactionStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(transactionsSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertTransactionStmt, insertTransactionSQL}, + {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, + }.prepare(db) +} + +func (s *transactionStatements) insertTransaction( + ctx context.Context, + transactionID string, + sessionID int64, + userID string, + eventID string, +) (err error) { + _, err = s.insertTransactionStmt.ExecContext( + ctx, transactionID, sessionID, userID, eventID, + ) + return +} + +func (s *transactionStatements) selectTransactionEventID( + ctx context.Context, + transactionID string, + sessionID int64, + userID string, +) (eventID string, err error) { + err = s.selectTransactionEventIDStmt.QueryRowContext( + ctx, transactionID, sessionID, userID, + ).Scan(&eventID) + return +} diff --git a/roomserver/storage/storage.go b/roomserver/storage/storage.go index 325d96e99..c3bd2d0ae 100644 --- a/roomserver/storage/storage.go +++ b/roomserver/storage/storage.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -66,6 +67,8 @@ func Open(dataSourceName string) (Database, error) { switch uri.Scheme { case "postgres": return postgres.Open(dataSourceName) + case "file": + return sqlite3.Open(dataSourceName) default: return nil, errors.New("unknown schema") }