From 4b42a0addb37981bfdc58238bc801460643f2733 Mon Sep 17 00:00:00 2001 From: Alex Flatow Date: Thu, 6 May 2021 14:41:27 +1000 Subject: [PATCH 1/2] storage.go --- .../cosmosdb/appservice_events_table.go | 267 +++++++++ appservice/storage/cosmosdb/storage.go | 119 ++++ .../storage/cosmosdb/txn_id_counter_table.go | 69 +++ appservice/storage/storage.go | 3 + appservice/storage/storage_wasm.go | 2 + dendrite-config-cosmosdb.yaml | 394 ++++++++++++++ .../storage/cosmosdb/blacklist_table.go | 107 ++++ .../storage/cosmosdb/inbound_peeks_table.go | 176 ++++++ .../storage/cosmosdb/joined_hosts_table.go | 219 ++++++++ .../storage/cosmosdb/outbound_peeks_table.go | 176 ++++++ .../storage/cosmosdb/queue_edus_table.go | 207 +++++++ .../storage/cosmosdb/queue_json_table.go | 136 +++++ .../storage/cosmosdb/queue_pdus_table.go | 235 ++++++++ federationsender/storage/cosmosdb/storage.go | 95 ++++ federationsender/storage/storage.go | 3 + federationsender/storage/storage_wasm.go | 2 + internal/cosmosdbutil/connection.go | 12 + internal/sqlutil/migrate.go | 7 +- internal/sqlutil/trace.go | 5 + .../storage/cosmosdb/device_keys_table.go | 199 +++++++ .../storage/cosmosdb/key_changes_table.go | 104 ++++ .../storage/cosmosdb/one_time_keys_table.go | 203 +++++++ .../storage/cosmosdb/stale_device_lists.go | 121 ++++ keyserver/storage/cosmosdb/storage.go | 52 ++ keyserver/storage/storage.go | 3 + keyserver/storage/storage_wasm.go | 2 + .../cosmosdb/media_repository_table.go | 150 +++++ mediaapi/storage/cosmosdb/prepare.go | 38 ++ mediaapi/storage/cosmosdb/sql.go | 38 ++ mediaapi/storage/cosmosdb/storage.go | 124 +++++ mediaapi/storage/cosmosdb/thumbnail_table.go | 171 ++++++ mediaapi/storage/storage.go | 3 + mediaapi/storage/storage_wasm.go | 2 + .../storage/cosmosdb/event_json_table.go | 107 ++++ .../cosmosdb/event_state_keys_table.go | 163 ++++++ .../storage/cosmosdb/event_types_table.go | 161 ++++++ roomserver/storage/cosmosdb/events_table.go | 515 ++++++++++++++++++ roomserver/storage/cosmosdb/invite_table.go | 159 ++++++ .../storage/cosmosdb/membership_table.go | 306 +++++++++++ .../storage/cosmosdb/previous_events_table.go | 131 +++++ .../storage/cosmosdb/published_table.go | 105 ++++ .../storage/cosmosdb/redactions_table.go | 123 +++++ .../storage/cosmosdb/room_aliases_table.go | 141 +++++ roomserver/storage/cosmosdb/rooms_table.go | 296 ++++++++++ .../storage/cosmosdb/state_block_table.go | 289 ++++++++++ .../cosmosdb/state_block_table_test.go | 86 +++ .../storage/cosmosdb/state_snapshot_table.go | 126 +++++ roomserver/storage/cosmosdb/storage.go | 187 +++++++ .../storage/cosmosdb/transactions_table.go | 91 ++++ roomserver/storage/storage.go | 3 + roomserver/storage/storage_wasm.go | 2 + setup/config/config.go | 5 + setup/mscs/msc2836/storage.go | 83 +++ setup/mscs/msc2946/storage.go | 43 ++ signingkeyserver/storage/cosmosdb/keydb.go | 99 ++++ .../storage/cosmosdb/server_key_table.go | 159 ++++++ signingkeyserver/storage/keydb.go | 3 + .../storage/cosmosdb/account_data_table.go | 156 ++++++ .../cosmosdb/backwards_extremities_table.go | 125 +++++ .../cosmosdb/current_room_state_table.go | 324 +++++++++++ syncapi/storage/cosmosdb/filter_table.go | 140 +++++ syncapi/storage/cosmosdb/filtering.go | 82 +++ syncapi/storage/cosmosdb/invites_table.go | 185 +++++++ syncapi/storage/cosmosdb/memberships_table.go | 119 ++++ .../cosmosdb/output_room_events_table.go | 477 ++++++++++++++++ .../output_room_events_topology_table.go | 179 ++++++ syncapi/storage/cosmosdb/peeks_table.go | 206 +++++++ syncapi/storage/cosmosdb/receipt_table.go | 141 +++++ .../storage/cosmosdb/send_to_device_table.go | 160 ++++++ syncapi/storage/cosmosdb/stream_id_table.go | 94 ++++ syncapi/storage/cosmosdb/syncserver.go | 128 +++++ syncapi/storage/storage.go | 3 + syncapi/storage/storage_wasm.go | 2 + .../accounts/cosmosdb/account_data_table.go | 134 +++++ .../accounts/cosmosdb/accounts_table.go | 187 +++++++ .../accounts/cosmosdb/constraint_wasm.go | 21 + .../storage/accounts/cosmosdb/openid_table.go | 86 +++ .../accounts/cosmosdb/profile_table.go | 143 +++++ userapi/storage/accounts/cosmosdb/storage.go | 408 ++++++++++++++ .../accounts/cosmosdb/threepid_table.go | 133 +++++ userapi/storage/accounts/storage.go | 3 + userapi/storage/accounts/storage_wasm.go | 2 + .../storage/devices/cosmosdb/devices_table.go | 322 +++++++++++ userapi/storage/devices/cosmosdb/storage.go | 214 ++++++++ userapi/storage/devices/storage.go | 3 + userapi/storage/devices/storage_wasm.go | 2 + 86 files changed, 11005 insertions(+), 1 deletion(-) create mode 100644 appservice/storage/cosmosdb/appservice_events_table.go create mode 100644 appservice/storage/cosmosdb/storage.go create mode 100644 appservice/storage/cosmosdb/txn_id_counter_table.go create mode 100644 dendrite-config-cosmosdb.yaml create mode 100644 federationsender/storage/cosmosdb/blacklist_table.go create mode 100644 federationsender/storage/cosmosdb/inbound_peeks_table.go create mode 100644 federationsender/storage/cosmosdb/joined_hosts_table.go create mode 100644 federationsender/storage/cosmosdb/outbound_peeks_table.go create mode 100644 federationsender/storage/cosmosdb/queue_edus_table.go create mode 100644 federationsender/storage/cosmosdb/queue_json_table.go create mode 100644 federationsender/storage/cosmosdb/queue_pdus_table.go create mode 100644 federationsender/storage/cosmosdb/storage.go create mode 100644 internal/cosmosdbutil/connection.go create mode 100644 keyserver/storage/cosmosdb/device_keys_table.go create mode 100644 keyserver/storage/cosmosdb/key_changes_table.go create mode 100644 keyserver/storage/cosmosdb/one_time_keys_table.go create mode 100644 keyserver/storage/cosmosdb/stale_device_lists.go create mode 100644 keyserver/storage/cosmosdb/storage.go create mode 100644 mediaapi/storage/cosmosdb/media_repository_table.go create mode 100644 mediaapi/storage/cosmosdb/prepare.go create mode 100644 mediaapi/storage/cosmosdb/sql.go create mode 100644 mediaapi/storage/cosmosdb/storage.go create mode 100644 mediaapi/storage/cosmosdb/thumbnail_table.go create mode 100644 roomserver/storage/cosmosdb/event_json_table.go create mode 100644 roomserver/storage/cosmosdb/event_state_keys_table.go create mode 100644 roomserver/storage/cosmosdb/event_types_table.go create mode 100644 roomserver/storage/cosmosdb/events_table.go create mode 100644 roomserver/storage/cosmosdb/invite_table.go create mode 100644 roomserver/storage/cosmosdb/membership_table.go create mode 100644 roomserver/storage/cosmosdb/previous_events_table.go create mode 100644 roomserver/storage/cosmosdb/published_table.go create mode 100644 roomserver/storage/cosmosdb/redactions_table.go create mode 100644 roomserver/storage/cosmosdb/room_aliases_table.go create mode 100644 roomserver/storage/cosmosdb/rooms_table.go create mode 100644 roomserver/storage/cosmosdb/state_block_table.go create mode 100644 roomserver/storage/cosmosdb/state_block_table_test.go create mode 100644 roomserver/storage/cosmosdb/state_snapshot_table.go create mode 100644 roomserver/storage/cosmosdb/storage.go create mode 100644 roomserver/storage/cosmosdb/transactions_table.go create mode 100644 signingkeyserver/storage/cosmosdb/keydb.go create mode 100644 signingkeyserver/storage/cosmosdb/server_key_table.go create mode 100644 syncapi/storage/cosmosdb/account_data_table.go create mode 100644 syncapi/storage/cosmosdb/backwards_extremities_table.go create mode 100644 syncapi/storage/cosmosdb/current_room_state_table.go create mode 100644 syncapi/storage/cosmosdb/filter_table.go create mode 100644 syncapi/storage/cosmosdb/filtering.go create mode 100644 syncapi/storage/cosmosdb/invites_table.go create mode 100644 syncapi/storage/cosmosdb/memberships_table.go create mode 100644 syncapi/storage/cosmosdb/output_room_events_table.go create mode 100644 syncapi/storage/cosmosdb/output_room_events_topology_table.go create mode 100644 syncapi/storage/cosmosdb/peeks_table.go create mode 100644 syncapi/storage/cosmosdb/receipt_table.go create mode 100644 syncapi/storage/cosmosdb/send_to_device_table.go create mode 100644 syncapi/storage/cosmosdb/stream_id_table.go create mode 100644 syncapi/storage/cosmosdb/syncserver.go create mode 100644 userapi/storage/accounts/cosmosdb/account_data_table.go create mode 100644 userapi/storage/accounts/cosmosdb/accounts_table.go create mode 100644 userapi/storage/accounts/cosmosdb/constraint_wasm.go create mode 100644 userapi/storage/accounts/cosmosdb/openid_table.go create mode 100644 userapi/storage/accounts/cosmosdb/profile_table.go create mode 100644 userapi/storage/accounts/cosmosdb/storage.go create mode 100644 userapi/storage/accounts/cosmosdb/threepid_table.go create mode 100644 userapi/storage/devices/cosmosdb/devices_table.go create mode 100644 userapi/storage/devices/cosmosdb/storage.go diff --git a/appservice/storage/cosmosdb/appservice_events_table.go b/appservice/storage/cosmosdb/appservice_events_table.go new file mode 100644 index 000000000..f69940870 --- /dev/null +++ b/appservice/storage/cosmosdb/appservice_events_table.go @@ -0,0 +1,267 @@ +// Copyright 2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "encoding/json" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +const appserviceEventsSchema = ` +-- Stores events to be sent to application services +CREATE TABLE IF NOT EXISTS appservice_events ( + -- An auto-incrementing id unique to each event in the table + id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The ID of the application service the event will be sent to + as_id TEXT NOT NULL, + -- JSON representation of the event + headered_event_json TEXT NOT NULL, + -- The ID of the transaction that this event is a part of + txn_id INTEGER NOT NULL +); + +CREATE INDEX IF NOT EXISTS appservice_events_as_id ON appservice_events(as_id); +` + +const selectEventsByApplicationServiceIDSQL = "" + + "SELECT id, headered_event_json, txn_id " + + "FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC" + +const countEventsByApplicationServiceIDSQL = "" + + "SELECT COUNT(id) FROM appservice_events WHERE as_id = $1" + +const insertEventSQL = "" + + "INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " + + "VALUES ($1, $2, $3)" + +const updateTxnIDForEventsSQL = "" + + "UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3" + +const deleteEventsBeforeAndIncludingIDSQL = "" + + "DELETE FROM appservice_events WHERE as_id = $1 AND id <= $2" + +const ( + // A transaction ID number that no transaction should ever have. Used for + // checking again the default value. + invalidTxnID = -2 +) + +type eventsStatements struct { + db *sql.DB + writer sqlutil.Writer + selectEventsByApplicationServiceIDStmt *sql.Stmt + countEventsByApplicationServiceIDStmt *sql.Stmt + insertEventStmt *sql.Stmt + updateTxnIDForEventsStmt *sql.Stmt + deleteEventsBeforeAndIncludingIDStmt *sql.Stmt +} + +func (s *eventsStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { + s.db = db + s.writer = writer + _, err = db.Exec(appserviceEventsSchema) + if err != nil { + return + } + + if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil { + return + } + if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil { + return + } + if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { + return + } + if s.updateTxnIDForEventsStmt, err = db.Prepare(updateTxnIDForEventsSQL); err != nil { + return + } + if s.deleteEventsBeforeAndIncludingIDStmt, err = db.Prepare(deleteEventsBeforeAndIncludingIDSQL); err != nil { + return + } + + return +} + +// selectEventsByApplicationServiceID takes in an application service ID and +// returns a slice of events that need to be sent to that application service, +// as well as an int later used to remove these same events from the database +// once successfully sent to an application service. +func (s *eventsStatements) selectEventsByApplicationServiceID( + ctx context.Context, + applicationServiceID string, + limit int, +) ( + txnID, maxID int, + events []gomatrixserverlib.HeaderedEvent, + eventsRemaining bool, + err error, +) { + defer func() { + if err != nil { + log.WithFields(log.Fields{ + "appservice": applicationServiceID, + }).WithError(err).Fatalf("appservice unable to select new events to send") + } + }() + // Retrieve events from the database. Unsuccessfully sent events first + eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID) + if err != nil { + return + } + defer checkNamedErr(eventRows.Close, &err) + events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit) + if err != nil { + return + } + + return +} + +// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil +func checkNamedErr(fn func() error, err *error) { + if e := fn(); e != nil && *err == nil { + *err = e + } +} + +func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) { + // Get current time for use in calculating event age + nowMilli := time.Now().UnixNano() / int64(time.Millisecond) + + // Iterate through each row and store event contents + // If txn_id changes dramatically, we've switched from collecting old events to + // new ones. Send back those events first. + lastTxnID := invalidTxnID + for eventsProcessed := 0; eventRows.Next(); { + var event gomatrixserverlib.HeaderedEvent + var eventJSON []byte + var id int + err = eventRows.Scan( + &id, + &eventJSON, + &txnID, + ) + if err != nil { + return nil, 0, 0, false, err + } + + // Unmarshal eventJSON + if err = json.Unmarshal(eventJSON, &event); err != nil { + return nil, 0, 0, false, err + } + + // If txnID has changed on this event from the previous event, then we've + // reached the end of a transaction's events. Return only those events. + if lastTxnID > invalidTxnID && lastTxnID != txnID { + return events, maxID, lastTxnID, true, nil + } + lastTxnID = txnID + + // Limit events that aren't part of an old transaction + if txnID == -1 { + // Return if we've hit the limit + if eventsProcessed++; eventsProcessed > limit { + return events, maxID, lastTxnID, true, nil + } + } + + if id > maxID { + maxID = id + } + + // Portion of the event that is unsigned due to rapid change + // TODO: Consider removing age as not many app services use it + if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil { + return nil, 0, 0, false, err + } + + events = append(events, event) + } + + return +} + +// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service +// IDs into the db. +func (s *eventsStatements) countEventsByApplicationServiceID( + ctx context.Context, + appServiceID string, +) (int, error) { + var count int + err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count) + if err != nil && err != sql.ErrNoRows { + return 0, err + } + + return count, nil +} + +// insertEvent inserts an event mapped to its corresponding application service +// IDs into the db. +func (s *eventsStatements) insertEvent( + ctx context.Context, + appServiceID string, + event *gomatrixserverlib.HeaderedEvent, +) (err error) { + // Convert event to JSON before inserting + eventJSON, err := json.Marshal(event) + if err != nil { + return err + } + + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + _, err := s.insertEventStmt.ExecContext( + ctx, + appServiceID, + eventJSON, + -1, // No transaction ID yet + ) + return err + }) +} + +// updateTxnIDForEvents sets the transactionID for a collection of events. Done +// before sending them to an AppService. Referenced before sending to make sure +// we aren't constructing multiple transactions with the same events. +func (s *eventsStatements) updateTxnIDForEvents( + ctx context.Context, + appserviceID string, + maxID, txnID int, +) (err error) { + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + _, err := s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID) + return err + }) +} + +// deleteEventsBeforeAndIncludingID removes events matching given IDs from the database. +func (s *eventsStatements) deleteEventsBeforeAndIncludingID( + ctx context.Context, + appserviceID string, + eventTableID int, +) (err error) { + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + _, err := s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID) + return err + }) +} diff --git a/appservice/storage/cosmosdb/storage.go b/appservice/storage/cosmosdb/storage.go new file mode 100644 index 000000000..3639010e1 --- /dev/null +++ b/appservice/storage/cosmosdb/storage.go @@ -0,0 +1,119 @@ +// Copyright 2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + + // Import SQLite database driver + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + _ "github.com/mattn/go-sqlite3" +) + +// Database stores events intended to be later sent to application services +type Database struct { + sqlutil.PartitionOffsetStatements + events eventsStatements + txnID txnStatements + db *sql.DB + writer sqlutil.Writer +} + +// NewDatabase opens a new database +func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { + var result Database + var err error + if result.db, err = sqlutil.Open(dbProperties); err != nil { + return nil, err + } + result.writer = sqlutil.NewExclusiveWriter() + if err = result.prepare(); err != nil { + return nil, err + } + if err = result.PartitionOffsetStatements.Prepare(result.db, result.writer, "appservice"); err != nil { + return nil, err + } + return &result, nil +} + +func (d *Database) prepare() error { + if err := d.events.prepare(d.db, d.writer); err != nil { + return err + } + + return d.txnID.prepare(d.db, d.writer) +} + +// StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database +// for a transaction worker to pull and later send to an application service. +func (d *Database) StoreEvent( + ctx context.Context, + appServiceID string, + event *gomatrixserverlib.HeaderedEvent, +) error { + return d.events.insertEvent(ctx, appServiceID, event) +} + +// GetEventsWithAppServiceID returns a slice of events and their IDs intended to +// be sent to an application service given its ID. +func (d *Database) GetEventsWithAppServiceID( + ctx context.Context, + appServiceID string, + limit int, +) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) { + return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit) +} + +// CountEventsWithAppServiceID returns the number of events destined for an +// application service given its ID. +func (d *Database) CountEventsWithAppServiceID( + ctx context.Context, + appServiceID string, +) (int, error) { + return d.events.countEventsByApplicationServiceID(ctx, appServiceID) +} + +// UpdateTxnIDForEvents takes in an application service ID and a +// and stores them in the DB, unless the pair already exists, in +// which case it updates them. +func (d *Database) UpdateTxnIDForEvents( + ctx context.Context, + appserviceID string, + maxID, txnID int, +) error { + return d.events.updateTxnIDForEvents(ctx, appserviceID, maxID, txnID) +} + +// RemoveEventsBeforeAndIncludingID removes all events from the database that +// are less than or equal to a given maximum ID. IDs here are implemented as a +// serial, thus this should always delete events in chronological order. +func (d *Database) RemoveEventsBeforeAndIncludingID( + ctx context.Context, + appserviceID string, + eventTableID int, +) error { + return d.events.deleteEventsBeforeAndIncludingID(ctx, appserviceID, eventTableID) +} + +// GetLatestTxnID returns the latest available transaction id +func (d *Database) GetLatestTxnID( + ctx context.Context, +) (int, error) { + return d.txnID.selectTxnID(ctx) +} diff --git a/appservice/storage/cosmosdb/txn_id_counter_table.go b/appservice/storage/cosmosdb/txn_id_counter_table.go new file mode 100644 index 000000000..73b13f3db --- /dev/null +++ b/appservice/storage/cosmosdb/txn_id_counter_table.go @@ -0,0 +1,69 @@ +// Copyright 2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const txnIDSchema = ` +-- Keeps a count of the current transaction ID +CREATE TABLE IF NOT EXISTS appservice_counters ( + name TEXT PRIMARY KEY NOT NULL, + last_id INTEGER DEFAULT 1 +); +INSERT OR IGNORE INTO appservice_counters (name, last_id) VALUES('txn_id', 1); +` + +const selectTxnIDSQL = ` + SELECT last_id FROM appservice_counters WHERE name='txn_id'; + UPDATE appservice_counters SET last_id=last_id+1 WHERE name='txn_id'; +` + +type txnStatements struct { + db *sql.DB + writer sqlutil.Writer + selectTxnIDStmt *sql.Stmt +} + +func (s *txnStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { + s.db = db + s.writer = writer + _, err = db.Exec(txnIDSchema) + if err != nil { + return + } + + if s.selectTxnIDStmt, err = db.Prepare(selectTxnIDSQL); err != nil { + return + } + + return +} + +// selectTxnID selects the latest ascending transaction ID +func (s *txnStatements) selectTxnID( + ctx context.Context, +) (txnID int, err error) { + err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + err := s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID) + return err + }) + return +} diff --git a/appservice/storage/storage.go b/appservice/storage/storage.go index b0df2b7dc..876bf4d0c 100644 --- a/appservice/storage/storage.go +++ b/appservice/storage/storage.go @@ -17,6 +17,7 @@ package storage import ( + "github.com/matrix-org/dendrite/appservice/storage/cosmosdb" "fmt" "github.com/matrix-org/dendrite/appservice/storage/postgres" @@ -28,6 +29,8 @@ import ( // and sets DB connection parameters func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) { switch { + case dbProperties.ConnectionString.IsCosmosDB(): + return cosmosdb.NewDatabase(dbProperties) case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(dbProperties) case dbProperties.ConnectionString.IsPostgres(): diff --git a/appservice/storage/storage_wasm.go b/appservice/storage/storage_wasm.go index 07d0e9ee1..fa04e8f98 100644 --- a/appservice/storage/storage_wasm.go +++ b/appservice/storage/storage_wasm.go @@ -23,6 +23,8 @@ import ( func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) { switch { + case dbProperties.ConnectionString.IsCosmosDB(): + return nil, fmt.Errorf("can't use CosmosDB implementation") case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(dbProperties) case dbProperties.ConnectionString.IsPostgres(): diff --git a/dendrite-config-cosmosdb.yaml b/dendrite-config-cosmosdb.yaml new file mode 100644 index 000000000..4f61b5362 --- /dev/null +++ b/dendrite-config-cosmosdb.yaml @@ -0,0 +1,394 @@ +# This is the Dendrite configuration file. +# +# The configuration is split up into sections - each Dendrite component has a +# configuration section, in addition to the "global" section which applies to +# all components. +# +# At a minimum, to get started, you will need to update the settings in the +# "global" section for your deployment, and you will need to check that the +# database "connection_string" line in each component section is correct. +# +# Each component with a "database" section can accept the following formats +# for "connection_string": +# SQLite: file:filename.db +# file:///path/to/filename.db +# PostgreSQL: postgresql://user:pass@hostname/database?params=... +# CosmosDB: cosmosdb:filename.db +# cosmosdb:///path/to/filename.db +# +# SQLite is embedded into Dendrite and therefore no further prerequisites are +# needed for the database when using SQLite mode. However, performance with +# PostgreSQL is significantly better and recommended for multi-user deployments. +# SQLite is typically around 20-30% slower than PostgreSQL when tested with a +# small number of users and likely will perform worse still with a higher volume +# of users. +# +# The "max_open_conns" and "max_idle_conns" settings configure the maximum +# number of open/idle database connections. The value 0 will use the database +# engine default, and a negative value will use unlimited connections. The +# "conn_max_lifetime" option controls the maximum length of time a database +# connection can be idle in seconds - a negative value is unlimited. + +# The version of the configuration file. +version: 1 + +# Global Matrix configuration. This configuration applies to all components. +global: + # The domain name of this homeserver. + server_name: localhost + + # The path to the signing private key file, used to sign requests and events. + # Note that this is NOT the same private key as used for TLS! To generate a + # signing key, use "./bin/generate-keys --private-key matrix_key.pem". + private_key: matrix_key.pem + + # The paths and expiry timestamps (as a UNIX timestamp in millisecond precision) + # to old signing private keys that were formerly in use on this domain. These + # keys will not be used for federation request or event signing, but will be + # provided to any other homeserver that asks when trying to verify old events. + # old_private_keys: + # - private_key: old_matrix_key.pem + # expired_at: 1601024554498 + + # How long a remote server can cache our server signing key before requesting it + # again. Increasing this number will reduce the number of requests made by other + # servers for our key but increases the period that a compromised key will be + # considered valid by other homeservers. + key_validity_period: 168h0m0s + + # Lists of domains that the server will trust as identity servers to verify third + # party identifiers such as phone numbers and email addresses. + trusted_third_party_id_servers: + - matrix.org + - vector.im + + # Disables federation. Dendrite will not be able to make any outbound HTTP requests + # to other servers and the federation API will not be exposed. + disable_federation: false + + # Configuration for Kafka/Naffka. + kafka: + # List of Kafka broker addresses to connect to. This is not needed if using + # Naffka in monolith mode. + addresses: + - localhost:2181 + + # The prefix to use for Kafka topic names for this homeserver. Change this only if + # you are running more than one Dendrite homeserver on the same Kafka deployment. + topic_prefix: Dendrite + + # Whether to use Naffka instead of Kafka. This is only available in monolith + # mode, but means that you can run a single-process server without requiring + # Kafka. + use_naffka: true + + # The max size a Kafka message is allowed to use. + # You only need to change this value, if you encounter issues with too large messages. + # Must be less than/equal to "max.message.bytes" configured in Kafka. + # Defaults to 8388608 bytes. + # max_message_bytes: 8388608 + + # Naffka database options. Not required when using Kafka. + naffka_database: + connection_string: cosmosdb:naffka.db + max_open_conns: 10 + max_idle_conns: 2 + conn_max_lifetime: -1 + + # Configuration for Prometheus metric collection. + metrics: + # Whether or not Prometheus metrics are enabled. + enabled: false + + # HTTP basic authentication to protect access to monitoring. + basic_auth: + username: metrics + password: metrics + + # DNS cache options. The DNS cache may reduce the load on DNS servers + # if there is no local caching resolver available for use. + dns_cache: + # Whether or not the DNS cache is enabled. + enabled: false + + # Maximum number of entries to hold in the DNS cache, and + # for how long those items should be considered valid in seconds. + cache_size: 256 + cache_lifetime: 300 + +# Configuration for the Appservice API. +app_service_api: + internal_api: + listen: http://localhost:7777 + connect: http://localhost:7777 + database: + connection_string: cosmosdb:appservice.db + max_open_conns: 10 + max_idle_conns: 2 + conn_max_lifetime: -1 + + # Disable the validation of TLS certificates of appservices. This is + # not recommended in production since it may allow appservice traffic + # to be sent to an unverified endpoint. + disable_tls_validation: false + + # Appservice configuration files to load into this homeserver. + config_files: [] + +# Configuration for the Client API. +client_api: + internal_api: + listen: http://localhost:7771 + connect: http://localhost:7771 + external_api: + listen: http://[::]:8071 + + # Prevents new users from being able to register on this homeserver, except when + # using the registration shared secret below. + registration_disabled: false + + # If set, allows registration by anyone who knows the shared secret, regardless of + # whether registration is otherwise disabled. + registration_shared_secret: "" + + # Whether to require reCAPTCHA for registration. + enable_registration_captcha: false + + # Settings for ReCAPTCHA. + recaptcha_public_key: "" + recaptcha_private_key: "" + recaptcha_bypass_secret: "" + recaptcha_siteverify_api: "" + + # TURN server information that this homeserver should send to clients. + turn: + turn_user_lifetime: "" + turn_uris: [] + turn_shared_secret: "" + turn_username: "" + turn_password: "" + + # Settings for rate-limited endpoints. Rate limiting will kick in after the + # threshold number of "slots" have been taken by requests from a specific + # host. Each "slot" will be released after the cooloff time in milliseconds. + rate_limiting: + enabled: true + threshold: 5 + cooloff_ms: 500 + +# Configuration for the EDU server. +edu_server: + internal_api: + listen: http://localhost:7778 + connect: http://localhost:7778 + +# Configuration for the Federation API. +federation_api: + internal_api: + listen: http://localhost:7772 + connect: http://localhost:7772 + external_api: + listen: http://[::]:8072 + + # List of paths to X.509 certificates to be used by the external federation listeners. + # These certificates will be used to calculate the TLS fingerprints and other servers + # will expect the certificate to match these fingerprints. Certificates must be in PEM + # format. + federation_certificates: [] + +# Configuration for the Federation Sender. +federation_sender: + internal_api: + listen: http://localhost:7775 + connect: http://localhost:7775 + database: + connection_string: cosmosdb:federationsender.db + max_open_conns: 10 + max_idle_conns: 2 + conn_max_lifetime: -1 + + # How many times we will try to resend a failed transaction to a specific server. The + # backoff is 2**x seconds, so 1 = 2 seconds, 2 = 4 seconds, 3 = 8 seconds etc. + send_max_retries: 16 + + # Disable the validation of TLS certificates of remote federated homeservers. Do not + # enable this option in production as it presents a security risk! + disable_tls_validation: false + + # Use the following proxy server for outbound federation traffic. + proxy_outbound: + enabled: false + protocol: http + host: localhost + port: 8080 + +# Configuration for the Key Server (for end-to-end encryption). +key_server: + internal_api: + listen: http://localhost:7779 + connect: http://localhost:7779 + database: + connection_string: cosmosdb:keyserver.db + max_open_conns: 10 + max_idle_conns: 2 + conn_max_lifetime: -1 + +# Configuration for the Media API. +media_api: + internal_api: + listen: http://localhost:7774 + connect: http://localhost:7774 + external_api: + listen: http://[::]:8074 + database: + connection_string: cosmosdb:mediaapi.db + max_open_conns: 5 + max_idle_conns: 2 + conn_max_lifetime: -1 + + # Storage path for uploaded media. May be relative or absolute. + base_path: ./media_store + + # The maximum allowed file size (in bytes) for media uploads to this homeserver + # (0 = unlimited). If using a reverse proxy, ensure it allows requests at + # least this large (e.g. client_max_body_size in nginx.) + max_file_size_bytes: 10485760 + + # Whether to dynamically generate thumbnails if needed. + dynamic_thumbnails: false + + # The maximum number of simultaneous thumbnail generators to run. + max_thumbnail_generators: 10 + + # A list of thumbnail sizes to be generated for media content. + thumbnail_sizes: + - width: 32 + height: 32 + method: crop + - width: 96 + height: 96 + method: crop + - width: 640 + height: 480 + method: scale + +# Configuration for experimental MSC's +mscs: + # A list of enabled MSC's + # Currently valid values are: + # - msc2836 (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836) + # - msc2946 (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946) + mscs: [] + database: + connection_string: cosmosdb:mscs.db + max_open_conns: 5 + max_idle_conns: 2 + conn_max_lifetime: -1 + +# Configuration for the Room Server. +room_server: + internal_api: + listen: http://localhost:7770 + connect: http://localhost:7770 + database: + connection_string: cosmosdb:roomserver.db + max_open_conns: 10 + max_idle_conns: 2 + conn_max_lifetime: -1 + +# Configuration for the Signing Key Server (for server signing keys). +signing_key_server: + internal_api: + listen: http://localhost:7780 + connect: http://localhost:7780 + database: + connection_string: cosmosdb:signingkeyserver.db + max_open_conns: 10 + max_idle_conns: 2 + conn_max_lifetime: -1 + + # Perspective keyservers to use as a backup when direct key fetches fail. This may + # be required to satisfy key requests for servers that are no longer online when + # joining some rooms. + key_perspectives: + - server_name: matrix.org + keys: + - key_id: ed25519:auto + public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw + - key_id: ed25519:a_RXGa + public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ + + # This option will control whether Dendrite will prefer to look up keys directly + # or whether it should try perspective servers first, using direct fetches as a + # last resort. + prefer_direct_fetch: false + +# Configuration for the Sync API. +sync_api: + internal_api: + listen: http://localhost:7773 + connect: http://localhost:7773 + external_api: + listen: http://[::]:8073 + database: + connection_string: cosmosdb:syncapi.db + max_open_conns: 10 + max_idle_conns: 2 + conn_max_lifetime: -1 + + # This option controls which HTTP header to inspect to find the real remote IP + # address of the client. This is likely required if Dendrite is running behind + # a reverse proxy server. + # real_ip_header: X-Real-IP + +# Configuration for the User API. +user_api: + # The cost when hashing passwords on registration/login. Default: 10. Min: 4, Max: 31 + # See https://pkg.go.dev/golang.org/x/crypto/bcrypt for more information. + # Setting this lower makes registration/login consume less CPU resources at the cost of security + # should the database be compromised. Setting this higher makes registration/login consume more + # CPU resources but makes it harder to brute force password hashes. + # This value can be low if performing tests or on embedded Dendrite instances (e.g WASM builds) + # bcrypt_cost: 10 + internal_api: + listen: http://localhost:7781 + connect: http://localhost:7781 + account_database: + connection_string: cosmosdb:userapi_accounts.db + max_open_conns: 10 + max_idle_conns: 2 + conn_max_lifetime: -1 + device_database: + connection_string: cosmosdb:userapi_devices.db + max_open_conns: 10 + max_idle_conns: 2 + conn_max_lifetime: -1 + # The length of time that a token issued for a relying party from + # /_matrix/client/r0/user/{userId}/openid/request_token endpoint + # is considered to be valid in milliseconds. + # The default lifetime is 3600000ms (60 minutes). + # openid_token_lifetime_ms: 3600000 + +# Configuration for Opentracing. +# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on +# how this works and how to set it up. +tracing: + enabled: false + jaeger: + serviceName: "" + disabled: false + rpc_metrics: false + tags: [] + sampler: null + reporter: null + headers: null + baggage_restrictions: null + throttler: null + +# Logging configuration, in addition to the standard logging that is sent to +# stdout by Dendrite. +logging: +- type: file + level: info + params: + path: ./logs diff --git a/federationsender/storage/cosmosdb/blacklist_table.go b/federationsender/storage/cosmosdb/blacklist_table.go new file mode 100644 index 000000000..f4488a8e8 --- /dev/null +++ b/federationsender/storage/cosmosdb/blacklist_table.go @@ -0,0 +1,107 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const blacklistSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_blacklist ( + -- The blacklisted server name + server_name TEXT NOT NULL, + UNIQUE (server_name) +); +` + +const insertBlacklistSQL = "" + + "INSERT INTO federationsender_blacklist (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectBlacklistSQL = "" + + "SELECT server_name FROM federationsender_blacklist WHERE server_name = $1" + +const deleteBlacklistSQL = "" + + "DELETE FROM federationsender_blacklist WHERE server_name = $1" + +type blacklistStatements struct { + db *sql.DB + insertBlacklistStmt *sql.Stmt + selectBlacklistStmt *sql.Stmt + deleteBlacklistStmt *sql.Stmt +} + +func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) { + s = &blacklistStatements{ + db: db, + } + _, err = db.Exec(blacklistSchema) + if err != nil { + return + } + + if s.insertBlacklistStmt, err = db.Prepare(insertBlacklistSQL); err != nil { + return + } + if s.selectBlacklistStmt, err = db.Prepare(selectBlacklistSQL); err != nil { + return + } + if s.deleteBlacklistStmt, err = db.Prepare(deleteBlacklistSQL); err != nil { + return + } + return +} + +// insertRoom inserts the room if it didn't already exist. +// If the room didn't exist then last_event_id is set to the empty string. +func (s *blacklistStatements) InsertBlacklist( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +// selectRoomForUpdate locks the row for the room and returns the last_event_id. +// The row must already exist in the table. Callers can ensure that the row +// exists by calling insertRoom first. +func (s *blacklistStatements) SelectBlacklist( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is blacklisted, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +// updateRoom updates the last_event_id for the room. selectRoomForUpdate should +// have already been called earlier within the transaction. +func (s *blacklistStatements) DeleteBlacklist( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} diff --git a/federationsender/storage/cosmosdb/inbound_peeks_table.go b/federationsender/storage/cosmosdb/inbound_peeks_table.go new file mode 100644 index 000000000..88d9b4a86 --- /dev/null +++ b/federationsender/storage/cosmosdb/inbound_peeks_table.go @@ -0,0 +1,176 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const inboundPeeksSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_inbound_peeks ( + room_id TEXT NOT NULL, + server_name TEXT NOT NULL, + peek_id TEXT NOT NULL, + creation_ts INTEGER NOT NULL, + renewed_ts INTEGER NOT NULL, + renewal_interval INTEGER NOT NULL, + UNIQUE (room_id, server_name, peek_id) +); +` + +const insertInboundPeekSQL = "" + + "INSERT INTO federationsender_inbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)" + +const selectInboundPeekSQL = "" + + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" + +const selectInboundPeeksSQL = "" + + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1" + +const renewInboundPeekSQL = "" + + "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" + +const deleteInboundPeekSQL = "" + + "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2" + +const deleteInboundPeeksSQL = "" + + "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1" + +type inboundPeeksStatements struct { + db *sql.DB + insertInboundPeekStmt *sql.Stmt + selectInboundPeekStmt *sql.Stmt + selectInboundPeeksStmt *sql.Stmt + renewInboundPeekStmt *sql.Stmt + deleteInboundPeekStmt *sql.Stmt + deleteInboundPeeksStmt *sql.Stmt +} + +func NewSQLiteInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err error) { + s = &inboundPeeksStatements{ + db: db, + } + _, err = db.Exec(inboundPeeksSchema) + if err != nil { + return + } + + if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil { + return + } + if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil { + return + } + if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil { + return + } + if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil { + return + } + if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil { + return + } + if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil { + return + } + return +} + +func (s *inboundPeeksStatements) InsertInboundPeek( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, +) (err error) { + nowMilli := time.Now().UnixNano() / int64(time.Millisecond) + stmt := sqlutil.TxStmt(txn, s.insertInboundPeekStmt) + _, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval) + return +} + +func (s *inboundPeeksStatements) RenewInboundPeek( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, +) (err error) { + nowMilli := time.Now().UnixNano() / int64(time.Millisecond) + _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) + return +} + +func (s *inboundPeeksStatements) SelectInboundPeek( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, +) (*types.InboundPeek, error) { + row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID) + inboundPeek := types.InboundPeek{} + err := row.Scan( + &inboundPeek.RoomID, + &inboundPeek.ServerName, + &inboundPeek.PeekID, + &inboundPeek.CreationTimestamp, + &inboundPeek.RenewedTimestamp, + &inboundPeek.RenewalInterval, + ) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + return &inboundPeek, nil +} + +func (s *inboundPeeksStatements) SelectInboundPeeks( + ctx context.Context, txn *sql.Tx, roomID string, +) (inboundPeeks []types.InboundPeek, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryContext(ctx, roomID) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectInboundPeeks: rows.close() failed") + + for rows.Next() { + inboundPeek := types.InboundPeek{} + if err = rows.Scan( + &inboundPeek.RoomID, + &inboundPeek.ServerName, + &inboundPeek.PeekID, + &inboundPeek.CreationTimestamp, + &inboundPeek.RenewedTimestamp, + &inboundPeek.RenewalInterval, + ); err != nil { + return + } + inboundPeeks = append(inboundPeeks, inboundPeek) + } + + return inboundPeeks, rows.Err() +} + +func (s *inboundPeeksStatements) DeleteInboundPeek( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteInboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID) + return +} + +func (s *inboundPeeksStatements) DeleteInboundPeeks( + ctx context.Context, txn *sql.Tx, roomID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteInboundPeeksStmt).ExecContext(ctx, roomID) + return +} diff --git a/federationsender/storage/cosmosdb/joined_hosts_table.go b/federationsender/storage/cosmosdb/joined_hosts_table.go new file mode 100644 index 000000000..b903d1b7b --- /dev/null +++ b/federationsender/storage/cosmosdb/joined_hosts_table.go @@ -0,0 +1,219 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "strings" + + "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const joinedHostsSchema = ` +-- The joined_hosts table stores a list of m.room.member event ids in the +-- current state for each room where the membership is "join". +-- There will be an entry for every user that is joined to the room. +CREATE TABLE IF NOT EXISTS federationsender_joined_hosts ( + -- The string ID of the room. + room_id TEXT NOT NULL, + -- The event ID of the m.room.member join event. + event_id TEXT NOT NULL, + -- The domain part of the user ID the m.room.member event is for. + server_name TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx + ON federationsender_joined_hosts (event_id); + +CREATE INDEX IF NOT EXISTS federatonsender_joined_hosts_room_id_idx + ON federationsender_joined_hosts (room_id) +` + +const insertJoinedHostsSQL = "" + + "INSERT OR IGNORE INTO federationsender_joined_hosts (room_id, event_id, server_name)" + + " VALUES ($1, $2, $3)" + +const deleteJoinedHostsSQL = "" + + "DELETE FROM federationsender_joined_hosts WHERE event_id = $1" + +const deleteJoinedHostsForRoomSQL = "" + + "DELETE FROM federationsender_joined_hosts WHERE room_id = $1" + +const selectJoinedHostsSQL = "" + + "SELECT event_id, server_name FROM federationsender_joined_hosts" + + " WHERE room_id = $1" + +const selectAllJoinedHostsSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_joined_hosts" + +const selectJoinedHostsForRoomsSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" + +type joinedHostsStatements struct { + db *sql.DB + insertJoinedHostsStmt *sql.Stmt + deleteJoinedHostsStmt *sql.Stmt + deleteJoinedHostsForRoomStmt *sql.Stmt + selectJoinedHostsStmt *sql.Stmt + selectAllJoinedHostsStmt *sql.Stmt + // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic +} + +func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { + s = &joinedHostsStatements{ + db: db, + } + _, err = db.Exec(joinedHostsSchema) + if err != nil { + return + } + if s.insertJoinedHostsStmt, err = db.Prepare(insertJoinedHostsSQL); err != nil { + return + } + if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil { + return + } + if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil { + return + } + if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil { + return + } + if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil { + return + } + return +} + +func (s *joinedHostsStatements) InsertJoinedHosts( + ctx context.Context, + txn *sql.Tx, + roomID, eventID string, + serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) + _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) + return err +} + +func (s *joinedHostsStatements) DeleteJoinedHosts( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) error { + for _, eventID := range eventIDs { + stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) + if _, err := stmt.ExecContext(ctx, eventID); err != nil { + return err + } + } + return nil +} + +func (s *joinedHostsStatements) DeleteJoinedHostsForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt) + _, err := stmt.ExecContext(ctx, roomID) + return err +} + +func (s *joinedHostsStatements) SelectJoinedHostsWithTx( + ctx context.Context, txn *sql.Tx, roomID string, +) ([]types.JoinedHost, error) { + stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt) + return joinedHostsFromStmt(ctx, stmt, roomID) +} + +func (s *joinedHostsStatements) SelectJoinedHosts( + ctx context.Context, roomID string, +) ([]types.JoinedHost, error) { + return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID) +} + +func (s *joinedHostsStatements) SelectAllJoinedHosts( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var serverName string + if err = rows.Scan(&serverName); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(serverName)) + } + + return result, rows.Err() +} + +func (s *joinedHostsStatements) SelectJoinedHostsForRooms( + ctx context.Context, roomIDs []string, +) ([]gomatrixserverlib.ServerName, error) { + iRoomIDs := make([]interface{}, len(roomIDs)) + for i := range roomIDs { + iRoomIDs[i] = roomIDs[i] + } + + sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) + rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var serverName string + if err = rows.Scan(&serverName); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(serverName)) + } + + return result, rows.Err() +} + +func joinedHostsFromStmt( + ctx context.Context, stmt *sql.Stmt, roomID string, +) ([]types.JoinedHost, error) { + rows, err := stmt.QueryContext(ctx, roomID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "joinedHostsFromStmt: rows.close() failed") + + var result []types.JoinedHost + for rows.Next() { + var eventID, serverName string + if err = rows.Scan(&eventID, &serverName); err != nil { + return nil, err + } + result = append(result, types.JoinedHost{ + MemberEventID: eventID, + ServerName: gomatrixserverlib.ServerName(serverName), + }) + } + + return result, nil +} diff --git a/federationsender/storage/cosmosdb/outbound_peeks_table.go b/federationsender/storage/cosmosdb/outbound_peeks_table.go new file mode 100644 index 000000000..0da9344d2 --- /dev/null +++ b/federationsender/storage/cosmosdb/outbound_peeks_table.go @@ -0,0 +1,176 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const outboundPeeksSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_outbound_peeks ( + room_id TEXT NOT NULL, + server_name TEXT NOT NULL, + peek_id TEXT NOT NULL, + creation_ts INTEGER NOT NULL, + renewed_ts INTEGER NOT NULL, + renewal_interval INTEGER NOT NULL, + UNIQUE (room_id, server_name, peek_id) +); +` + +const insertOutboundPeekSQL = "" + + "INSERT INTO federationsender_outbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)" + +const selectOutboundPeekSQL = "" + + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" + +const selectOutboundPeeksSQL = "" + + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1" + +const renewOutboundPeekSQL = "" + + "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" + +const deleteOutboundPeekSQL = "" + + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2" + +const deleteOutboundPeeksSQL = "" + + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1" + +type outboundPeeksStatements struct { + db *sql.DB + insertOutboundPeekStmt *sql.Stmt + selectOutboundPeekStmt *sql.Stmt + selectOutboundPeeksStmt *sql.Stmt + renewOutboundPeekStmt *sql.Stmt + deleteOutboundPeekStmt *sql.Stmt + deleteOutboundPeeksStmt *sql.Stmt +} + +func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err error) { + s = &outboundPeeksStatements{ + db: db, + } + _, err = db.Exec(outboundPeeksSchema) + if err != nil { + return + } + + if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil { + return + } + if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil { + return + } + if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil { + return + } + if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil { + return + } + if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil { + return + } + if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil { + return + } + return +} + +func (s *outboundPeeksStatements) InsertOutboundPeek( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, +) (err error) { + nowMilli := time.Now().UnixNano() / int64(time.Millisecond) + stmt := sqlutil.TxStmt(txn, s.insertOutboundPeekStmt) + _, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval) + return +} + +func (s *outboundPeeksStatements) RenewOutboundPeek( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, +) (err error) { + nowMilli := time.Now().UnixNano() / int64(time.Millisecond) + _, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) + return +} + +func (s *outboundPeeksStatements) SelectOutboundPeek( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, +) (*types.OutboundPeek, error) { + row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID) + outboundPeek := types.OutboundPeek{} + err := row.Scan( + &outboundPeek.RoomID, + &outboundPeek.ServerName, + &outboundPeek.PeekID, + &outboundPeek.CreationTimestamp, + &outboundPeek.RenewedTimestamp, + &outboundPeek.RenewalInterval, + ) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + return &outboundPeek, nil +} + +func (s *outboundPeeksStatements) SelectOutboundPeeks( + ctx context.Context, txn *sql.Tx, roomID string, +) (outboundPeeks []types.OutboundPeek, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryContext(ctx, roomID) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectOutboundPeeks: rows.close() failed") + + for rows.Next() { + outboundPeek := types.OutboundPeek{} + if err = rows.Scan( + &outboundPeek.RoomID, + &outboundPeek.ServerName, + &outboundPeek.PeekID, + &outboundPeek.CreationTimestamp, + &outboundPeek.RenewedTimestamp, + &outboundPeek.RenewalInterval, + ); err != nil { + return + } + outboundPeeks = append(outboundPeeks, outboundPeek) + } + + return outboundPeeks, rows.Err() +} + +func (s *outboundPeeksStatements) DeleteOutboundPeek( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteOutboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID) + return +} + +func (s *outboundPeeksStatements) DeleteOutboundPeeks( + ctx context.Context, txn *sql.Tx, roomID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteOutboundPeeksStmt).ExecContext(ctx, roomID) + return +} diff --git a/federationsender/storage/cosmosdb/queue_edus_table.go b/federationsender/storage/cosmosdb/queue_edus_table.go new file mode 100644 index 000000000..530e0c088 --- /dev/null +++ b/federationsender/storage/cosmosdb/queue_edus_table.go @@ -0,0 +1,207 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const queueEDUsSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_queue_edus ( + -- The type of the event (informational). + edu_type TEXT NOT NULL, + -- The domain part of the user ID the EDU event is for. + server_name TEXT NOT NULL, + -- The JSON NID from the federationsender_queue_edus_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx + ON federationsender_queue_edus (json_nid, server_name); +` + +const insertQueueEDUSQL = "" + + "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const deleteQueueEDUsSQL = "" + + "DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)" + +const selectQueueEDUSQL = "" + + "SELECT json_nid FROM federationsender_queue_edus" + + " WHERE server_name = $1" + + " LIMIT $2" + +const selectQueueEDUReferenceJSONCountSQL = "" + + "SELECT COUNT(*) FROM federationsender_queue_edus" + + " WHERE json_nid = $1" + +const selectQueueEDUCountSQL = "" + + "SELECT COUNT(*) FROM federationsender_queue_edus" + + " WHERE server_name = $1" + +const selectQueueServerNamesSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_queue_edus" + +type queueEDUsStatements struct { + db *sql.DB + insertQueueEDUStmt *sql.Stmt + selectQueueEDUStmt *sql.Stmt + selectQueueEDUReferenceJSONCountStmt *sql.Stmt + selectQueueEDUCountStmt *sql.Stmt + selectQueueEDUServerNamesStmt *sql.Stmt +} + +func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { + s = &queueEDUsStatements{ + db: db, + } + _, err = db.Exec(queueEDUsSchema) + if err != nil { + return + } + if s.insertQueueEDUStmt, err = db.Prepare(insertQueueEDUSQL); err != nil { + return + } + if s.selectQueueEDUStmt, err = db.Prepare(selectQueueEDUSQL); err != nil { + return + } + if s.selectQueueEDUReferenceJSONCountStmt, err = db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil { + return + } + if s.selectQueueEDUCountStmt, err = db.Prepare(selectQueueEDUCountSQL); err != nil { + return + } + if s.selectQueueEDUServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil { + return + } + return +} + +func (s *queueEDUsStatements) InsertQueueEDU( + ctx context.Context, + txn *sql.Tx, + eduType string, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) + _, err := stmt.ExecContext( + ctx, + eduType, // the EDU type + serverName, // destination server name + nid, // JSON blob NID + ) + return err +} + +func (s *queueEDUsStatements) DeleteQueueEDUs( + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + deleteSQL := strings.Replace(deleteQueueEDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) + deleteStmt, err := txn.Prepare(deleteSQL) + if err != nil { + return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err) + } + + params := make([]interface{}, len(jsonNIDs)+1) + params[0] = serverName + for k, v := range jsonNIDs { + params[k+1] = v + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func (s *queueEDUsStatements) SelectQueueEDUs( + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt) + rows, err := stmt.QueryContext(ctx, serverName, limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + return result, nil +} + +func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( + ctx context.Context, txn *sql.Tx, jsonNID int64, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt) + err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) + if err == sql.ErrNoRows { + return -1, nil + } + return count, err +} + +func (s *queueEDUsStatements) SelectQueueEDUCount( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} + +func (s *queueEDUsStatements) SelectQueueEDUServerNames( + ctx context.Context, txn *sql.Tx, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt) + rows, err := stmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []gomatrixserverlib.ServerName + for rows.Next() { + var serverName gomatrixserverlib.ServerName + if err = rows.Scan(&serverName); err != nil { + return nil, err + } + result = append(result, serverName) + } + + return result, rows.Err() +} diff --git a/federationsender/storage/cosmosdb/queue_json_table.go b/federationsender/storage/cosmosdb/queue_json_table.go new file mode 100644 index 000000000..74cee2b17 --- /dev/null +++ b/federationsender/storage/cosmosdb/queue_json_table.go @@ -0,0 +1,136 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const queueJSONSchema = ` +-- The queue_retry_json table contains event contents that +-- we failed to send. +CREATE TABLE IF NOT EXISTS federationsender_queue_json ( + -- The JSON NID. This allows the federationsender_queue_retry table to + -- cross-reference to find the JSON blob. + json_nid INTEGER PRIMARY KEY AUTOINCREMENT, + -- The JSON body. Text so that we preserve UTF-8. + json_body TEXT NOT NULL +); +` + +const insertJSONSQL = "" + + "INSERT INTO federationsender_queue_json (json_body)" + + " VALUES ($1)" + +const deleteJSONSQL = "" + + "DELETE FROM federationsender_queue_json WHERE json_nid IN ($1)" + +const selectJSONSQL = "" + + "SELECT json_nid, json_body FROM federationsender_queue_json" + + " WHERE json_nid IN ($1)" + +type queueJSONStatements struct { + db *sql.DB + insertJSONStmt *sql.Stmt + //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic + //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic +} + +func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) { + s = &queueJSONStatements{ + db: db, + } + _, err = db.Exec(queueJSONSchema) + if err != nil { + return + } + if s.insertJSONStmt, err = db.Prepare(insertJSONSQL); err != nil { + return + } + return +} + +func (s *queueJSONStatements) InsertQueueJSON( + ctx context.Context, txn *sql.Tx, json string, +) (lastid int64, err error) { + stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) + res, err := stmt.ExecContext(ctx, json) + if err != nil { + return 0, fmt.Errorf("stmt.QueryContext: %w", err) + } + lastid, err = res.LastInsertId() + if err != nil { + return 0, fmt.Errorf("res.LastInsertId: %w", err) + } + return +} + +func (s *queueJSONStatements) DeleteQueueJSON( + ctx context.Context, txn *sql.Tx, nids []int64, +) error { + deleteSQL := strings.Replace(deleteJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + deleteStmt, err := txn.Prepare(deleteSQL) + if err != nil { + return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err) + } + + iNIDs := make([]interface{}, len(nids)) + for k, v := range nids { + iNIDs[k] = v + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, iNIDs...) + return err +} + +func (s *queueJSONStatements) SelectQueueJSON( + ctx context.Context, txn *sql.Tx, jsonNIDs []int64, +) (map[int64][]byte, error) { + selectSQL := strings.Replace(selectJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1) + selectStmt, err := txn.Prepare(selectSQL) + if err != nil { + return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err) + } + + iNIDs := make([]interface{}, len(jsonNIDs)) + for k, v := range jsonNIDs { + iNIDs[k] = v + } + + blobs := map[int64][]byte{} + stmt := sqlutil.TxStmt(txn, selectStmt) + rows, err := stmt.QueryContext(ctx, iNIDs...) + if err != nil { + return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed") + for rows.Next() { + var nid int64 + var blob []byte + if err = rows.Scan(&nid, &blob); err != nil { + return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err) + } + blobs[nid] = blob + } + return blobs, err +} diff --git a/federationsender/storage/cosmosdb/queue_pdus_table.go b/federationsender/storage/cosmosdb/queue_pdus_table.go new file mode 100644 index 000000000..8ca0a1fde --- /dev/null +++ b/federationsender/storage/cosmosdb/queue_pdus_table.go @@ -0,0 +1,235 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const queuePDUsSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_queue_pdus ( + -- The transaction ID that was generated before persisting the event. + transaction_id TEXT NOT NULL, + -- The domain part of the user ID the m.room.member event is for. + server_name TEXT NOT NULL, + -- The JSON NID from the federationsender_queue_pdus_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx + ON federationsender_queue_pdus (json_nid, server_name); +` + +const insertQueuePDUSQL = "" + + "INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const deleteQueuePDUsSQL = "" + + "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid IN ($2)" + +const selectQueueNextTransactionIDSQL = "" + + "SELECT transaction_id FROM federationsender_queue_pdus" + + " WHERE server_name = $1" + + " ORDER BY transaction_id ASC" + + " LIMIT 1" + +const selectQueuePDUsSQL = "" + + "SELECT json_nid FROM federationsender_queue_pdus" + + " WHERE server_name = $1" + + " LIMIT $2" + +const selectQueuePDUsReferenceJSONCountSQL = "" + + "SELECT COUNT(*) FROM federationsender_queue_pdus" + + " WHERE json_nid = $1" + +const selectQueuePDUsCountSQL = "" + + "SELECT COUNT(*) FROM federationsender_queue_pdus" + + " WHERE server_name = $1" + +const selectQueuePDUsServerNamesSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_queue_pdus" + +type queuePDUsStatements struct { + db *sql.DB + insertQueuePDUStmt *sql.Stmt + selectQueueNextTransactionIDStmt *sql.Stmt + selectQueuePDUsStmt *sql.Stmt + selectQueueReferenceJSONCountStmt *sql.Stmt + selectQueuePDUsCountStmt *sql.Stmt + selectQueueServerNamesStmt *sql.Stmt + // deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic +} + +func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { + s = &queuePDUsStatements{ + db: db, + } + _, err = db.Exec(queuePDUsSchema) + if err != nil { + return + } + if s.insertQueuePDUStmt, err = db.Prepare(insertQueuePDUSQL); err != nil { + return + } + //if s.deleteQueuePDUsStmt, err = db.Prepare(deleteQueuePDUsSQL); err != nil { + // return + //} + if s.selectQueueNextTransactionIDStmt, err = db.Prepare(selectQueueNextTransactionIDSQL); err != nil { + return + } + if s.selectQueuePDUsStmt, err = db.Prepare(selectQueuePDUsSQL); err != nil { + return + } + if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil { + return + } + if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil { + return + } + if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil { + return + } + return +} + +func (s *queuePDUsStatements) InsertQueuePDU( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt) + _, err := stmt.ExecContext( + ctx, + transactionID, // the transaction ID that we initially attempted + serverName, // destination server name + nid, // JSON blob NID + ) + return err +} + +func (s *queuePDUsStatements) DeleteQueuePDUs( + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + deleteSQL := strings.Replace(deleteQueuePDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) + deleteStmt, err := txn.Prepare(deleteSQL) + if err != nil { + return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err) + } + + params := make([]interface{}, len(jsonNIDs)+1) + params[0] = serverName + for k, v := range jsonNIDs { + params[k+1] = v + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (gomatrixserverlib.TransactionID, error) { + var transactionID gomatrixserverlib.TransactionID + stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID) + if err == sql.ErrNoRows { + return "", nil + } + return transactionID, err +} + +func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( + ctx context.Context, txn *sql.Tx, jsonNID int64, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt) + err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) + if err == sql.ErrNoRows { + return -1, nil + } + return count, err +} + +func (s *queuePDUsStatements) SelectQueuePDUCount( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} + +func (s *queuePDUsStatements) SelectQueuePDUs( + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt) + rows, err := stmt.QueryContext(ctx, serverName, limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + + return result, rows.Err() +} + +func (s *queuePDUsStatements) SelectQueuePDUServerNames( + ctx context.Context, txn *sql.Tx, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt) + rows, err := stmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []gomatrixserverlib.ServerName + for rows.Next() { + var serverName gomatrixserverlib.ServerName + if err = rows.Scan(&serverName); err != nil { + return nil, err + } + result = append(result, serverName) + } + + return result, rows.Err() +} diff --git a/federationsender/storage/cosmosdb/storage.go b/federationsender/storage/cosmosdb/storage.go new file mode 100644 index 000000000..da429046b --- /dev/null +++ b/federationsender/storage/cosmosdb/storage.go @@ -0,0 +1,95 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "database/sql" + + _ "github.com/mattn/go-sqlite3" + + "github.com/matrix-org/dendrite/federationsender/storage/shared" + "github.com/matrix-org/dendrite/federationsender/storage/sqlite3/deltas" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" +) + +// Database stores information needed by the federation sender +type Database struct { + shared.Database + sqlutil.PartitionOffsetStatements + db *sql.DB + writer sqlutil.Writer +} + +// NewDatabase opens a new database +func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (*Database, error) { + var d Database + var err error + if d.db, err = sqlutil.Open(dbProperties); err != nil { + return nil, err + } + d.writer = sqlutil.NewExclusiveWriter() + joinedHosts, err := NewSQLiteJoinedHostsTable(d.db) + if err != nil { + return nil, err + } + queuePDUs, err := NewSQLiteQueuePDUsTable(d.db) + if err != nil { + return nil, err + } + queueEDUs, err := NewSQLiteQueueEDUsTable(d.db) + if err != nil { + return nil, err + } + queueJSON, err := NewSQLiteQueueJSONTable(d.db) + if err != nil { + return nil, err + } + blacklist, err := NewSQLiteBlacklistTable(d.db) + if err != nil { + return nil, err + } + outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db) + if err != nil { + return nil, err + } + inboundPeeks, err := NewSQLiteInboundPeeksTable(d.db) + if err != nil { + return nil, err + } + m := sqlutil.NewMigrations() + deltas.LoadRemoveRoomsTable(m) + if err = m.RunDeltas(d.db, dbProperties); err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + Cache: cache, + Writer: d.writer, + FederationSenderJoinedHosts: joinedHosts, + FederationSenderQueuePDUs: queuePDUs, + FederationSenderQueueEDUs: queueEDUs, + FederationSenderQueueJSON: queueJSON, + FederationSenderBlacklist: blacklist, + FederationSenderOutboundPeeks: outboundPeeks, + FederationSenderInboundPeeks: inboundPeeks, + } + if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "federationsender"); err != nil { + return nil, err + } + return &d, nil +} diff --git a/federationsender/storage/storage.go b/federationsender/storage/storage.go index 5462c3523..5402ff16e 100644 --- a/federationsender/storage/storage.go +++ b/federationsender/storage/storage.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/federationsender/storage/postgres" "github.com/matrix-org/dendrite/federationsender/storage/sqlite3" + "github.com/matrix-org/dendrite/federationsender/storage/cosmosdb" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/setup/config" ) @@ -28,6 +29,8 @@ import ( // NewDatabase opens a new database func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (Database, error) { switch { + case dbProperties.ConnectionString.IsCosmosDB(): + return cosmosdb.NewDatabase(dbProperties, cache) case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(dbProperties, cache) case dbProperties.ConnectionString.IsPostgres(): diff --git a/federationsender/storage/storage_wasm.go b/federationsender/storage/storage_wasm.go index bc52bd9bb..5364c74f8 100644 --- a/federationsender/storage/storage_wasm.go +++ b/federationsender/storage/storage_wasm.go @@ -25,6 +25,8 @@ import ( // NewDatabase opens a new database func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (Database, error) { switch { + case dbProperties.ConnectionString.IsCosmosDB(): + return nil, fmt.Errorf("can't use CosmosDB implementation") case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(dbProperties, cache) case dbProperties.ConnectionString.IsPostgres(): diff --git a/internal/cosmosdbutil/connection.go b/internal/cosmosdbutil/connection.go new file mode 100644 index 000000000..6e96cb9d7 --- /dev/null +++ b/internal/cosmosdbutil/connection.go @@ -0,0 +1,12 @@ +package cosmosdbutil + +import ( + "github.com/matrix-org/dendrite/setup/config" + "strings" +) + +func GetConnectionString(d *config.DataSource) config.DataSource { + var connString string + connString = string(*d) + return config.DataSource(strings.Replace(connString, "cosmosdb:", "file:", 1)) +} \ No newline at end of file diff --git a/internal/sqlutil/migrate.go b/internal/sqlutil/migrate.go index 62b1c8fad..e08f5ac73 100644 --- a/internal/sqlutil/migrate.go +++ b/internal/sqlutil/migrate.go @@ -48,10 +48,15 @@ func (m *Migrations) RunDeltas(db *sql.DB, props *config.DatabaseOptions) error if err != nil { return fmt.Errorf("RunDeltas: Failed to collect migrations: %w", err) } - if props.ConnectionString.IsPostgres() { + if props.ConnectionString.IsCosmosDB() { if err = goose.SetDialect("postgres"); err != nil { return err } + } else if props.ConnectionString.IsPostgres() { + //HACK: Not supported + if err = goose.SetDialect("cosmosdb"); err != nil { + return err + } } else if props.ConnectionString.IsSQLite() { if err = goose.SetDialect("sqlite3"); err != nil { return err diff --git a/internal/sqlutil/trace.go b/internal/sqlutil/trace.go index ad0044559..7489f03ce 100644 --- a/internal/sqlutil/trace.go +++ b/internal/sqlutil/trace.go @@ -103,6 +103,11 @@ func Open(dbProperties *config.DatabaseOptions) (*sql.DB, error) { var err error var driverName, dsn string switch { + case dbProperties.ConnectionString.IsCosmosDB(): + //HACK: Not supported + driverName = "cosmosdb" + dsn = string(dbProperties.ConnectionString) + return nil, fmt.Errorf("CosmosDB %q", dbProperties.ConnectionString) case dbProperties.ConnectionString.IsSQLite(): driverName = SQLiteDriverName() dsn, err = ParseFileURI(dbProperties.ConnectionString) diff --git a/keyserver/storage/cosmosdb/device_keys_table.go b/keyserver/storage/cosmosdb/device_keys_table.go new file mode 100644 index 000000000..67d4da201 --- /dev/null +++ b/keyserver/storage/cosmosdb/device_keys_table.go @@ -0,0 +1,199 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "strings" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/storage/tables" +) + +var deviceKeysSchema = ` +-- Stores device keys for users +CREATE TABLE IF NOT EXISTS keyserver_device_keys ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + stream_id BIGINT NOT NULL, + display_name TEXT, + -- Clobber based on tuple of user/device. + UNIQUE (user_id, device_id) +); +` + +const upsertDeviceKeysSQL = "" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT (user_id, device_id)" + + " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" + +const selectDeviceKeysSQL = "" + + "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + +const selectBatchDeviceKeysSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" + +const selectMaxStreamForUserSQL = "" + + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" + +const countStreamIDsForUserSQL = "" + + "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" + +const deleteAllDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1" + +type deviceKeysStatements struct { + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt +} + +func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { + s := &deviceKeysStatements{ + db: db, + } + _, err := db.Exec(deviceKeysSchema) + if err != nil { + return nil, err + } + if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil { + return nil, err + } + if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil { + return nil, err + } + if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { + return nil, err + } + if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { + return nil, err + } + if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) + return err +} + +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { + deviceIDMap := make(map[string]bool) + for _, d := range deviceIDs { + deviceIDMap[d] = true + } + rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") + var result []api.DeviceMessage + for rows.Next() { + var dk api.DeviceMessage + dk.UserID = userID + var keyJSON string + var streamID int + var displayName sql.NullString + if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { + return nil, err + } + dk.KeyJSON = []byte(keyJSON) + dk.StreamID = streamID + if displayName.Valid { + dk.DisplayName = displayName.String + } + // include the key if we want all keys (no device) or it was asked + if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { + result = append(result, dk) + } + } + return result, rows.Err() +} + +func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { + for i, key := range keys { + var keyJSONStr string + var streamID int + var displayName sql.NullString + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) + if err != nil && err != sql.ErrNoRows { + return err + } + // this will be '' when there is no device + keys[i].KeyJSON = []byte(keyJSONStr) + keys[i].StreamID = streamID + if displayName.Valid { + keys[i].DisplayName = displayName.String + } + } + return nil +} + +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { + // nullable if there are no results + var nullStream sql.NullInt32 + err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + if err == sql.ErrNoRows { + err = nil + } + if nullStream.Valid { + streamID = nullStream.Int32 + } + return +} + +func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) { + iStreamIDs := make([]interface{}, len(streamIDs)+1) + iStreamIDs[0] = userID + for i := range streamIDs { + iStreamIDs[i+1] = streamIDs[i] + } + query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1) + // nullable if there are no results + var count sql.NullInt32 + err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count) + if err != nil { + return 0, err + } + if count.Valid { + return int(count.Int32), nil + } + return 0, nil +} + +func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { + for _, key := range keys { + now := time.Now().Unix() + _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, + ) + if err != nil { + return err + } + } + return nil +} diff --git a/keyserver/storage/cosmosdb/key_changes_table.go b/keyserver/storage/cosmosdb/key_changes_table.go new file mode 100644 index 000000000..08eef3619 --- /dev/null +++ b/keyserver/storage/cosmosdb/key_changes_table.go @@ -0,0 +1,104 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "math" + + "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver/storage/tables" +) + +var keyChangesSchema = ` +-- Stores key change information about users. Used to determine when to send updated device lists to clients. +CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + partition BIGINT NOT NULL, + offset BIGINT NOT NULL, + -- The key owner + user_id TEXT NOT NULL, + UNIQUE (partition, offset) +); +` + +// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped. +// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will +// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too. +const upsertKeyChangeSQL = "" + + "INSERT INTO keyserver_key_changes (partition, offset, user_id)" + + " VALUES ($1, $2, $3)" + + " ON CONFLICT (partition, offset)" + + " DO UPDATE SET user_id = $3" + +// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just +// take the max offset value as the latest offset. +const selectKeyChangesSQL = "" + + "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id" + +type keyChangesStatements struct { + db *sql.DB + upsertKeyChangeStmt *sql.Stmt + selectKeyChangesStmt *sql.Stmt +} + +func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { + s := &keyChangesStatements{ + db: db, + } + _, err := db.Exec(keyChangesSchema) + if err != nil { + return nil, err + } + if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil { + return nil, err + } + if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { + _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) + return err +} + +func (s *keyChangesStatements) SelectKeyChanges( + ctx context.Context, partition int32, fromOffset, toOffset int64, +) (userIDs []string, latestOffset int64, err error) { + if toOffset == sarama.OffsetNewest { + toOffset = math.MaxInt64 + } + latestOffset = fromOffset + rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset) + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed") + for rows.Next() { + var userID string + var offset int64 + if err := rows.Scan(&userID, &offset); err != nil { + return nil, 0, err + } + if offset > latestOffset { + latestOffset = offset + } + userIDs = append(userIDs, userID) + } + return +} diff --git a/keyserver/storage/cosmosdb/one_time_keys_table.go b/keyserver/storage/cosmosdb/one_time_keys_table.go new file mode 100644 index 000000000..942c7532a --- /dev/null +++ b/keyserver/storage/cosmosdb/one_time_keys_table.go @@ -0,0 +1,203 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "encoding/json" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/storage/tables" +) + +var oneTimeKeysSchema = ` +-- Stores one-time public keys for users +CREATE TABLE IF NOT EXISTS keyserver_one_time_keys ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + key_id TEXT NOT NULL, + algorithm TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + -- Clobber based on 4-uple of user/device/key/algorithm. + UNIQUE (user_id, device_id, key_id, algorithm) +); +` + +const upsertKeysSQL = "" + + "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT (user_id, device_id, key_id, algorithm)" + + " DO UPDATE SET key_json = $6" + +const selectKeysSQL = "" + + "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" + +const selectKeysCountSQL = "" + + "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" + +const deleteOneTimeKeySQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4" + +const selectKeyByAlgorithmSQL = "" + + "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" + +type oneTimeKeysStatements struct { + db *sql.DB + upsertKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysCountStmt *sql.Stmt + selectKeyByAlgorithmStmt *sql.Stmt + deleteOneTimeKeyStmt *sql.Stmt +} + +func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { + s := &oneTimeKeysStatements{ + db: db, + } + _, err := db.Exec(oneTimeKeysSchema) + if err != nil { + return nil, err + } + if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil { + return nil, err + } + if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { + return nil, err + } + if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { + return nil, err + } + if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil { + return nil, err + } + if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed") + + wantSet := make(map[string]bool, len(keyIDsWithAlgorithms)) + for _, ka := range keyIDsWithAlgorithms { + wantSet[ka] = true + } + + result := make(map[string]json.RawMessage) + for rows.Next() { + var keyID string + var algorithm string + var keyJSONStr string + if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil { + return nil, err + } + keyIDWithAlgo := algorithm + ":" + keyID + if wantSet[keyIDWithAlgo] { + result[keyIDWithAlgo] = json.RawMessage(keyJSONStr) + } + } + return result, rows.Err() +} + +func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + counts := &api.OneTimeKeysCount{ + DeviceID: deviceID, + UserID: userID, + KeyCount: make(map[string]int), + } + rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + return counts, nil +} + +func (s *oneTimeKeysStatements) InsertOneTimeKeys( + ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys, +) (*api.OneTimeKeysCount, error) { + now := time.Now().Unix() + counts := &api.OneTimeKeysCount{ + DeviceID: keys.DeviceID, + UserID: keys.UserID, + KeyCount: make(map[string]int), + } + for keyIDWithAlgo, keyJSON := range keys.KeyJSON { + algo, keyID := keys.Split(keyIDWithAlgo) + _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext( + ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), + ) + if err != nil { + return nil, err + } + } + rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + + return counts, rows.Err() +} + +func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( + ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string, +) (map[string]json.RawMessage, error) { + var keyID string + var keyJSON string + err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + if err != nil { + return nil, err + } + if keyJSON == "" { + return nil, nil + } + return map[string]json.RawMessage{ + algorithm + ":" + keyID: json.RawMessage(keyJSON), + }, err +} diff --git a/keyserver/storage/cosmosdb/stale_device_lists.go b/keyserver/storage/cosmosdb/stale_device_lists.go new file mode 100644 index 000000000..2c4e0d8e2 --- /dev/null +++ b/keyserver/storage/cosmosdb/stale_device_lists.go @@ -0,0 +1,121 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +var staleDeviceListsSchema = ` +-- Stores whether a user's device lists are stale or not. +CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( + user_id TEXT PRIMARY KEY NOT NULL, + domain TEXT NOT NULL, + is_stale BOOLEAN NOT NULL, + ts_added_secs BIGINT NOT NULL +); + +CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); +` + +const upsertStaleDeviceListSQL = "" + + "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (user_id)" + + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" + +const selectStaleDeviceListsWithDomainsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" + +const selectStaleDeviceListsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" + +type staleDeviceListsStatements struct { + db *sql.DB + upsertStaleDeviceListStmt *sql.Stmt + selectStaleDeviceListsWithDomainsStmt *sql.Stmt + selectStaleDeviceListsStmt *sql.Stmt +} + +func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { + s := &staleDeviceListsStatements{ + db: db, + } + _, err := db.Exec(staleDeviceListsSchema) + if err != nil { + return nil, err + } + if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { + return nil, err + } + if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { + return nil, err + } + if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + return err +} + +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + // we only query for 1 domain or all domains so optimise for those use cases + if len(domains) == 0 { + rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) + if err != nil { + return nil, err + } + return rowsToUserIDs(ctx, rows) + } + var result []string + for _, domain := range domains { + rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) + if err != nil { + return nil, err + } + userIDs, err := rowsToUserIDs(ctx, rows) + if err != nil { + return nil, err + } + result = append(result, userIDs...) + } + return result, nil +} + +func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { + defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/keyserver/storage/cosmosdb/storage.go b/keyserver/storage/cosmosdb/storage.go new file mode 100644 index 000000000..ba000cb24 --- /dev/null +++ b/keyserver/storage/cosmosdb/storage.go @@ -0,0 +1,52 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/keyserver/storage/shared" + "github.com/matrix-org/dendrite/setup/config" +) + +func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) { + db, err := sqlutil.Open(dbProperties) + if err != nil { + return nil, err + } + otk, err := NewSqliteOneTimeKeysTable(db) + if err != nil { + return nil, err + } + dk, err := NewSqliteDeviceKeysTable(db) + if err != nil { + return nil, err + } + kc, err := NewSqliteKeyChangesTable(db) + if err != nil { + return nil, err + } + sdl, err := NewSqliteStaleDeviceListsTable(db) + if err != nil { + return nil, err + } + return &shared.Database{ + DB: db, + Writer: sqlutil.NewExclusiveWriter(), + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, + }, nil +} diff --git a/keyserver/storage/storage.go b/keyserver/storage/storage.go index 8f05d0030..3a7b37a1c 100644 --- a/keyserver/storage/storage.go +++ b/keyserver/storage/storage.go @@ -19,6 +19,7 @@ package storage import ( "fmt" + "github.com/matrix-org/dendrite/keyserver/storage/cosmosdb" "github.com/matrix-org/dendrite/keyserver/storage/postgres" "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" "github.com/matrix-org/dendrite/setup/config" @@ -28,6 +29,8 @@ import ( // and sets postgres connection parameters func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) { switch { + case dbProperties.ConnectionString.IsCosmosDB(): + return cosmosdb.NewDatabase(dbProperties) case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(dbProperties) case dbProperties.ConnectionString.IsPostgres(): diff --git a/keyserver/storage/storage_wasm.go b/keyserver/storage/storage_wasm.go index 8b31bfd01..93d19a24a 100644 --- a/keyserver/storage/storage_wasm.go +++ b/keyserver/storage/storage_wasm.go @@ -23,6 +23,8 @@ import ( func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) { switch { + case dbProperties.ConnectionString.IsCosmosDB(): + return nil, fmt.Errorf("can't use CosmosDB implementation") case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(dbProperties) case dbProperties.ConnectionString.IsPostgres(): diff --git a/mediaapi/storage/cosmosdb/media_repository_table.go b/mediaapi/storage/cosmosdb/media_repository_table.go new file mode 100644 index 000000000..b4f1b40fc --- /dev/null +++ b/mediaapi/storage/cosmosdb/media_repository_table.go @@ -0,0 +1,150 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/mediaapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const mediaSchema = ` +-- The media_repository table holds metadata for each media file stored and accessible to the local server, +-- the actual file is stored separately. +CREATE TABLE IF NOT EXISTS mediaapi_media_repository ( + -- The id used to refer to the media. + -- For uploads to this server this is a base64-encoded sha256 hash of the file data + -- For media from remote servers, this can be any unique identifier string + media_id TEXT NOT NULL, + -- The origin of the media as requested by the client. Should be a homeserver domain. + media_origin TEXT NOT NULL, + -- The MIME-type of the media file as specified when uploading. + content_type TEXT NOT NULL, + -- Size of the media file in bytes. + file_size_bytes INTEGER NOT NULL, + -- When the content was uploaded in UNIX epoch ms. + creation_ts INTEGER NOT NULL, + -- The file name with which the media was uploaded. + upload_name TEXT NOT NULL, + -- Alternate RFC 4648 unpadded base64 encoding string representation of a SHA-256 hash sum of the file data. + base64hash TEXT NOT NULL, + -- The user who uploaded the file. Should be a Matrix user ID. + user_id TEXT NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS mediaapi_media_repository_index ON mediaapi_media_repository (media_id, media_origin); +` + +const insertMediaSQL = ` +INSERT INTO mediaapi_media_repository (media_id, media_origin, content_type, file_size_bytes, creation_ts, upload_name, base64hash, user_id) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) +` + +const selectMediaSQL = ` +SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user_id FROM mediaapi_media_repository WHERE media_id = $1 AND media_origin = $2 +` + +const selectMediaByHashSQL = ` +SELECT content_type, file_size_bytes, creation_ts, upload_name, media_id, user_id FROM mediaapi_media_repository WHERE base64hash = $1 AND media_origin = $2 +` + +type mediaStatements struct { + db *sql.DB + writer sqlutil.Writer + insertMediaStmt *sql.Stmt + selectMediaStmt *sql.Stmt + selectMediaByHashStmt *sql.Stmt +} + +func (s *mediaStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { + s.db = db + s.writer = writer + + _, err = db.Exec(mediaSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertMediaStmt, insertMediaSQL}, + {&s.selectMediaStmt, selectMediaSQL}, + {&s.selectMediaByHashStmt, selectMediaByHashSQL}, + }.prepare(db) +} + +func (s *mediaStatements) insertMedia( + ctx context.Context, mediaMetadata *types.MediaMetadata, +) error { + mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertMediaStmt) + _, err := stmt.ExecContext( + ctx, + mediaMetadata.MediaID, + mediaMetadata.Origin, + mediaMetadata.ContentType, + mediaMetadata.FileSizeBytes, + mediaMetadata.CreationTimestamp, + mediaMetadata.UploadName, + mediaMetadata.Base64Hash, + mediaMetadata.UserID, + ) + return err + }) +} + +func (s *mediaStatements) selectMedia( + ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, +) (*types.MediaMetadata, error) { + mediaMetadata := types.MediaMetadata{ + MediaID: mediaID, + Origin: mediaOrigin, + } + err := s.selectMediaStmt.QueryRowContext( + ctx, mediaMetadata.MediaID, mediaMetadata.Origin, + ).Scan( + &mediaMetadata.ContentType, + &mediaMetadata.FileSizeBytes, + &mediaMetadata.CreationTimestamp, + &mediaMetadata.UploadName, + &mediaMetadata.Base64Hash, + &mediaMetadata.UserID, + ) + return &mediaMetadata, err +} + +func (s *mediaStatements) selectMediaByHash( + ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, +) (*types.MediaMetadata, error) { + mediaMetadata := types.MediaMetadata{ + Base64Hash: mediaHash, + Origin: mediaOrigin, + } + err := s.selectMediaStmt.QueryRowContext( + ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin, + ).Scan( + &mediaMetadata.ContentType, + &mediaMetadata.FileSizeBytes, + &mediaMetadata.CreationTimestamp, + &mediaMetadata.UploadName, + &mediaMetadata.MediaID, + &mediaMetadata.UserID, + ) + return &mediaMetadata, err +} diff --git a/mediaapi/storage/cosmosdb/prepare.go b/mediaapi/storage/cosmosdb/prepare.go new file mode 100644 index 000000000..930416d28 --- /dev/null +++ b/mediaapi/storage/cosmosdb/prepare.go @@ -0,0 +1,38 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// FIXME: This should be made internal! + +package cosmosdb + +import ( + "database/sql" +) + +// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. +type statementList []struct { + statement **sql.Stmt + sql string +} + +// prepare the SQL for each statement in the list and assign the result to the prepared statement. +func (s statementList) prepare(db *sql.DB) (err error) { + for _, statement := range s { + if *statement.statement, err = db.Prepare(statement.sql); err != nil { + return + } + } + return +} diff --git a/mediaapi/storage/cosmosdb/sql.go b/mediaapi/storage/cosmosdb/sql.go new file mode 100644 index 000000000..df63ba800 --- /dev/null +++ b/mediaapi/storage/cosmosdb/sql.go @@ -0,0 +1,38 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +type statements struct { + media mediaStatements + thumbnail thumbnailStatements +} + +func (s *statements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { + if err = s.media.prepare(db, writer); err != nil { + return + } + if err = s.thumbnail.prepare(db, writer); err != nil { + return + } + + return +} diff --git a/mediaapi/storage/cosmosdb/storage.go b/mediaapi/storage/cosmosdb/storage.go new file mode 100644 index 000000000..b05373868 --- /dev/null +++ b/mediaapi/storage/cosmosdb/storage.go @@ -0,0 +1,124 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + + // Import the postgres database driver. + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/mediaapi/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + _ "github.com/mattn/go-sqlite3" +) + +// Database is used to store metadata about a repository of media files. +type Database struct { + statements statements + db *sql.DB + writer sqlutil.Writer +} + +// Open opens a postgres database. +func Open(dbProperties *config.DatabaseOptions) (*Database, error) { + d := Database{ + writer: sqlutil.NewExclusiveWriter(), + } + var err error + if d.db, err = sqlutil.Open(dbProperties); err != nil { + return nil, err + } + if err = d.statements.prepare(d.db, d.writer); err != nil { + return nil, err + } + return &d, nil +} + +// StoreMediaMetadata inserts the metadata about the uploaded media into the database. +// Returns an error if the combination of MediaID and Origin are not unique in the table. +func (d *Database) StoreMediaMetadata( + ctx context.Context, mediaMetadata *types.MediaMetadata, +) error { + return d.statements.media.insertMedia(ctx, mediaMetadata) +} + +// GetMediaMetadata returns metadata about media stored on this server. +// The media could have been uploaded to this server or fetched from another server and cached here. +// Returns nil metadata if there is no metadata associated with this media. +func (d *Database) GetMediaMetadata( + ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, +) (*types.MediaMetadata, error) { + mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin) + if err != nil && err == sql.ErrNoRows { + return nil, nil + } + return mediaMetadata, err +} + +// GetMediaMetadataByHash returns metadata about media stored on this server. +// The media could have been uploaded to this server or fetched from another server and cached here. +// Returns nil metadata if there is no metadata associated with this media. +func (d *Database) GetMediaMetadataByHash( + ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, +) (*types.MediaMetadata, error) { + mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin) + if err != nil && err == sql.ErrNoRows { + return nil, nil + } + return mediaMetadata, err +} + +// StoreThumbnail inserts the metadata about the thumbnail into the database. +// Returns an error if the combination of MediaID and Origin are not unique in the table. +func (d *Database) StoreThumbnail( + ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, +) error { + return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata) +} + +// GetThumbnail returns metadata about a specific thumbnail. +// The media could have been uploaded to this server or fetched from another server and cached here. +// Returns nil metadata if there is no metadata associated with this thumbnail. +func (d *Database) GetThumbnail( + ctx context.Context, + mediaID types.MediaID, + mediaOrigin gomatrixserverlib.ServerName, + width, height int, + resizeMethod string, +) (*types.ThumbnailMetadata, error) { + thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail( + ctx, mediaID, mediaOrigin, width, height, resizeMethod, + ) + if err != nil && err == sql.ErrNoRows { + return nil, nil + } + return thumbnailMetadata, err +} + +// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server. +// The media could have been uploaded to this server or fetched from another server and cached here. +// Returns nil metadata if there are no thumbnails associated with this media. +func (d *Database) GetThumbnails( + ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, +) ([]*types.ThumbnailMetadata, error) { + thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin) + if err != nil && err == sql.ErrNoRows { + return nil, nil + } + return thumbnails, err +} diff --git a/mediaapi/storage/cosmosdb/thumbnail_table.go b/mediaapi/storage/cosmosdb/thumbnail_table.go new file mode 100644 index 000000000..cc5f00ddc --- /dev/null +++ b/mediaapi/storage/cosmosdb/thumbnail_table.go @@ -0,0 +1,171 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/mediaapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const thumbnailSchema = ` +-- The mediaapi_thumbnail table holds metadata for each thumbnail file stored and accessible to the local server, +-- the actual file is stored separately. +CREATE TABLE IF NOT EXISTS mediaapi_thumbnail ( + media_id TEXT NOT NULL, + media_origin TEXT NOT NULL, + content_type TEXT NOT NULL, + file_size_bytes INTEGER NOT NULL, + creation_ts INTEGER NOT NULL, + width INTEGER NOT NULL, + height INTEGER NOT NULL, + resize_method TEXT NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS mediaapi_thumbnail_index ON mediaapi_thumbnail (media_id, media_origin, width, height, resize_method); +` + +const insertThumbnailSQL = ` +INSERT INTO mediaapi_thumbnail (media_id, media_origin, content_type, file_size_bytes, creation_ts, width, height, resize_method) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) +` + +// Note: this selects one specific thumbnail +const selectThumbnailSQL = ` +SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 AND width = $3 AND height = $4 AND resize_method = $5 +` + +// Note: this selects all thumbnails for a media_origin and media_id +const selectThumbnailsSQL = ` +SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 +` + +type thumbnailStatements struct { + db *sql.DB + writer sqlutil.Writer + insertThumbnailStmt *sql.Stmt + selectThumbnailStmt *sql.Stmt + selectThumbnailsStmt *sql.Stmt +} + +func (s *thumbnailStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { + _, err = db.Exec(thumbnailSchema) + if err != nil { + return + } + s.db = db + s.writer = writer + + return statementList{ + {&s.insertThumbnailStmt, insertThumbnailSQL}, + {&s.selectThumbnailStmt, selectThumbnailSQL}, + {&s.selectThumbnailsStmt, selectThumbnailsSQL}, + }.prepare(db) +} + +func (s *thumbnailStatements) insertThumbnail( + ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, +) error { + thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt) + _, err := stmt.ExecContext( + ctx, + thumbnailMetadata.MediaMetadata.MediaID, + thumbnailMetadata.MediaMetadata.Origin, + thumbnailMetadata.MediaMetadata.ContentType, + thumbnailMetadata.MediaMetadata.FileSizeBytes, + thumbnailMetadata.MediaMetadata.CreationTimestamp, + thumbnailMetadata.ThumbnailSize.Width, + thumbnailMetadata.ThumbnailSize.Height, + thumbnailMetadata.ThumbnailSize.ResizeMethod, + ) + return err + }) +} + +func (s *thumbnailStatements) selectThumbnail( + ctx context.Context, + mediaID types.MediaID, + mediaOrigin gomatrixserverlib.ServerName, + width, height int, + resizeMethod string, +) (*types.ThumbnailMetadata, error) { + thumbnailMetadata := types.ThumbnailMetadata{ + MediaMetadata: &types.MediaMetadata{ + MediaID: mediaID, + Origin: mediaOrigin, + }, + ThumbnailSize: types.ThumbnailSize{ + Width: width, + Height: height, + ResizeMethod: resizeMethod, + }, + } + err := s.selectThumbnailStmt.QueryRowContext( + ctx, + thumbnailMetadata.MediaMetadata.MediaID, + thumbnailMetadata.MediaMetadata.Origin, + thumbnailMetadata.ThumbnailSize.Width, + thumbnailMetadata.ThumbnailSize.Height, + thumbnailMetadata.ThumbnailSize.ResizeMethod, + ).Scan( + &thumbnailMetadata.MediaMetadata.ContentType, + &thumbnailMetadata.MediaMetadata.FileSizeBytes, + &thumbnailMetadata.MediaMetadata.CreationTimestamp, + ) + return &thumbnailMetadata, err +} + +func (s *thumbnailStatements) selectThumbnails( + ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, +) ([]*types.ThumbnailMetadata, error) { + rows, err := s.selectThumbnailsStmt.QueryContext( + ctx, mediaID, mediaOrigin, + ) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectThumbnails: rows.close() failed") + + var thumbnails []*types.ThumbnailMetadata + for rows.Next() { + thumbnailMetadata := types.ThumbnailMetadata{ + MediaMetadata: &types.MediaMetadata{ + MediaID: mediaID, + Origin: mediaOrigin, + }, + } + err = rows.Scan( + &thumbnailMetadata.MediaMetadata.ContentType, + &thumbnailMetadata.MediaMetadata.FileSizeBytes, + &thumbnailMetadata.MediaMetadata.CreationTimestamp, + &thumbnailMetadata.ThumbnailSize.Width, + &thumbnailMetadata.ThumbnailSize.Height, + &thumbnailMetadata.ThumbnailSize.ResizeMethod, + ) + if err != nil { + return nil, err + } + thumbnails = append(thumbnails, &thumbnailMetadata) + } + + return thumbnails, rows.Err() +} diff --git a/mediaapi/storage/storage.go b/mediaapi/storage/storage.go index a976f795b..316bb1a89 100644 --- a/mediaapi/storage/storage.go +++ b/mediaapi/storage/storage.go @@ -19,6 +19,7 @@ package storage import ( "fmt" + "github.com/matrix-org/dendrite/mediaapi/storage/cosmosdb" "github.com/matrix-org/dendrite/mediaapi/storage/postgres" "github.com/matrix-org/dendrite/mediaapi/storage/sqlite3" "github.com/matrix-org/dendrite/setup/config" @@ -27,6 +28,8 @@ import ( // Open opens a postgres database. func Open(dbProperties *config.DatabaseOptions) (Database, error) { switch { + case dbProperties.ConnectionString.IsCosmosDB(): + return cosmosdb.Open(dbProperties) case dbProperties.ConnectionString.IsSQLite(): return sqlite3.Open(dbProperties) case dbProperties.ConnectionString.IsPostgres(): diff --git a/mediaapi/storage/storage_wasm.go b/mediaapi/storage/storage_wasm.go index a6e997b2a..0cfc560df 100644 --- a/mediaapi/storage/storage_wasm.go +++ b/mediaapi/storage/storage_wasm.go @@ -24,6 +24,8 @@ import ( // Open opens a postgres database. func Open(dbProperties *config.DatabaseOptions) (Database, error) { switch { + case dbProperties.ConnectionString.IsCosmosDB(): + return nil, fmt.Errorf("can't use CosmosDB implementation") case dbProperties.ConnectionString.IsSQLite(): return sqlite3.Open(dbProperties) case dbProperties.ConnectionString.IsPostgres(): diff --git a/roomserver/storage/cosmosdb/event_json_table.go b/roomserver/storage/cosmosdb/event_json_table.go new file mode 100644 index 000000000..05b6b1b62 --- /dev/null +++ b/roomserver/storage/cosmosdb/event_json_table.go @@ -0,0 +1,107 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const eventJSONSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_event_json ( + event_nid INTEGER NOT NULL PRIMARY KEY, + event_json TEXT NOT NULL + ); +` + +const insertEventJSONSQL = ` + INSERT OR REPLACE INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2) +` + +// Bulk event JSON lookup by numeric event ID. +// Sort by the numeric event ID. +// This means that we can use binary search to lookup by numeric event ID. +const bulkSelectEventJSONSQL = ` + SELECT event_nid, event_json FROM roomserver_event_json + WHERE event_nid IN ($1) + ORDER BY event_nid ASC +` + +type eventJSONStatements struct { + db *sql.DB + insertEventJSONStmt *sql.Stmt + bulkSelectEventJSONStmt *sql.Stmt +} + +func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) { + s := &eventJSONStatements{ + db: db, + } + _, err := db.Exec(eventJSONSchema) + if err != nil { + return nil, err + } + return s, shared.StatementList{ + {&s.insertEventJSONStmt, insertEventJSONSQL}, + {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, + }.Prepare(db) +} + +func (s *eventJSONStatements) InsertEventJSON( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, +) error { + _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) + return err +} + +func (s *eventJSONStatements) BulkSelectEventJSON( + ctx context.Context, eventNIDs []types.EventNID, +) ([]tables.EventJSONPair, error) { + iEventNIDs := make([]interface{}, len(eventNIDs)) + for k, v := range eventNIDs { + iEventNIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) + + rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventJSON: rows.close() failed") + + // We know that we will only get as many results as event NIDs + // because of the unique constraint on event NIDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than NIDs so we adjust the length of the slice before returning it. + results := make([]tables.EventJSONPair, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + var eventNID int64 + if err := rows.Scan(&eventNID, &result.EventJSON); err != nil { + return nil, err + } + result.EventNID = types.EventNID(eventNID) + } + return results[:i], nil +} diff --git a/roomserver/storage/cosmosdb/event_state_keys_table.go b/roomserver/storage/cosmosdb/event_state_keys_table.go new file mode 100644 index 000000000..a9307f68a --- /dev/null +++ b/roomserver/storage/cosmosdb/event_state_keys_table.go @@ -0,0 +1,163 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const eventStateKeysSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_event_state_keys ( + event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT, + event_state_key TEXT NOT NULL UNIQUE + ); + INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key) + VALUES (1, '') + ON CONFLICT DO NOTHING; +` + +// Same as insertEventTypeNIDSQL +const insertEventStateKeyNIDSQL = ` + INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1) + ON CONFLICT DO NOTHING; +` + +const selectEventStateKeyNIDSQL = ` + SELECT event_state_key_nid FROM roomserver_event_state_keys + WHERE event_state_key = $1 +` + +// Bulk lookup from string state key to numeric ID for that state key. +// Takes an array of strings as the query parameter. +const bulkSelectEventStateKeySQL = ` + SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys + WHERE event_state_key IN ($1) +` + +// Bulk lookup from numeric ID to string state key for that state key. +// Takes an array of strings as the query parameter. +const bulkSelectEventStateKeyNIDSQL = ` + SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys + WHERE event_state_key_nid IN ($1) +` + +type eventStateKeyStatements struct { + db *sql.DB + insertEventStateKeyNIDStmt *sql.Stmt + selectEventStateKeyNIDStmt *sql.Stmt + bulkSelectEventStateKeyNIDStmt *sql.Stmt + bulkSelectEventStateKeyStmt *sql.Stmt +} + +func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { + s := &eventStateKeyStatements{ + db: db, + } + _, err := db.Exec(eventStateKeysSchema) + if err != nil { + return nil, err + } + return s, shared.StatementList{ + {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, + {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, + {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, + {&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL}, + }.Prepare(db) +} + +func (s *eventStateKeyStatements) InsertEventStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (types.EventStateKeyNID, error) { + insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt) + res, err := insertStmt.ExecContext(ctx, eventStateKey) + if err != nil { + return 0, err + } + eventStateKeyNID, err := res.LastInsertId() + if err != nil { + return 0, err + } + return types.EventStateKeyNID(eventStateKeyNID), err +} + +func (s *eventStateKeyStatements) SelectEventStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (types.EventStateKeyNID, error) { + var eventStateKeyNID int64 + stmt := sqlutil.TxStmt(txn, s.selectEventStateKeyNIDStmt) + err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) + return types.EventStateKeyNID(eventStateKeyNID), err +} + +func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( + ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + iEventStateKeys := make([]interface{}, len(eventStateKeys)) + for k, v := range eventStateKeys { + iEventStateKeys[k] = v + } + selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1) + + rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKeyNID: rows.close() failed") + result := make(map[string]types.EventStateKeyNID, len(eventStateKeys)) + for rows.Next() { + var stateKey string + var stateKeyNID int64 + if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { + return nil, err + } + result[stateKey] = types.EventStateKeyNID(stateKeyNID) + } + return result, nil +} + +func (s *eventStateKeyStatements) BulkSelectEventStateKey( + ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, +) (map[types.EventStateKeyNID]string, error) { + iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs)) + for k, v := range eventStateKeyNIDs { + iEventStateKeyNIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1) + + rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKey: rows.close() failed") + result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs)) + for rows.Next() { + var stateKey string + var stateKeyNID int64 + if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { + return nil, err + } + result[types.EventStateKeyNID(stateKeyNID)] = stateKey + } + return result, nil +} diff --git a/roomserver/storage/cosmosdb/event_types_table.go b/roomserver/storage/cosmosdb/event_types_table.go new file mode 100644 index 000000000..a63b537ce --- /dev/null +++ b/roomserver/storage/cosmosdb/event_types_table.go @@ -0,0 +1,161 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const eventTypesSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_event_types ( + event_type_nid INTEGER PRIMARY KEY AUTOINCREMENT, + event_type TEXT NOT NULL UNIQUE + ); + INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES + (1, 'm.room.create'), + (2, 'm.room.power_levels'), + (3, 'm.room.join_rules'), + (4, 'm.room.third_party_invite'), + (5, 'm.room.member'), + (6, 'm.room.redaction'), + (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING; +` + +// Assign a new numeric event type ID. +// The usual case is that the event type is not in the database. +// In that case the ID will be assigned using the next value from the sequence. +// We use `RETURNING` to tell postgres to return the assigned ID. +// But it's possible that the type was added in a query that raced with us. +// This will result in a conflict on the event_type_unique constraint, in this +// case we do nothing. Postgresql won't return a row in that case so we rely on +// the caller catching the sql.ErrNoRows error and running a select to get the row. +// We could get postgresql to return the row on a conflict by updating the row +// but it doesn't seem like a good idea to modify the rows just to make postgresql +// return it. Modifying the rows will cause postgres to assign a new tuple for the +// row even though the data doesn't change resulting in unncesssary modifications +// to the indexes. +const insertEventTypeNIDSQL = ` + INSERT INTO roomserver_event_types (event_type) VALUES ($1) + ON CONFLICT DO NOTHING; +` + +const insertEventTypeNIDResultSQL = ` + SELECT event_type_nid FROM roomserver_event_types + WHERE rowid = last_insert_rowid(); +` + +const selectEventTypeNIDSQL = ` + SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1 +` + +// Bulk lookup from string event type to numeric ID for that event type. +// Takes an array of strings as the query parameter. +const bulkSelectEventTypeNIDSQL = ` + SELECT event_type, event_type_nid FROM roomserver_event_types + WHERE event_type IN ($1) +` + +type eventTypeStatements struct { + db *sql.DB + insertEventTypeNIDStmt *sql.Stmt + insertEventTypeNIDResultStmt *sql.Stmt + selectEventTypeNIDStmt *sql.Stmt + bulkSelectEventTypeNIDStmt *sql.Stmt +} + +func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) { + s := &eventTypeStatements{ + db: db, + } + _, err := db.Exec(eventTypesSchema) + if err != nil { + return nil, err + } + + return s, shared.StatementList{ + {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, + {&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL}, + {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, + {&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL}, + }.Prepare(db) +} + +func (s *eventTypeStatements) InsertEventTypeNID( + ctx context.Context, txn *sql.Tx, eventType string, +) (types.EventTypeNID, error) { + var eventTypeNID int64 + insertStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDStmt) + resultStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDResultStmt) + _, err := insertStmt.ExecContext(ctx, eventType) + if err != nil { + return 0, fmt.Errorf("insertStmt.ExecContext: %w", err) + } + if err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID); err != nil { + return 0, fmt.Errorf("resultStmt.QueryRowContext.Scan: %w", err) + } + return types.EventTypeNID(eventTypeNID), err +} + +func (s *eventTypeStatements) SelectEventTypeNID( + ctx context.Context, tx *sql.Tx, eventType string, +) (types.EventTypeNID, error) { + var eventTypeNID int64 + selectStmt := sqlutil.TxStmt(tx, s.selectEventTypeNIDStmt) + err := selectStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) + return types.EventTypeNID(eventTypeNID), err +} + +func (s *eventTypeStatements) BulkSelectEventTypeNID( + ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + /////////////// + iEventTypes := make([]interface{}, len(eventTypes)) + for k, v := range eventTypes { + iEventTypes[k] = v + } + selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventTypes)), 1) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + rows, err := selectPrep.QueryContext(ctx, iEventTypes...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventTypeNID: rows.close() failed") + + result := make(map[string]types.EventTypeNID, len(eventTypes)) + for rows.Next() { + var eventType string + var eventTypeNID int64 + if err := rows.Scan(&eventType, &eventTypeNID); err != nil { + return nil, err + } + result[eventType] = types.EventTypeNID(eventTypeNID) + } + return result, nil +} diff --git a/roomserver/storage/cosmosdb/events_table.go b/roomserver/storage/cosmosdb/events_table.go new file mode 100644 index 000000000..d8c83cad1 --- /dev/null +++ b/roomserver/storage/cosmosdb/events_table.go @@ -0,0 +1,515 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const eventsSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_events ( + event_nid INTEGER PRIMARY KEY AUTOINCREMENT, + room_nid INTEGER NOT NULL, + event_type_nid INTEGER NOT NULL, + event_state_key_nid INTEGER NOT NULL, + sent_to_output BOOLEAN NOT NULL DEFAULT FALSE, + state_snapshot_nid INTEGER NOT NULL DEFAULT 0, + depth INTEGER NOT NULL, + event_id TEXT NOT NULL UNIQUE, + reference_sha256 BLOB NOT NULL, + auth_event_nids TEXT NOT NULL DEFAULT '[]', + is_rejected BOOLEAN NOT NULL DEFAULT FALSE + ); +` + +const insertEventSQL = ` + INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT DO NOTHING; +` + +const selectEventSQL = "" + + "SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1" + +// Bulk lookup of events by string ID. +// Sort by the numeric IDs for event type and state key. +// This means we can use binary search to lookup entries by type and state key. +const bulkSelectStateEventByIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + + " WHERE event_id IN ($1)" + + " ORDER BY event_type_nid, event_state_key_nid ASC" + +const bulkSelectStateAtEventByIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + + " WHERE event_id IN ($1)" + +const updateEventStateSQL = "" + + "UPDATE roomserver_events SET state_snapshot_nid = $1 WHERE event_nid = $2" + +const selectEventSentToOutputSQL = "" + + "SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1" + +const updateEventSentToOutputSQL = "" + + "UPDATE roomserver_events SET sent_to_output = TRUE WHERE event_nid = $1" + +const selectEventIDSQL = "" + + "SELECT event_id FROM roomserver_events WHERE event_nid = $1" + +const bulkSelectStateAtEventAndReferenceSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" + + " FROM roomserver_events WHERE event_nid IN ($1)" + +const bulkSelectEventReferenceSQL = "" + + "SELECT event_id, reference_sha256 FROM roomserver_events WHERE event_nid IN ($1)" + +const bulkSelectEventIDSQL = "" + + "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid IN ($1)" + +const bulkSelectEventNIDSQL = "" + + "SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)" + +const selectMaxEventDepthSQL = "" + + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" + +const selectRoomNIDsForEventNIDsSQL = "" + + "SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid IN ($1)" + +type eventStatements struct { + db *sql.DB + insertEventStmt *sql.Stmt + selectEventStmt *sql.Stmt + bulkSelectStateEventByIDStmt *sql.Stmt + bulkSelectStateAtEventByIDStmt *sql.Stmt + updateEventStateStmt *sql.Stmt + selectEventSentToOutputStmt *sql.Stmt + updateEventSentToOutputStmt *sql.Stmt + selectEventIDStmt *sql.Stmt + bulkSelectStateAtEventAndReferenceStmt *sql.Stmt + bulkSelectEventReferenceStmt *sql.Stmt + bulkSelectEventIDStmt *sql.Stmt + bulkSelectEventNIDStmt *sql.Stmt + //selectRoomNIDsForEventNIDsStmt *sql.Stmt +} + +func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) { + s := &eventStatements{ + db: db, + } + _, err := db.Exec(eventsSchema) + if err != nil { + return nil, err + } + + return s, shared.StatementList{ + {&s.insertEventStmt, insertEventSQL}, + {&s.selectEventStmt, selectEventSQL}, + {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, + {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, + {&s.updateEventStateStmt, updateEventStateSQL}, + {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, + {&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL}, + {&s.selectEventIDStmt, selectEventIDSQL}, + {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, + {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, + {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, + {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, + //{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, + }.Prepare(db) +} + +func (s *eventStatements) InsertEvent( + ctx context.Context, + txn *sql.Tx, + roomNID types.RoomNID, + eventTypeNID types.EventTypeNID, + eventStateKeyNID types.EventStateKeyNID, + eventID string, + referenceSHA256 []byte, + authEventNIDs []types.EventNID, + depth int64, + isRejected bool, +) (types.EventNID, types.StateSnapshotNID, error) { + // attempt to insert: the last_row_id is the event NID + var eventNID int64 + insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) + result, err := insertStmt.ExecContext( + ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), + eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, isRejected, + ) + if err != nil { + return 0, 0, err + } + modified, err := result.RowsAffected() + if modified == 0 && err == nil { + return 0, 0, sql.ErrNoRows + } + eventNID, err = result.LastInsertId() + return types.EventNID(eventNID), 0, err +} + +func (s *eventStatements) SelectEvent( + ctx context.Context, txn *sql.Tx, eventID string, +) (types.EventNID, types.StateSnapshotNID, error) { + var eventNID int64 + var stateNID int64 + selectStmt := sqlutil.TxStmt(txn, s.selectEventStmt) + err := selectStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID) + return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err +} + +// bulkSelectStateEventByID lookups a list of state events by event ID. +// If any of the requested events are missing from the database it returns a types.MissingEventError +func (s *eventStatements) BulkSelectStateEventByID( + ctx context.Context, eventIDs []string, +) ([]types.StateEntry, error) { + /////////////// + iEventIDs := make([]interface{}, len(eventIDs)) + for k, v := range eventIDs { + iEventIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + selectStmt, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + rows, err := selectStmt.QueryContext(ctx, iEventIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed") + // We know that we will only get as many results as event IDs + // because of the unique constraint on event IDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than IDs so we adjust the length of the slice before returning it. + results := make([]types.StateEntry, len(eventIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventTypeNID, + &result.EventStateKeyNID, + &result.EventNID, + ); err != nil { + return nil, err + } + } + if i != len(eventIDs) { + // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have. + // We don't know which ones were missing because we don't return the string IDs in the query. + // However it should be possible debug this by replaying queries or entries from the input kafka logs. + // If this turns out to be impossible and we do need the debug information here, it would be better + // to do it as a separate query rather than slowing down/complicating the internal case. + return nil, types.MissingEventError( + fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)), + ) + } + return results, err +} + +// bulkSelectStateAtEventByID lookups the state at a list of events by event ID. +// If any of the requested events are missing from the database it returns a types.MissingEventError. +// If we do not have the state for any of the requested events it returns a types.MissingEventError. +func (s *eventStatements) BulkSelectStateAtEventByID( + ctx context.Context, eventIDs []string, +) ([]types.StateAtEvent, error) { + /////////////// + iEventIDs := make([]interface{}, len(eventIDs)) + for k, v := range eventIDs { + iEventIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + selectStmt, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + rows, err := selectStmt.QueryContext(ctx, iEventIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventByID: rows.close() failed") + results := make([]types.StateAtEvent, len(eventIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventTypeNID, + &result.EventStateKeyNID, + &result.EventNID, + &result.BeforeStateSnapshotNID, + &result.IsRejected, + ); err != nil { + return nil, err + } + if result.BeforeStateSnapshotNID == 0 { + return nil, types.MissingEventError( + fmt.Sprintf("storage: missing state for event NID %d", result.EventNID), + ) + } + } + if i != len(eventIDs) { + return nil, types.MissingEventError( + fmt.Sprintf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)), + ) + } + return results, err +} + +func (s *eventStatements) UpdateEventState( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID, +) error { + stmt := sqlutil.TxStmt(txn, s.updateEventStateStmt) + _, err := stmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) + return err +} + +func (s *eventStatements) SelectEventSentToOutput( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, +) (sentToOutput bool, err error) { + selectStmt := sqlutil.TxStmt(txn, s.selectEventSentToOutputStmt) + err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) + return +} + +func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { + updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt) + _, err := updateStmt.ExecContext(ctx, int64(eventNID)) + return err +} + +func (s *eventStatements) SelectEventID( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, +) (eventID string, err error) { + selectStmt := sqlutil.TxStmt(txn, s.selectEventIDStmt) + err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID) + return +} + +func (s *eventStatements) BulkSelectStateAtEventAndReference( + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, +) ([]types.StateAtEventAndReference, error) { + /////////////// + iEventNIDs := make([]interface{}, len(eventNIDs)) + for k, v := range eventNIDs { + iEventNIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + ////////////// + + rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...) + if err != nil { + return nil, fmt.Errorf("sqlutil.TxStmt.QueryContext: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed") + results := make([]types.StateAtEventAndReference, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + var ( + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + stateSnapshotNID int64 + eventID string + eventSHA256 []byte + ) + if err = rows.Scan( + &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256, + ); err != nil { + return nil, err + } + result := &results[i] + result.EventTypeNID = types.EventTypeNID(eventTypeNID) + result.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) + result.EventNID = types.EventNID(eventNID) + result.BeforeStateSnapshotNID = types.StateSnapshotNID(stateSnapshotNID) + result.EventID = eventID + result.EventSHA256 = eventSHA256 + } + if i != len(eventNIDs) { + return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) + } + return results, nil +} + +func (s *eventStatements) BulkSelectEventReference( + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, +) ([]gomatrixserverlib.EventReference, error) { + /////////////// + iEventNIDs := make([]interface{}, len(eventNIDs)) + for k, v := range eventNIDs { + iEventNIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + selectStmt := sqlutil.TxStmt(txn, selectPrep) + rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventReference: rows.close() failed") + results := make([]gomatrixserverlib.EventReference, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan(&result.EventID, &result.EventSHA256); err != nil { + return nil, err + } + } + if i != len(eventNIDs) { + return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) + } + return results, nil +} + +// bulkSelectEventID returns a map from numeric event ID to string event ID. +func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { + /////////////// + iEventNIDs := make([]interface{}, len(eventNIDs)) + for k, v := range eventNIDs { + iEventNIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) + selectStmt, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventID: rows.close() failed") + results := make(map[types.EventNID]string, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + var eventNID int64 + var eventID string + if err = rows.Scan(&eventNID, &eventID); err != nil { + return nil, err + } + results[types.EventNID(eventNID)] = eventID + } + if i != len(eventNIDs) { + return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) + } + return results, nil +} + +// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. +// If an event ID is not in the database then it is omitted from the map. +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { + /////////////// + iEventIDs := make([]interface{}, len(eventIDs)) + for k, v := range eventIDs { + iEventIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + selectStmt, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + rows, err := selectStmt.QueryContext(ctx, iEventIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed") + results := make(map[string]types.EventNID, len(eventIDs)) + for rows.Next() { + var eventID string + var eventNID int64 + if err = rows.Scan(&eventID, &eventNID); err != nil { + return nil, err + } + results[eventID] = types.EventNID(eventNID) + } + return results, nil +} + +func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) { + var result int64 + iEventIDs := make([]interface{}, len(eventNIDs)) + for i, v := range eventNIDs { + iEventIDs[i] = v + } + sqlStr := strings.Replace(selectMaxEventDepthSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + sqlPrep, err := s.db.Prepare(sqlStr) + if err != nil { + return 0, err + } + err = sqlutil.TxStmt(txn, sqlPrep).QueryRowContext(ctx, iEventIDs...).Scan(&result) + if err != nil { + return 0, fmt.Errorf("sqlutil.TxStmt.QueryRowContext: %w", err) + } + return result, nil +} + +func (s *eventStatements) SelectRoomNIDsForEventNIDs( + ctx context.Context, eventNIDs []types.EventNID, +) (map[types.EventNID]types.RoomNID, error) { + sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1) + sqlPrep, err := s.db.Prepare(sqlStr) + if err != nil { + return nil, err + } + iEventNIDs := make([]interface{}, len(eventNIDs)) + for i, v := range eventNIDs { + iEventNIDs[i] = v + } + rows, err := sqlPrep.QueryContext(ctx, iEventNIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed") + result := make(map[types.EventNID]types.RoomNID) + for rows.Next() { + var eventNID types.EventNID + var roomNID types.RoomNID + if err = rows.Scan(&eventNID, &roomNID); err != nil { + return nil, err + } + result[eventNID] = roomNID + } + return result, nil +} + +func eventNIDsAsArray(eventNIDs []types.EventNID) string { + b, _ := json.Marshal(eventNIDs) + return string(b) +} diff --git a/roomserver/storage/cosmosdb/invite_table.go b/roomserver/storage/cosmosdb/invite_table.go new file mode 100644 index 000000000..2e3bf328e --- /dev/null +++ b/roomserver/storage/cosmosdb/invite_table.go @@ -0,0 +1,159 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const inviteSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_invites ( + invite_event_id TEXT PRIMARY KEY, + room_nid INTEGER NOT NULL, + target_nid INTEGER NOT NULL, + sender_nid INTEGER NOT NULL DEFAULT 0, + retired BOOLEAN NOT NULL DEFAULT FALSE, + invite_event_json TEXT NOT NULL + ); + + CREATE INDEX IF NOT EXISTS roomserver_invites_active_idx ON roomserver_invites (target_nid, room_nid) + WHERE NOT retired; +` +const insertInviteEventSQL = "" + + "INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," + + " sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" + + " ON CONFLICT DO NOTHING" + +const selectInviteActiveForUserInRoomSQL = "" + + "SELECT invite_event_id, sender_nid FROM roomserver_invites" + + " WHERE target_nid = $1 AND room_nid = $2" + + " AND NOT retired" + +// Retire every active invite for a user in a room. +// Ideally we'd know which invite events were retired by a given update so we +// wouldn't need to remove every active invite. +// However the matrix protocol doesn't give us a way to reliably identify the +// invites that were retired, so we are forced to retire all of them. +const updateInviteRetiredSQL = ` + UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired +` + +const selectInvitesAboutToRetireSQL = ` +SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired +` + +type inviteStatements struct { + db *sql.DB + insertInviteEventStmt *sql.Stmt + selectInviteActiveForUserInRoomStmt *sql.Stmt + updateInviteRetiredStmt *sql.Stmt + selectInvitesAboutToRetireStmt *sql.Stmt +} + +func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) { + s := &inviteStatements{ + db: db, + } + _, err := db.Exec(inviteSchema) + if err != nil { + return nil, err + } + + return s, shared.StatementList{ + {&s.insertInviteEventStmt, insertInviteEventSQL}, + {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, + {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, + {&s.selectInvitesAboutToRetireStmt, selectInvitesAboutToRetireSQL}, + }.Prepare(db) +} + +func (s *inviteStatements) InsertInviteEvent( + ctx context.Context, + txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, + targetUserNID, senderUserNID types.EventStateKeyNID, + inviteEventJSON []byte, +) (bool, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) + result, err := stmt.ExecContext( + ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, + ) + if err != nil { + return false, err + } + count, err = result.RowsAffected() + if err != nil { + return false, err + } + return count != 0, err +} + +func (s *inviteStatements) UpdateInviteRetired( + ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, +) (eventIDs []string, err error) { + // gather all the event IDs we will retire + stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) + rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed") + for rows.Next() { + var inviteEventID string + if err = rows.Scan(&inviteEventID); err != nil { + return + } + eventIDs = append(eventIDs, inviteEventID) + } + // now retire the invites + stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) + _, err = stmt.ExecContext(ctx, roomNID, targetUserNID) + return +} + +// selectInviteActiveForUserInRoom returns a list of sender state key NIDs +func (s *inviteStatements) SelectInviteActiveForUserInRoom( + ctx context.Context, + targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, +) ([]types.EventStateKeyNID, []string, error) { + rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( + ctx, targetUserNID, roomNID, + ) + if err != nil { + return nil, nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed") + var result []types.EventStateKeyNID + var eventIDs []string + for rows.Next() { + var eventID string + var senderUserNID int64 + if err := rows.Scan(&eventID, &senderUserNID); err != nil { + return nil, nil, err + } + result = append(result, types.EventStateKeyNID(senderUserNID)) + eventIDs = append(eventIDs, eventID) + } + return result, eventIDs, nil +} diff --git a/roomserver/storage/cosmosdb/membership_table.go b/roomserver/storage/cosmosdb/membership_table.go new file mode 100644 index 000000000..a318d6caf --- /dev/null +++ b/roomserver/storage/cosmosdb/membership_table.go @@ -0,0 +1,306 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const membershipSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_membership ( + room_nid INTEGER NOT NULL, + target_nid INTEGER NOT NULL, + sender_nid INTEGER NOT NULL DEFAULT 0, + membership_nid INTEGER NOT NULL DEFAULT 1, + event_nid INTEGER NOT NULL DEFAULT 0, + target_local BOOLEAN NOT NULL DEFAULT false, + forgotten BOOLEAN NOT NULL DEFAULT false, + UNIQUE (room_nid, target_nid) + ); +` + +var selectJoinedUsersSetForRoomsSQL = "" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" + + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + + " GROUP BY target_nid" + +// Insert a row in to membership table so that it can be locked by the +// SELECT FOR UPDATE +const insertMembershipSQL = "" + + "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" + + " VALUES ($1, $2, $3)" + + " ON CONFLICT DO NOTHING" + +const selectMembershipFromRoomAndTargetSQL = "" + + "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" + + " WHERE room_nid = $1 AND target_nid = $2" + +const selectMembershipsFromRoomAndMembershipSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false" + +const selectLocalMembershipsFromRoomAndMembershipSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND membership_nid = $2" + + " AND target_local = true and forgotten = false" + +const selectMembershipsFromRoomSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1 and forgotten = false" + +const selectLocalMembershipsFromRoomSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1" + + " AND target_local = true and forgotten = false" + +const selectMembershipForUpdateSQL = "" + + "SELECT membership_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND target_nid = $2" + +const updateMembershipSQL = "" + + "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" + + " WHERE room_nid = $5 AND target_nid = $6" + +const updateMembershipForgetRoom = "" + + "UPDATE roomserver_membership SET forgotten = $1" + + " WHERE room_nid = $2 AND target_nid = $3" + +const selectRoomsWithMembershipSQL = "" + + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false" + +// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is +// joined to. Since this information is used to populate the user directory, we will +// only return users that the user would ordinarily be able to see anyway. +var selectKnownUsersSQL = "" + + "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " + + "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + + " WHERE room_nid IN (" + + " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + + ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3" + +type membershipStatements struct { + db *sql.DB + insertMembershipStmt *sql.Stmt + selectMembershipForUpdateStmt *sql.Stmt + selectMembershipFromRoomAndTargetStmt *sql.Stmt + selectMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectMembershipsFromRoomStmt *sql.Stmt + selectLocalMembershipsFromRoomStmt *sql.Stmt + selectRoomsWithMembershipStmt *sql.Stmt + updateMembershipStmt *sql.Stmt + selectKnownUsersStmt *sql.Stmt + updateMembershipForgetRoomStmt *sql.Stmt +} + +func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { + s := &membershipStatements{ + db: db, + } + + return s, shared.StatementList{ + {&s.insertMembershipStmt, insertMembershipSQL}, + {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, + {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, + {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL}, + {&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL}, + {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, + {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, + {&s.updateMembershipStmt, updateMembershipSQL}, + {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, + {&s.selectKnownUsersStmt, selectKnownUsersSQL}, + {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, + }.Prepare(db) +} + +func (s *membershipStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(membershipSchema) + return err +} + +func (s *membershipStatements) InsertMembership( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + localTarget bool, +) error { + stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) + _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) + return err +} + +func (s *membershipStatements) SelectMembershipForUpdate( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, +) (membership tables.MembershipState, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt) + err = stmt.QueryRowContext( + ctx, roomNID, targetUserNID, + ).Scan(&membership) + return +} + +func (s *membershipStatements) SelectMembershipFromRoomAndTarget( + ctx context.Context, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, +) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { + err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( + ctx, roomNID, targetUserNID, + ).Scan(&membership, &eventNID, &forgotten) + return +} + +func (s *membershipStatements) SelectMembershipsFromRoom( + ctx context.Context, + roomNID types.RoomNID, localOnly bool, +) (eventNIDs []types.EventNID, err error) { + var selectStmt *sql.Stmt + if localOnly { + selectStmt = s.selectLocalMembershipsFromRoomStmt + } else { + selectStmt = s.selectMembershipsFromRoomStmt + } + rows, err := selectStmt.QueryContext(ctx, roomNID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoom: rows.close() failed") + + for rows.Next() { + var eNID types.EventNID + if err = rows.Scan(&eNID); err != nil { + return + } + eventNIDs = append(eventNIDs, eNID) + } + return +} + +func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( + ctx context.Context, + roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, +) (eventNIDs []types.EventNID, err error) { + var stmt *sql.Stmt + if localOnly { + stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt + } else { + stmt = s.selectMembershipsFromRoomAndMembershipStmt + } + rows, err := stmt.QueryContext(ctx, roomNID, membership) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoomAndMembership: rows.close() failed") + + for rows.Next() { + var eNID types.EventNID + if err = rows.Scan(&eNID); err != nil { + return + } + eventNIDs = append(eventNIDs, eNID) + } + return +} + +func (s *membershipStatements) UpdateMembership( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, + eventNID types.EventNID, forgotten bool, +) error { + stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) + _, err := stmt.ExecContext( + ctx, senderUserNID, membership, eventNID, forgotten, roomNID, targetUserNID, + ) + return err +} + +func (s *membershipStatements) SelectRoomsWithMembership( + ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, +) ([]types.RoomNID, error) { + rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed") + var roomNIDs []types.RoomNID + for rows.Next() { + var roomNID types.RoomNID + if err := rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + return roomNIDs, nil +} + +func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { + iRoomNIDs := make([]interface{}, len(roomNIDs)) + for i, v := range roomNIDs { + iRoomNIDs[i] = v + } + query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) + rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed") + result := make(map[types.EventStateKeyNID]int) + for rows.Next() { + var userID types.EventStateKeyNID + var count int + if err := rows.Scan(&userID, &count); err != nil { + return nil, err + } + result[userID] = count + } + return result, rows.Err() +} + +func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { + rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) + if err != nil { + return nil, err + } + result := []string{} + defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} + +func (s *membershipStatements) UpdateForgetMembership( + ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + forget bool, +) error { + _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( + ctx, forget, roomNID, targetUserNID, + ) + return err +} diff --git a/roomserver/storage/cosmosdb/previous_events_table.go b/roomserver/storage/cosmosdb/previous_events_table.go new file mode 100644 index 000000000..1062ab1cf --- /dev/null +++ b/roomserver/storage/cosmosdb/previous_events_table.go @@ -0,0 +1,131 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" +) + +// TODO: previous_reference_sha256 was NOT NULL before but it broke sytest because +// sytest sends no SHA256 sums in the prev_events references in the soft-fail tests. +// In Postgres an empty BYTEA field is not NULL so it's fine there. In SQLite it +// seems to care that it's empty and therefore hits a NOT NULL constraint on insert. +// We should really work out what the right thing to do here is. +const previousEventSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_previous_events ( + previous_event_id TEXT NOT NULL, + previous_reference_sha256 BLOB, + event_nids TEXT NOT NULL, + UNIQUE (previous_event_id, previous_reference_sha256) + ); +` + +// Insert an entry into the previous_events table. +// If there is already an entry indicating that an event references that previous event then +// add the event NID to the list to indicate that this event references that previous event as well. +// This should only be modified while holding a "FOR UPDATE" lock on the row in the rooms table for this room. +// The lock is necessary to avoid data races when checking whether an event is already referenced by another event. +const insertPreviousEventSQL = ` + INSERT OR REPLACE INTO roomserver_previous_events + (previous_event_id, previous_reference_sha256, event_nids) + VALUES ($1, $2, $3) +` + +const selectPreviousEventNIDsSQL = ` + SELECT event_nids FROM roomserver_previous_events + WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 +` + +// Check if the event is referenced by another event in the table. +// This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room. +const selectPreviousEventExistsSQL = ` + SELECT 1 FROM roomserver_previous_events + WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 +` + +type previousEventStatements struct { + db *sql.DB + insertPreviousEventStmt *sql.Stmt + selectPreviousEventNIDsStmt *sql.Stmt + selectPreviousEventExistsStmt *sql.Stmt +} + +func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { + s := &previousEventStatements{ + db: db, + } + _, err := db.Exec(previousEventSchema) + if err != nil { + return nil, err + } + + return s, shared.StatementList{ + {&s.insertPreviousEventStmt, insertPreviousEventSQL}, + {&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL}, + {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, + }.Prepare(db) +} + +func (s *previousEventStatements) InsertPreviousEvent( + ctx context.Context, + txn *sql.Tx, + previousEventID string, + previousEventReferenceSHA256 []byte, + eventNID types.EventNID, +) error { + var eventNIDs string + eventNIDAsString := fmt.Sprintf("%d", eventNID) + selectStmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) + err := selectStmt.QueryRowContext(ctx, previousEventID, previousEventReferenceSHA256).Scan(&eventNIDs) + if err != nil && err != sql.ErrNoRows { + return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err) + } + var nids []string + if eventNIDs != "" { + nids = strings.Split(eventNIDs, ",") + for _, nid := range nids { + if nid == eventNIDAsString { + return nil + } + } + eventNIDs = strings.Join(append(nids, eventNIDAsString), ",") + } else { + eventNIDs = eventNIDAsString + } + insertStmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) + _, err = insertStmt.ExecContext( + ctx, previousEventID, previousEventReferenceSHA256, eventNIDs, + ) + return err +} + +// Check if the event reference exists +// Returns sql.ErrNoRows if the event reference doesn't exist. +func (s *previousEventStatements) SelectPreviousEventExists( + ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, +) error { + var ok int64 + stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) + return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) +} diff --git a/roomserver/storage/cosmosdb/published_table.go b/roomserver/storage/cosmosdb/published_table.go new file mode 100644 index 000000000..0de948aa6 --- /dev/null +++ b/roomserver/storage/cosmosdb/published_table.go @@ -0,0 +1,105 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" +) + +const publishedSchema = ` +-- Stores which rooms are published in the room directory +CREATE TABLE IF NOT EXISTS roomserver_published ( + -- The room ID of the room + room_id TEXT NOT NULL PRIMARY KEY, + -- Whether it is published or not + published BOOLEAN NOT NULL DEFAULT false +); +` + +const upsertPublishedSQL = "" + + "INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)" + +const selectAllPublishedSQL = "" + + "SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC" + +const selectPublishedSQL = "" + + "SELECT published FROM roomserver_published WHERE room_id = $1" + +type publishedStatements struct { + db *sql.DB + upsertPublishedStmt *sql.Stmt + selectAllPublishedStmt *sql.Stmt + selectPublishedStmt *sql.Stmt +} + +func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) { + s := &publishedStatements{ + db: db, + } + _, err := db.Exec(publishedSchema) + if err != nil { + return nil, err + } + return s, shared.StatementList{ + {&s.upsertPublishedStmt, upsertPublishedSQL}, + {&s.selectAllPublishedStmt, selectAllPublishedSQL}, + {&s.selectPublishedStmt, selectPublishedSQL}, + }.Prepare(db) +} + +func (s *publishedStatements) UpsertRoomPublished( + ctx context.Context, txn *sql.Tx, roomID string, published bool, +) error { + stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt) + _, err := stmt.ExecContext(ctx, roomID, published) + return err +} + +func (s *publishedStatements) SelectPublishedFromRoomID( + ctx context.Context, roomID string, +) (published bool, err error) { + err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published) + if err == sql.ErrNoRows { + return false, nil + } + return +} + +func (s *publishedStatements) SelectAllPublishedRooms( + ctx context.Context, published bool, +) ([]string, error) { + rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectAllPublishedStmt: rows.close() failed") + + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + + roomIDs = append(roomIDs, roomID) + } + return roomIDs, rows.Err() +} diff --git a/roomserver/storage/cosmosdb/redactions_table.go b/roomserver/storage/cosmosdb/redactions_table.go new file mode 100644 index 000000000..0d2ee27eb --- /dev/null +++ b/roomserver/storage/cosmosdb/redactions_table.go @@ -0,0 +1,123 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" +) + +const redactionsSchema = ` +-- Stores information about the redacted state of events. +-- We need to track redactions rather than blindly updating the event JSON table on receipt of a redaction +-- because we might receive the redaction BEFORE we receive the event which it redacts (think backfill). +CREATE TABLE IF NOT EXISTS roomserver_redactions ( + redaction_event_id TEXT PRIMARY KEY, + redacts_event_id TEXT NOT NULL, + -- Initially FALSE, set to TRUE when the redaction has been validated according to rooms v3+ spec + -- https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events + validated BOOLEAN NOT NULL +); +` + +const insertRedactionSQL = "" + + "INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" + + " VALUES ($1, $2, $3)" + +const selectRedactionInfoByRedactionEventIDSQL = "" + + "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" + + " WHERE redaction_event_id = $1" + +const selectRedactionInfoByEventBeingRedactedSQL = "" + + "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" + + " WHERE redacts_event_id = $1" + +const markRedactionValidatedSQL = "" + + " UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1" + +type redactionStatements struct { + db *sql.DB + insertRedactionStmt *sql.Stmt + selectRedactionInfoByRedactionEventIDStmt *sql.Stmt + selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt + markRedactionValidatedStmt *sql.Stmt +} + +func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) { + s := &redactionStatements{ + db: db, + } + _, err := db.Exec(redactionsSchema) + if err != nil { + return nil, err + } + + return s, shared.StatementList{ + {&s.insertRedactionStmt, insertRedactionSQL}, + {&s.selectRedactionInfoByRedactionEventIDStmt, selectRedactionInfoByRedactionEventIDSQL}, + {&s.selectRedactionInfoByEventBeingRedactedStmt, selectRedactionInfoByEventBeingRedactedSQL}, + {&s.markRedactionValidatedStmt, markRedactionValidatedSQL}, + }.Prepare(db) +} + +func (s *redactionStatements) InsertRedaction( + ctx context.Context, txn *sql.Tx, info tables.RedactionInfo, +) error { + stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt) + _, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated) + return err +} + +func (s *redactionStatements) SelectRedactionInfoByRedactionEventID( + ctx context.Context, txn *sql.Tx, redactionEventID string, +) (info *tables.RedactionInfo, err error) { + info = &tables.RedactionInfo{} + stmt := sqlutil.TxStmt(txn, s.selectRedactionInfoByRedactionEventIDStmt) + err = stmt.QueryRowContext(ctx, redactionEventID).Scan( + &info.RedactionEventID, &info.RedactsEventID, &info.Validated, + ) + if err == sql.ErrNoRows { + info = nil + err = nil + } + return +} + +func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted( + ctx context.Context, txn *sql.Tx, eventID string, +) (info *tables.RedactionInfo, err error) { + info = &tables.RedactionInfo{} + stmt := sqlutil.TxStmt(txn, s.selectRedactionInfoByEventBeingRedactedStmt) + err = stmt.QueryRowContext(ctx, eventID).Scan( + &info.RedactionEventID, &info.RedactsEventID, &info.Validated, + ) + if err == sql.ErrNoRows { + info = nil + err = nil + } + return +} + +func (s *redactionStatements) MarkRedactionValidated( + ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool, +) error { + stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt) + _, err := stmt.ExecContext(ctx, redactionEventID, validated) + return err +} diff --git a/roomserver/storage/cosmosdb/room_aliases_table.go b/roomserver/storage/cosmosdb/room_aliases_table.go new file mode 100644 index 000000000..3592257ba --- /dev/null +++ b/roomserver/storage/cosmosdb/room_aliases_table.go @@ -0,0 +1,141 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" +) + +const roomAliasesSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_room_aliases ( + alias TEXT NOT NULL PRIMARY KEY, + room_id TEXT NOT NULL, + creator_id TEXT NOT NULL + ); + + CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id); +` + +const insertRoomAliasSQL = ` + INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3) +` + +const selectRoomIDFromAliasSQL = ` + SELECT room_id FROM roomserver_room_aliases WHERE alias = $1 +` + +const selectAliasesFromRoomIDSQL = ` + SELECT alias FROM roomserver_room_aliases WHERE room_id = $1 +` + +const selectCreatorIDFromAliasSQL = ` + SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1 +` + +const deleteRoomAliasSQL = ` + DELETE FROM roomserver_room_aliases WHERE alias = $1 +` + +type roomAliasesStatements struct { + db *sql.DB + insertRoomAliasStmt *sql.Stmt + selectRoomIDFromAliasStmt *sql.Stmt + selectAliasesFromRoomIDStmt *sql.Stmt + selectCreatorIDFromAliasStmt *sql.Stmt + deleteRoomAliasStmt *sql.Stmt +} + +func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { + s := &roomAliasesStatements{ + db: db, + } + _, err := db.Exec(roomAliasesSchema) + if err != nil { + return nil, err + } + return s, shared.StatementList{ + {&s.insertRoomAliasStmt, insertRoomAliasSQL}, + {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, + {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, + {&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL}, + {&s.deleteRoomAliasStmt, deleteRoomAliasSQL}, + }.Prepare(db) +} + +func (s *roomAliasesStatements) InsertRoomAlias( + ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string, +) error { + stmt := sqlutil.TxStmt(txn, s.insertRoomAliasStmt) + _, err := stmt.ExecContext(ctx, alias, roomID, creatorUserID) + return err +} + +func (s *roomAliasesStatements) SelectRoomIDFromAlias( + ctx context.Context, alias string, +) (roomID string, err error) { + err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) + if err == sql.ErrNoRows { + return "", nil + } + return +} + +func (s *roomAliasesStatements) SelectAliasesFromRoomID( + ctx context.Context, roomID string, +) (aliases []string, err error) { + aliases = []string{} + rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) + if err != nil { + return + } + + defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed") + + for rows.Next() { + var alias string + if err = rows.Scan(&alias); err != nil { + return + } + + aliases = append(aliases, alias) + } + + return +} + +func (s *roomAliasesStatements) SelectCreatorIDFromAlias( + ctx context.Context, alias string, +) (creatorID string, err error) { + err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) + if err == sql.ErrNoRows { + return "", nil + } + return +} + +func (s *roomAliasesStatements) DeleteRoomAlias( + ctx context.Context, txn *sql.Tx, alias string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteRoomAliasStmt) + _, err := stmt.ExecContext(ctx, alias) + return err +} diff --git a/roomserver/storage/cosmosdb/rooms_table.go b/roomserver/storage/cosmosdb/rooms_table.go new file mode 100644 index 000000000..8570a9723 --- /dev/null +++ b/roomserver/storage/cosmosdb/rooms_table.go @@ -0,0 +1,296 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const roomsSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_rooms ( + room_nid INTEGER PRIMARY KEY AUTOINCREMENT, + room_id TEXT NOT NULL UNIQUE, + latest_event_nids TEXT NOT NULL DEFAULT '[]', + last_event_sent_nid INTEGER NOT NULL DEFAULT 0, + state_snapshot_nid INTEGER NOT NULL DEFAULT 0, + room_version TEXT NOT NULL + ); +` + +// Same as insertEventTypeNIDSQL +const insertRoomNIDSQL = ` + INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2) + ON CONFLICT DO NOTHING; +` + +const selectRoomNIDSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" + +const selectLatestEventNIDsSQL = "" + + "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" + +const selectLatestEventNIDsForUpdateSQL = "" + + "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" + +const updateLatestEventNIDsSQL = "" + + "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4" + +const selectRoomVersionsForRoomNIDsSQL = "" + + "SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)" + +const selectRoomInfoSQL = "" + + "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" + +const selectRoomIDsSQL = "" + + "SELECT room_id FROM roomserver_rooms" + +const bulkSelectRoomIDsSQL = "" + + "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" + +const bulkSelectRoomNIDsSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)" + +type roomStatements struct { + db *sql.DB + insertRoomNIDStmt *sql.Stmt + selectRoomNIDStmt *sql.Stmt + selectLatestEventNIDsStmt *sql.Stmt + selectLatestEventNIDsForUpdateStmt *sql.Stmt + updateLatestEventNIDsStmt *sql.Stmt + //selectRoomVersionForRoomNIDStmt *sql.Stmt + selectRoomInfoStmt *sql.Stmt + selectRoomIDsStmt *sql.Stmt +} + +func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { + s := &roomStatements{ + db: db, + } + _, err := db.Exec(roomsSchema) + if err != nil { + return nil, err + } + return s, shared.StatementList{ + {&s.insertRoomNIDStmt, insertRoomNIDSQL}, + {&s.selectRoomNIDStmt, selectRoomNIDSQL}, + {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, + {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, + {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, + //{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL}, + {&s.selectRoomInfoStmt, selectRoomInfoSQL}, + {&s.selectRoomIDsStmt, selectRoomIDsSQL}, + }.Prepare(db) +} + +func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { + rows, err := s.selectRoomIDsStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed") + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, nil +} + +func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { + var info types.RoomInfo + var latestNIDsJSON string + err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan( + &info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + var latestNIDs []int64 + if err = json.Unmarshal([]byte(latestNIDsJSON), &latestNIDs); err != nil { + return nil, err + } + info.IsStub = len(latestNIDs) == 0 + return &info, err +} + +func (s *roomStatements) InsertRoomNID( + ctx context.Context, txn *sql.Tx, + roomID string, roomVersion gomatrixserverlib.RoomVersion, +) (roomNID types.RoomNID, err error) { + insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt) + _, err = insertStmt.ExecContext(ctx, roomID, roomVersion) + if err != nil { + return 0, fmt.Errorf("insertStmt.ExecContext: %w", err) + } + roomNID, err = s.SelectRoomNID(ctx, txn, roomID) + if err != nil { + return 0, fmt.Errorf("s.SelectRoomNID: %w", err) + } + return +} + +func (s *roomStatements) SelectRoomNID( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + var roomNID int64 + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err +} + +func (s *roomStatements) SelectLatestEventNIDs( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) ([]types.EventNID, types.StateSnapshotNID, error) { + var eventNIDs []types.EventNID + var nidsJSON string + var stateSnapshotNID int64 + stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt) + err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nidsJSON, &stateSnapshotNID) + if err != nil { + return nil, 0, err + } + if err := json.Unmarshal([]byte(nidsJSON), &eventNIDs); err != nil { + return nil, 0, err + } + return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil +} + +func (s *roomStatements) SelectLatestEventsNIDsForUpdate( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) { + var eventNIDs []types.EventNID + var nidsJSON string + var lastEventSentNID int64 + var stateSnapshotNID int64 + stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt) + err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nidsJSON, &lastEventSentNID, &stateSnapshotNID) + if err != nil { + return nil, 0, 0, err + } + if err := json.Unmarshal([]byte(nidsJSON), &eventNIDs); err != nil { + return nil, 0, 0, err + } + return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil +} + +func (s *roomStatements) UpdateLatestEventNIDs( + ctx context.Context, + txn *sql.Tx, + roomNID types.RoomNID, + eventNIDs []types.EventNID, + lastEventSentNID types.EventNID, + stateSnapshotNID types.StateSnapshotNID, +) error { + stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt) + _, err := stmt.ExecContext( + ctx, + eventNIDsAsArray(eventNIDs), + int64(lastEventSentNID), + int64(stateSnapshotNID), + roomNID, + ) + return err +} + +func (s *roomStatements) SelectRoomVersionsForRoomNIDs( + ctx context.Context, roomNIDs []types.RoomNID, +) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) { + sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) + sqlPrep, err := s.db.Prepare(sqlStr) + if err != nil { + return nil, err + } + iRoomNIDs := make([]interface{}, len(roomNIDs)) + for i, v := range roomNIDs { + iRoomNIDs[i] = v + } + rows, err := sqlPrep.QueryContext(ctx, iRoomNIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed") + result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion) + for rows.Next() { + var roomNID types.RoomNID + var roomVersion gomatrixserverlib.RoomVersion + if err = rows.Scan(&roomNID, &roomVersion); err != nil { + return nil, err + } + result[roomNID] = roomVersion + } + return result, nil +} + +func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { + iRoomNIDs := make([]interface{}, len(roomNIDs)) + for i, v := range roomNIDs { + iRoomNIDs[i] = v + } + sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) + rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed") + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, nil +} + +func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { + iRoomIDs := make([]interface{}, len(roomIDs)) + for i, v := range roomIDs { + iRoomIDs[i] = v + } + sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1) + rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed") + var roomNIDs []types.RoomNID + for rows.Next() { + var roomNID types.RoomNID + if err = rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + return roomNIDs, nil +} diff --git a/roomserver/storage/cosmosdb/state_block_table.go b/roomserver/storage/cosmosdb/state_block_table.go new file mode 100644 index 000000000..f0c8169dd --- /dev/null +++ b/roomserver/storage/cosmosdb/state_block_table.go @@ -0,0 +1,289 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "fmt" + "sort" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/util" +) + +const stateDataSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_state_block ( + state_block_nid INTEGER NOT NULL, + event_type_nid INTEGER NOT NULL, + event_state_key_nid INTEGER NOT NULL, + event_nid INTEGER NOT NULL, + UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) + ); +` + +const insertStateDataSQL = "" + + "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + + " VALUES ($1, $2, $3, $4)" + +const selectNextStateBlockNIDSQL = ` +SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block +` + +// Bulk state lookup by numeric state block ID. +// Sort by the state_block_nid, event_type_nid, event_state_key_nid +// This means that all the entries for a given state_block_nid will appear +// together in the list and those entries will sorted by event_type_nid +// and event_state_key_nid. This property makes it easier to merge two +// state data blocks together. +const bulkSelectStateBlockEntriesSQL = "" + + "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + + " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + + " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + +// Bulk state lookup by numeric state block ID. +// Filters the rows in each block to the requested types and state keys. +// We would like to restrict to particular type state key pairs but we are +// restricted by the query language to pull the cross product of a list +// of types and a list state_keys. So we have to filter the result in the +// application to restrict it to the list of event types and state keys we +// actually wanted. +const bulkSelectFilteredStateBlockEntriesSQL = "" + + "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + + " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + + " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" + + " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + +type stateBlockStatements struct { + db *sql.DB + insertStateDataStmt *sql.Stmt + selectNextStateBlockNIDStmt *sql.Stmt + bulkSelectStateBlockEntriesStmt *sql.Stmt + bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt +} + +func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { + s := &stateBlockStatements{ + db: db, + } + _, err := db.Exec(stateDataSchema) + if err != nil { + return nil, err + } + + return s, shared.StatementList{ + {&s.insertStateDataStmt, insertStateDataSQL}, + {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, + {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, + {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, + }.Prepare(db) +} + +func (s *stateBlockStatements) BulkInsertStateData( + ctx context.Context, txn *sql.Tx, + entries []types.StateEntry, +) (types.StateBlockNID, error) { + if len(entries) == 0 { + return 0, nil + } + var stateBlockNID types.StateBlockNID + err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) + if err != nil { + return 0, err + } + for _, entry := range entries { + _, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext( + ctx, + int64(stateBlockNID), + int64(entry.EventTypeNID), + int64(entry.EventStateKeyNID), + int64(entry.EventNID), + ) + if err != nil { + return 0, err + } + } + return stateBlockNID, err +} + +func (s *stateBlockStatements) BulkSelectStateBlockEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + nids := make([]interface{}, len(stateBlockNIDs)) + for k, v := range stateBlockNIDs { + nids[k] = v + } + selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + selectStmt, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + rows, err := selectStmt.QueryContext(ctx, nids...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed") + + results := make([]types.StateEntryList, len(stateBlockNIDs)) + // current is a pointer to the StateEntryList to append the state entries to. + var current *types.StateEntryList + i := 0 + for rows.Next() { + var ( + stateBlockNID int64 + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + entry types.StateEntry + ) + if err := rows.Scan( + &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, + ); err != nil { + return nil, err + } + entry.EventTypeNID = types.EventTypeNID(eventTypeNID) + entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) + entry.EventNID = types.EventNID(eventNID) + if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID { + // The state entry row is for a different state data block to the current one. + // So we start appending to the next entry in the list. + current = &results[i] + current.StateBlockNID = types.StateBlockNID(stateBlockNID) + i++ + } + current.StateEntries = append(current.StateEntries, entry) + } + if i != len(nids) { + return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids)) + } + return results, nil +} + +func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + tuples := stateKeyTupleSorter(stateKeyTuples) + // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db. + sort.Sort(tuples) + + eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() + sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1) + sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1) + sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1) + + var params []interface{} + for _, val := range stateBlockNIDs { + params = append(params, int64(val)) + } + for _, val := range eventTypeNIDArray { + params = append(params, val) + } + for _, val := range eventStateKeyNIDArray { + params = append(params, val) + } + + rows, err := s.db.QueryContext( + ctx, + sqlStatement, + params..., + ) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed") + + var results []types.StateEntryList + var current types.StateEntryList + for rows.Next() { + var ( + stateBlockNID int64 + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + entry types.StateEntry + ) + if err := rows.Scan( + &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, + ); err != nil { + return nil, err + } + entry.EventTypeNID = types.EventTypeNID(eventTypeNID) + entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) + entry.EventNID = types.EventNID(eventNID) + + // We can use binary search here because we sorted the tuples earlier + if !tuples.contains(entry.StateKeyTuple) { + // The select will return the cross product of types and state keys. + // So we need to check if type of the entry is in the list. + continue + } + + if types.StateBlockNID(stateBlockNID) != current.StateBlockNID { + // The state entry row is for a different state data block to the current one. + // So we append the current entry to the results and start adding to a new one. + // The first time through the loop current will be empty. + if current.StateEntries != nil { + results = append(results, current) + } + current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)} + } + current.StateEntries = append(current.StateEntries, entry) + } + // Add the last entry to the list if it is not empty. + if current.StateEntries != nil { + results = append(results, current) + } + return results, nil +} + +type stateKeyTupleSorter []types.StateKeyTuple + +func (s stateKeyTupleSorter) Len() int { return len(s) } +func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } +func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// Check whether a tuple is in the list. Assumes that the list is sorted. +func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool { + i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) }) + return i < len(s) && s[i] == value +} + +// List the unique eventTypeNIDs and eventStateKeyNIDs. +// Assumes that the list is sorted. +func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) { + eventTypeNIDs = make([]int64, len(s)) + eventStateKeyNIDs = make([]int64, len(s)) + for i := range s { + eventTypeNIDs[i] = int64(s[i].EventTypeNID) + eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID) + } + eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))] + eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))] + return +} + +type int64Sorter []int64 + +func (s int64Sorter) Len() int { return len(s) } +func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] } +func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/storage/cosmosdb/state_block_table_test.go b/roomserver/storage/cosmosdb/state_block_table_test.go new file mode 100644 index 000000000..da59f138c --- /dev/null +++ b/roomserver/storage/cosmosdb/state_block_table_test.go @@ -0,0 +1,86 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "sort" + "testing" + + "github.com/matrix-org/dendrite/roomserver/types" +) + +func TestStateKeyTupleSorter(t *testing.T) { + input := stateKeyTupleSorter{ + {EventTypeNID: 1, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 4}, + {EventTypeNID: 2, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 1}, + } + want := []types.StateKeyTuple{ + {EventTypeNID: 1, EventStateKeyNID: 1}, + {EventTypeNID: 1, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 4}, + {EventTypeNID: 2, EventStateKeyNID: 2}, + } + doNotWant := []types.StateKeyTuple{ + {EventTypeNID: 0, EventStateKeyNID: 0}, + {EventTypeNID: 1, EventStateKeyNID: 3}, + {EventTypeNID: 2, EventStateKeyNID: 1}, + {EventTypeNID: 3, EventStateKeyNID: 1}, + } + wantTypeNIDs := []int64{1, 2} + wantStateKeyNIDs := []int64{1, 2, 4} + + // Sort the input and check it's in the right order. + sort.Sort(input) + gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays() + + for i := range want { + if input[i] != want[i] { + t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i]) + } + + if !input.contains(want[i]) { + t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i]) + } + } + + for i := range doNotWant { + if input.contains(doNotWant[i]) { + t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i]) + } + } + + if len(wantTypeNIDs) != len(gotTypeNIDs) { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + + for i := range wantTypeNIDs { + if wantTypeNIDs[i] != gotTypeNIDs[i] { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + } + + if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) { + t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs) + } + + for i := range wantStateKeyNIDs { + if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + } +} diff --git a/roomserver/storage/cosmosdb/state_snapshot_table.go b/roomserver/storage/cosmosdb/state_snapshot_table.go new file mode 100644 index 000000000..f75b18755 --- /dev/null +++ b/roomserver/storage/cosmosdb/state_snapshot_table.go @@ -0,0 +1,126 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const stateSnapshotSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( + state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, + room_nid INTEGER NOT NULL, + state_block_nids TEXT NOT NULL DEFAULT '[]' + ); +` + +const insertStateSQL = ` + INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids) + VALUES ($1, $2);` + +// Bulk state data NID lookup. +// Sorting by state_snapshot_nid means we can use binary search over the result +// to lookup the state data NIDs for a state snapshot NID. +const bulkSelectStateBlockNIDsSQL = "" + + "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + + " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC" + +type stateSnapshotStatements struct { + db *sql.DB + insertStateStmt *sql.Stmt + bulkSelectStateBlockNIDsStmt *sql.Stmt +} + +func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { + s := &stateSnapshotStatements{ + db: db, + } + _, err := db.Exec(stateSnapshotSchema) + if err != nil { + return nil, err + } + + return s, shared.StatementList{ + {&s.insertStateStmt, insertStateSQL}, + {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, + }.Prepare(db) +} + +func (s *stateSnapshotStatements) InsertState( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, +) (stateNID types.StateSnapshotNID, err error) { + stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs) + if err != nil { + return + } + insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt) + res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)) + if err != nil { + return 0, err + } + lastRowID, err := res.LastInsertId() + if err != nil { + return 0, err + } + stateNID = types.StateSnapshotNID(lastRowID) + return +} + +func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + nids := make([]interface{}, len(stateNIDs)) + for k, v := range stateNIDs { + nids[k] = v + } + selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + selectStmt, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + + rows, err := selectStmt.QueryContext(ctx, nids...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed") + results := make([]types.StateBlockNIDList, len(stateNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + var stateBlockNIDsJSON string + if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil { + return nil, err + } + if err := json.Unmarshal([]byte(stateBlockNIDsJSON), &result.StateBlockNIDs); err != nil { + return nil, err + } + } + if i != len(stateNIDs) { + return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs)) + } + return results, nil +} diff --git a/roomserver/storage/cosmosdb/storage.go b/roomserver/storage/cosmosdb/storage.go new file mode 100644 index 000000000..bb3f6af2e --- /dev/null +++ b/roomserver/storage/cosmosdb/storage.go @@ -0,0 +1,187 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + + _ "github.com/mattn/go-sqlite3" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// A Database is used to store room events and stream offsets. +type Database struct { + shared.Database +} + +// Open a sqlite database. +func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { + var d Database + var db *sql.DB + var err error + if db, err = sqlutil.Open(dbProperties); err != nil { + return nil, err + } + + //db.Exec("PRAGMA journal_mode=WAL;") + //db.Exec("PRAGMA read_uncommitted = true;") + + // FIXME: We are leaking connections somewhere. Setting this to 2 will eventually + // cause the roomserver to be unresponsive to new events because something will + // acquire the global mutex and never unlock it because it is waiting for a connection + // which it will never obtain. + db.SetMaxOpenConns(20) + + // Create tables before executing migrations so we don't fail if the table is missing, + // and THEN prepare statements so we don't fail due to referencing new columns + ms := membershipStatements{} + if err := ms.execSchema(db); err != nil { + return nil, err + } + m := sqlutil.NewMigrations() + deltas.LoadAddForgottenColumn(m) + if err := m.RunDeltas(db, dbProperties); err != nil { + return nil, err + } + if err := d.prepare(db, cache); err != nil { + return nil, err + } + + return &d, nil +} + +// nolint: gocyclo +func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { + var err error + eventStateKeys, err := NewSqliteEventStateKeysTable(db) + if err != nil { + return err + } + eventTypes, err := NewSqliteEventTypesTable(db) + if err != nil { + return err + } + eventJSON, err := NewSqliteEventJSONTable(db) + if err != nil { + return err + } + events, err := NewSqliteEventsTable(db) + if err != nil { + return err + } + rooms, err := NewSqliteRoomsTable(db) + if err != nil { + return err + } + transactions, err := NewSqliteTransactionsTable(db) + if err != nil { + return err + } + stateBlock, err := NewSqliteStateBlockTable(db) + if err != nil { + return err + } + stateSnapshot, err := NewSqliteStateSnapshotTable(db) + if err != nil { + return err + } + prevEvents, err := NewSqlitePrevEventsTable(db) + if err != nil { + return err + } + roomAliases, err := NewSqliteRoomAliasesTable(db) + if err != nil { + return err + } + invites, err := NewSqliteInvitesTable(db) + if err != nil { + return err + } + membership, err := NewSqliteMembershipTable(db) + if err != nil { + return err + } + published, err := NewSqlitePublishedTable(db) + if err != nil { + return err + } + redactions, err := NewSqliteRedactionsTable(db) + if err != nil { + return err + } + d.Database = shared.Database{ + DB: db, + Cache: cache, + Writer: sqlutil.NewExclusiveWriter(), + EventsTable: events, + EventTypesTable: eventTypes, + EventStateKeysTable: eventStateKeys, + EventJSONTable: eventJSON, + RoomsTable: rooms, + TransactionsTable: transactions, + StateBlockTable: stateBlock, + StateSnapshotTable: stateSnapshot, + PrevEventsTable: prevEvents, + RoomAliasesTable: roomAliases, + InvitesTable: invites, + MembershipTable: membership, + PublishedTable: published, + RedactionsTable: redactions, + GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate, + } + return nil +} + +func (d *Database) SupportsConcurrentRoomInputs() bool { + // This isn't supported in SQLite mode yet because of issues with + // database locks. + // TODO: Look at this again - the problem is probably to do with + // the membership updaters and latest events updaters. + return false +} + +func (d *Database) GetLatestEventsForUpdate( + ctx context.Context, roomInfo types.RoomInfo, +) (*shared.LatestEventsUpdater, error) { + // TODO: Do not use transactions. We should be holding open this transaction but we cannot have + // multiple write transactions on sqlite. The code will perform additional + // write transactions independent of this one which will consistently cause + // 'database is locked' errors. As sqlite doesn't support multi-process on the + // same DB anyway, and we only execute updates sequentially, the only worries + // are for rolling back when things go wrong. (atomicity) + return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo) +} + +func (d *Database) MembershipUpdater( + ctx context.Context, roomID, targetUserID string, + targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, +) (*shared.MembershipUpdater, error) { + // TODO: Do not use transactions. We should be holding open this transaction but we cannot have + // multiple write transactions on sqlite. The code will perform additional + // write transactions independent of this one which will consistently cause + // 'database is locked' errors. As sqlite doesn't support multi-process on the + // same DB anyway, and we only execute updates sequentially, the only worries + // are for rolling back when things go wrong. (atomicity) + return shared.NewMembershipUpdater(ctx, &d.Database, nil, roomID, targetUserID, targetLocal, roomVersion) +} diff --git a/roomserver/storage/cosmosdb/transactions_table.go b/roomserver/storage/cosmosdb/transactions_table.go new file mode 100644 index 000000000..9be93ed34 --- /dev/null +++ b/roomserver/storage/cosmosdb/transactions_table.go @@ -0,0 +1,91 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" +) + +const transactionsSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_transactions ( + transaction_id TEXT NOT NULL, + session_id INTEGER NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + PRIMARY KEY (transaction_id, session_id, user_id) + ); +` +const insertTransactionSQL = ` + INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id) + VALUES ($1, $2, $3, $4) +` + +const selectTransactionEventIDSQL = ` + SELECT event_id FROM roomserver_transactions + WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3 +` + +type transactionStatements struct { + db *sql.DB + insertTransactionStmt *sql.Stmt + selectTransactionEventIDStmt *sql.Stmt +} + +func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) { + s := &transactionStatements{ + db: db, + } + _, err := db.Exec(transactionsSchema) + if err != nil { + return nil, err + } + + return s, shared.StatementList{ + {&s.insertTransactionStmt, insertTransactionSQL}, + {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, + }.Prepare(db) +} + +func (s *transactionStatements) InsertTransaction( + ctx context.Context, txn *sql.Tx, + transactionID string, + sessionID int64, + userID string, + eventID string, +) error { + stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt) + _, err := stmt.ExecContext( + ctx, transactionID, sessionID, userID, eventID, + ) + return err +} + +func (s *transactionStatements) SelectTransactionEventID( + ctx context.Context, + transactionID string, + sessionID int64, + userID string, +) (eventID string, err error) { + err = s.selectTransactionEventIDStmt.QueryRowContext( + ctx, transactionID, sessionID, userID, + ).Scan(&eventID) + return +} diff --git a/roomserver/storage/storage.go b/roomserver/storage/storage.go index 9359312db..08e54ff70 100644 --- a/roomserver/storage/storage.go +++ b/roomserver/storage/storage.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/roomserver/storage/cosmosdb" "github.com/matrix-org/dendrite/roomserver/storage/postgres" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" "github.com/matrix-org/dendrite/setup/config" @@ -28,6 +29,8 @@ import ( // Open opens a database connection. func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) { switch { + case dbProperties.ConnectionString.IsCosmosDB(): + return cosmosdb.Open(dbProperties, cache) case dbProperties.ConnectionString.IsSQLite(): return sqlite3.Open(dbProperties, cache) case dbProperties.ConnectionString.IsPostgres(): diff --git a/roomserver/storage/storage_wasm.go b/roomserver/storage/storage_wasm.go index dfc374e6e..333ac9bf6 100644 --- a/roomserver/storage/storage_wasm.go +++ b/roomserver/storage/storage_wasm.go @@ -25,6 +25,8 @@ import ( // NewPublicRoomsServerDatabase opens a database connection. func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) { switch { + case dbProperties.ConnectionString.IsCosmosDB(): + return nil, fmt.Errorf("can't use CosmosDB implementation") case dbProperties.ConnectionString.IsSQLite(): return sqlite3.Open(dbProperties, cache) case dbProperties.ConnectionString.IsPostgres(): diff --git a/setup/config/config.go b/setup/config/config.go index b91144078..c92d1fecd 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -139,6 +139,11 @@ func (d DataSource) IsPostgres() bool { return !d.IsSQLite() } +func (d DataSource) IsCosmosDB() bool { + // commented line may not always be true? + return strings.HasPrefix(string(d), "cosmosdb:") +} + // A Topic in kafka. type Topic string diff --git a/setup/mscs/msc2836/storage.go b/setup/mscs/msc2836/storage.go index 72523916b..61208d58d 100644 --- a/setup/mscs/msc2836/storage.go +++ b/setup/mscs/msc2836/storage.go @@ -59,6 +59,9 @@ type DB struct { // NewDatabase loads the database for msc2836 func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) { + if dbOpts.ConnectionString.IsCosmosDB() { + return newCosmosDBDatabase(dbOpts) + } if dbOpts.ConnectionString.IsPostgres() { return newPostgresDatabase(dbOpts) } @@ -225,6 +228,86 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { return &d, nil } +func newCosmosDBDatabase(dbOpts *config.DatabaseOptions) (Database, error) { + d := DB{ + writer: sqlutil.NewExclusiveWriter(), + } + var err error + if d.db, err = sqlutil.Open(dbOpts); err != nil { + return nil, err + } + _, err = d.db.Exec(` + CREATE TABLE IF NOT EXISTS msc2836_edges ( + parent_event_id TEXT NOT NULL, + child_event_id TEXT NOT NULL, + rel_type TEXT NOT NULL, + parent_room_id TEXT NOT NULL, + parent_servers TEXT NOT NULL, + UNIQUE (parent_event_id, child_event_id, rel_type) + ); + + CREATE TABLE IF NOT EXISTS msc2836_nodes ( + event_id TEXT PRIMARY KEY NOT NULL, + origin_server_ts BIGINT NOT NULL, + room_id TEXT NOT NULL, + unsigned_children_count BIGINT NOT NULL, + unsigned_children_hash TEXT NOT NULL, + explored SMALLINT NOT NULL + ); + `) + if err != nil { + return nil, err + } + if d.insertEdgeStmt, err = d.db.Prepare(` + INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) + VALUES($1, $2, $3, $4, $5) + ON CONFLICT (parent_event_id, child_event_id, rel_type) DO NOTHING + `); err != nil { + return nil, err + } + if d.insertNodeStmt, err = d.db.Prepare(` + INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id, unsigned_children_count, unsigned_children_hash, explored) + VALUES($1, $2, $3, $4, $5, $6) + ON CONFLICT DO NOTHING + `); err != nil { + return nil, err + } + selectChildrenQuery := ` + SELECT child_event_id, origin_server_ts, room_id FROM msc2836_edges + LEFT JOIN msc2836_nodes ON msc2836_edges.child_event_id = msc2836_nodes.event_id + WHERE parent_event_id = $1 AND rel_type = $2 + ORDER BY origin_server_ts + ` + if d.selectChildrenForParentOldestFirstStmt, err = d.db.Prepare(selectChildrenQuery + "ASC"); err != nil { + return nil, err + } + if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil { + return nil, err + } + if d.selectParentForChildStmt, err = d.db.Prepare(` + SELECT parent_event_id, parent_room_id FROM msc2836_edges + WHERE child_event_id = $1 AND rel_type = $2 + `); err != nil { + return nil, err + } + if d.updateChildMetadataStmt, err = d.db.Prepare(` + UPDATE msc2836_nodes SET unsigned_children_count=$1, unsigned_children_hash=$2, explored=$3 WHERE event_id=$4 + `); err != nil { + return nil, err + } + if d.selectChildMetadataStmt, err = d.db.Prepare(` + SELECT unsigned_children_count, unsigned_children_hash, explored FROM msc2836_nodes WHERE event_id=$1 + `); err != nil { + return nil, err + } + if d.updateChildMetadataExploredStmt, err = d.db.Prepare(` + UPDATE msc2836_nodes SET explored=$1 WHERE event_id=$2 + `); err != nil { + return nil, err + } + return &d, nil +} + func (p *DB) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error { parent, child, relType := parentChildEventIDs(ev) if parent == "" || child == "" { diff --git a/setup/mscs/msc2946/storage.go b/setup/mscs/msc2946/storage.go index 20db18594..b80cbe2cb 100644 --- a/setup/mscs/msc2946/storage.go +++ b/setup/mscs/msc2946/storage.go @@ -47,6 +47,9 @@ type DB struct { // NewDatabase loads the database for msc2836 func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) { + if dbOpts.ConnectionString.IsCosmosDB() { + return newCosmosDBDatabase(dbOpts) + } if dbOpts.ConnectionString.IsPostgres() { return newPostgresDatabase(dbOpts) } @@ -133,6 +136,46 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { return &d, err } +func newCosmosDBDatabase(dbOpts *config.DatabaseOptions) (Database, error) { + d := DB{ + writer: sqlutil.NewExclusiveWriter(), + } + var err error + if d.db, err = sqlutil.Open(dbOpts); err != nil { + return nil, err + } + _, err = d.db.Exec(` + CREATE TABLE IF NOT EXISTS msc2946_edges ( + room_version TEXT NOT NULL, + -- the room ID of the event, the source of the arrow + source_room_id TEXT NOT NULL, + -- the target room ID, the arrow destination + dest_room_id TEXT NOT NULL, + -- the kind of relation, either child or parent (1,2) + rel_type SMALLINT NOT NULL, + event_json TEXT NOT NULL, + UNIQUE (source_room_id, dest_room_id, rel_type) + ); + `) + if err != nil { + return nil, err + } + if d.insertEdgeStmt, err = d.db.Prepare(` + INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json) + VALUES($1, $2, $3, $4, $5) + ON CONFLICT (source_room_id, dest_room_id, rel_type) DO UPDATE SET event_json = $5 + `); err != nil { + return nil, err + } + if d.selectEdgesStmt, err = d.db.Prepare(` + SELECT room_version, event_json FROM msc2946_edges + WHERE source_room_id = $1 OR dest_room_id = $2 + `); err != nil { + return nil, err + } + return &d, err +} + func (d *DB) StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error { target := SpaceTarget(he) if target == "" { diff --git a/signingkeyserver/storage/cosmosdb/keydb.go b/signingkeyserver/storage/cosmosdb/keydb.go new file mode 100644 index 000000000..0f4371bce --- /dev/null +++ b/signingkeyserver/storage/cosmosdb/keydb.go @@ -0,0 +1,99 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + + "golang.org/x/crypto/ed25519" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + + _ "github.com/mattn/go-sqlite3" +) + +// A Database implements gomatrixserverlib.KeyDatabase and is used to store +// the public keys for other matrix servers. +type Database struct { + writer sqlutil.Writer + statements serverKeyStatements +} + +// NewDatabase prepares a new key database. +// It creates the necessary tables if they don't already exist. +// It prepares all the SQL statements that it will use. +// Returns an error if there was a problem talking to the database. +func NewDatabase( + dbProperties *config.DatabaseOptions, + serverName gomatrixserverlib.ServerName, + serverKey ed25519.PublicKey, + serverKeyID gomatrixserverlib.KeyID, +) (*Database, error) { + db, err := sqlutil.Open(dbProperties) + if err != nil { + return nil, err + } + d := &Database{ + writer: sqlutil.NewExclusiveWriter(), + } + err = d.statements.prepare(db, d.writer) + if err != nil { + return nil, err + } + if err != nil { + return nil, err + } + return d, nil +} + +// FetcherName implements KeyFetcher +func (d Database) FetcherName() string { + return "SqliteKeyDatabase" +} + +// FetchKeys implements gomatrixserverlib.KeyDatabase +func (d *Database) FetchKeys( + ctx context.Context, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, +) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { + return d.statements.bulkSelectServerKeys(ctx, requests) +} + +// StoreKeys implements gomatrixserverlib.KeyDatabase +func (d *Database) StoreKeys( + ctx context.Context, + keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, +) error { + // TODO: Inserting all the keys within a single transaction may + // be more efficient since the transaction overhead can be quite + // high for a single insert statement. + var lastErr error + for request, keys := range keyMap { + if err := d.statements.upsertServerKeys(ctx, request, keys); err != nil { + // Rather than returning immediately on error we try to insert the + // remaining keys. + // Since we are inserting the keys outside of a transaction it is + // possible for some of the inserts to succeed even though some + // of the inserts have failed. + // Ensuring that we always insert all the keys we can means that + // this behaviour won't depend on the iteration order of the map. + lastErr = err + } + } + return lastErr +} diff --git a/signingkeyserver/storage/cosmosdb/server_key_table.go b/signingkeyserver/storage/cosmosdb/server_key_table.go new file mode 100644 index 000000000..e30de0a12 --- /dev/null +++ b/signingkeyserver/storage/cosmosdb/server_key_table.go @@ -0,0 +1,159 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const serverKeysSchema = ` +-- A cache of signing keys downloaded from remote servers. +CREATE TABLE IF NOT EXISTS keydb_server_keys ( + -- The name of the matrix server the key is for. + server_name TEXT NOT NULL, + -- The ID of the server key. + server_key_id TEXT NOT NULL, + -- Combined server name and key ID separated by the ASCII unit separator + -- to make it easier to run bulk queries. + server_name_and_key_id TEXT NOT NULL, + -- When the key is valid until as a millisecond timestamp. + -- 0 if this is an expired key (in which case expired_ts will be non-zero) + valid_until_ts BIGINT NOT NULL, + -- When the key expired as a millisecond timestamp. + -- 0 if this is an active key (in which case valid_until_ts will be non-zero) + expired_ts BIGINT NOT NULL, + -- The base64-encoded public key. + server_key TEXT NOT NULL, + UNIQUE (server_name, server_key_id) +); + +CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id); +` + +const bulkSelectServerKeysSQL = "" + + "SELECT server_name, server_key_id, valid_until_ts, expired_ts, " + + " server_key FROM keydb_server_keys" + + " WHERE server_name_and_key_id IN ($1)" + +const upsertServerKeysSQL = "" + + "INSERT INTO keydb_server_keys (server_name, server_key_id," + + " server_name_and_key_id, valid_until_ts, expired_ts, server_key)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT (server_name, server_key_id)" + + " DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6" + +type serverKeyStatements struct { + db *sql.DB + writer sqlutil.Writer + bulkSelectServerKeysStmt *sql.Stmt + upsertServerKeysStmt *sql.Stmt +} + +func (s *serverKeyStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { + s.db = db + s.writer = writer + _, err = db.Exec(serverKeysSchema) + if err != nil { + return + } + if s.bulkSelectServerKeysStmt, err = db.Prepare(bulkSelectServerKeysSQL); err != nil { + return + } + if s.upsertServerKeysStmt, err = db.Prepare(upsertServerKeysSQL); err != nil { + return + } + return +} + +func (s *serverKeyStatements) bulkSelectServerKeys( + ctx context.Context, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, +) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { + nameAndKeyIDs := make([]string, 0, len(requests)) + for request := range requests { + nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) + } + results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests)) + iKeyIDs := make([]interface{}, len(nameAndKeyIDs)) + for i, v := range nameAndKeyIDs { + iKeyIDs[i] = v + } + + err := sqlutil.RunLimitedVariablesQuery( + ctx, bulkSelectServerKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables, + func(rows *sql.Rows) error { + for rows.Next() { + var serverName string + var keyID string + var key string + var validUntilTS int64 + var expiredTS int64 + if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil { + return fmt.Errorf("bulkSelectServerKeys: %v", err) + } + r := gomatrixserverlib.PublicKeyLookupRequest{ + ServerName: gomatrixserverlib.ServerName(serverName), + KeyID: gomatrixserverlib.KeyID(keyID), + } + vk := gomatrixserverlib.VerifyKey{} + err := vk.Key.Decode(key) + if err != nil { + return fmt.Errorf("bulkSelectServerKeys: %v", err) + } + results[r] = gomatrixserverlib.PublicKeyLookupResult{ + VerifyKey: vk, + ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS), + ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), + } + } + return nil + }, + ) + + if err != nil { + return nil, err + } + return results, nil +} + +func (s *serverKeyStatements) upsertServerKeys( + ctx context.Context, + request gomatrixserverlib.PublicKeyLookupRequest, + key gomatrixserverlib.PublicKeyLookupResult, +) error { + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt) + _, err := stmt.ExecContext( + ctx, + string(request.ServerName), + string(request.KeyID), + nameAndKeyID(request), + key.ValidUntilTS, + key.ExpiredTS, + key.Key.Encode(), + ) + return err + }) +} + +func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string { + return string(request.ServerName) + "\x1F" + string(request.KeyID) +} diff --git a/signingkeyserver/storage/keydb.go b/signingkeyserver/storage/keydb.go index aa247f1d8..2f63427ac 100644 --- a/signingkeyserver/storage/keydb.go +++ b/signingkeyserver/storage/keydb.go @@ -22,6 +22,7 @@ import ( "golang.org/x/crypto/ed25519" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/signingkeyserver/storage/cosmosdb" "github.com/matrix-org/dendrite/signingkeyserver/storage/postgres" "github.com/matrix-org/dendrite/signingkeyserver/storage/sqlite3" "github.com/matrix-org/gomatrixserverlib" @@ -35,6 +36,8 @@ func NewDatabase( serverKeyID gomatrixserverlib.KeyID, ) (Database, error) { switch { + case dbProperties.ConnectionString.IsCosmosDB(): + return cosmosdb.NewDatabase(dbProperties, serverName, serverKey, serverKeyID) case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(dbProperties, serverName, serverKey, serverKeyID) case dbProperties.ConnectionString.IsPostgres(): diff --git a/syncapi/storage/cosmosdb/account_data_table.go b/syncapi/storage/cosmosdb/account_data_table.go new file mode 100644 index 000000000..308d3bd1f --- /dev/null +++ b/syncapi/storage/cosmosdb/account_data_table.go @@ -0,0 +1,156 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const accountDataSchema = ` +CREATE TABLE IF NOT EXISTS syncapi_account_data_type ( + id INTEGER PRIMARY KEY, + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + UNIQUE (user_id, room_id, type) +); +` + +const insertAccountDataSQL = "" + + "INSERT INTO syncapi_account_data_type (id, user_id, room_id, type) VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (user_id, room_id, type) DO UPDATE" + + " SET id = $5" + +const selectAccountDataInRangeSQL = "" + + "SELECT room_id, type FROM syncapi_account_data_type" + + " WHERE user_id = $1 AND id > $2 AND id <= $3" + + " ORDER BY id ASC" + +const selectMaxAccountDataIDSQL = "" + + "SELECT MAX(id) FROM syncapi_account_data_type" + +type accountDataStatements struct { + db *sql.DB + streamIDStatements *streamIDStatements + insertAccountDataStmt *sql.Stmt + selectMaxAccountDataIDStmt *sql.Stmt + selectAccountDataInRangeStmt *sql.Stmt +} + +func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) { + s := &accountDataStatements{ + db: db, + streamIDStatements: streamID, + } + _, err := db.Exec(accountDataSchema) + if err != nil { + return nil, err + } + if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { + return nil, err + } + if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil { + return nil, err + } + if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *accountDataStatements) InsertAccountData( + ctx context.Context, txn *sql.Tx, + userID, roomID, dataType string, +) (pos types.StreamPosition, err error) { + pos, err = s.streamIDStatements.nextAccountDataID(ctx, txn) + if err != nil { + return + } + _, err = sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType, pos) + return +} + +func (s *accountDataStatements) SelectAccountDataInRange( + ctx context.Context, + userID string, + r types.Range, + accountDataFilterPart *gomatrixserverlib.EventFilter, +) (data map[string][]string, err error) { + data = make(map[string][]string) + + rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High()) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed") + + var entries int + + for rows.Next() { + var dataType string + var roomID string + + if err = rows.Scan(&roomID, &dataType); err != nil { + return + } + + // check if we should add this by looking at the filter. + // It would be nice if we could do this in SQL-land, but the mix of variadic + // and positional parameters makes the query annoyingly hard to do, it's easier + // and clearer to do it in Go-land. If there are no filters for [not]types then + // this gets skipped. + for _, includeType := range accountDataFilterPart.Types { + if includeType != dataType { // TODO: wildcard support + continue + } + } + for _, excludeType := range accountDataFilterPart.NotTypes { + if excludeType == dataType { // TODO: wildcard support + continue + } + } + + if len(data[roomID]) > 0 { + data[roomID] = append(data[roomID], dataType) + } else { + data[roomID] = []string{dataType} + } + entries++ + if entries >= accountDataFilterPart.Limit { + break + } + } + + return data, nil +} + +func (s *accountDataStatements) SelectMaxAccountDataID( + ctx context.Context, txn *sql.Tx, +) (id int64, err error) { + var nullableID sql.NullInt64 + err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID) + if nullableID.Valid { + id = nullableID.Int64 + } + return +} diff --git a/syncapi/storage/cosmosdb/backwards_extremities_table.go b/syncapi/storage/cosmosdb/backwards_extremities_table.go new file mode 100644 index 000000000..5bc7d723b --- /dev/null +++ b/syncapi/storage/cosmosdb/backwards_extremities_table.go @@ -0,0 +1,125 @@ +// Copyright 2018 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" +) + +const backwardExtremitiesSchema = ` +-- Stores output room events received from the roomserver. +CREATE TABLE IF NOT EXISTS syncapi_backward_extremities ( + -- The 'room_id' key for the event. + room_id TEXT NOT NULL, + -- The event ID for the last known event. This is the backwards extremity. + event_id TEXT NOT NULL, + -- The prev_events for the last known event. This is used to update extremities. + prev_event_id TEXT NOT NULL, + PRIMARY KEY(room_id, event_id, prev_event_id) +); +` + +const insertBackwardExtremitySQL = "" + + "INSERT INTO syncapi_backward_extremities (room_id, event_id, prev_event_id)" + + " VALUES ($1, $2, $3)" + + " ON CONFLICT (room_id, event_id, prev_event_id) DO NOTHING" + +const selectBackwardExtremitiesForRoomSQL = "" + + "SELECT event_id, prev_event_id FROM syncapi_backward_extremities WHERE room_id = $1" + +const deleteBackwardExtremitySQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" + +const deleteBackwardExtremitiesForRoomSQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" + +type backwardExtremitiesStatements struct { + db *sql.DB + insertBackwardExtremityStmt *sql.Stmt + selectBackwardExtremitiesForRoomStmt *sql.Stmt + deleteBackwardExtremityStmt *sql.Stmt + deleteBackwardExtremitiesForRoomStmt *sql.Stmt +} + +func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { + s := &backwardExtremitiesStatements{ + db: db, + } + _, err := db.Exec(backwardExtremitiesSchema) + if err != nil { + return nil, err + } + if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil { + return nil, err + } + if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil { + return nil, err + } + if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { + return nil, err + } + if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( + ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) + return err +} + +func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( + ctx context.Context, roomID string, +) (bwExtrems map[string][]string, err error) { + rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "selectBackwardExtremitiesForRoom: rows.close() failed") + + bwExtrems = make(map[string][]string) + for rows.Next() { + var eID string + var prevEventID string + if err = rows.Scan(&eID, &prevEventID); err != nil { + return + } + bwExtrems[eID] = append(bwExtrems[eID], prevEventID) + } + + return bwExtrems, rows.Err() +} + +func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( + ctx context.Context, txn *sql.Tx, roomID, knownEventID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) + return err +} + +func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/cosmosdb/current_room_state_table.go b/syncapi/storage/cosmosdb/current_room_state_table.go new file mode 100644 index 000000000..a3a5a4a4a --- /dev/null +++ b/syncapi/storage/cosmosdb/current_room_state_table.go @@ -0,0 +1,324 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const currentRoomStateSchema = ` +-- Stores the current room state for every room. +CREATE TABLE IF NOT EXISTS syncapi_current_room_state ( + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + type TEXT NOT NULL, + sender TEXT NOT NULL, + contains_url BOOL NOT NULL DEFAULT false, + state_key TEXT NOT NULL, + headered_event_json TEXT NOT NULL, + membership TEXT, + added_at BIGINT, + UNIQUE (room_id, type, state_key) +); +-- for event deletion +CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url); +-- for querying membership states of users +-- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave'; +-- for querying state by event IDs +CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id); +` + +const upsertRoomStateSQL = "" + + "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + + " ON CONFLICT (room_id, type, state_key)" + + " DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9" + +const deleteRoomStateByEventIDSQL = "" + + "DELETE FROM syncapi_current_room_state WHERE event_id = $1" + +const DeleteRoomStateForRoomSQL = "" + + "DELETE FROM syncapi_current_room_state WHERE event_id = $1" + +const selectRoomIDsWithMembershipSQL = "" + + "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" + +const selectCurrentStateSQL = "" + + "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" + // WHEN, ORDER BY and LIMIT will be added by prepareWithFilter + +const selectJoinedUsersSQL = "" + + "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" + +const selectStateEventSQL = "" + + "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" + +const selectEventsWithEventIDsSQL = "" + + // TODO: The session_id and transaction_id blanks are here because otherwise + // the rowsToStreamEvents expects there to be exactly six columns. We need to + // figure out if these really need to be in the DB, and if so, we need a + // better permanent fix for this. - neilalexander, 2 Jan 2020 + "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" + + " FROM syncapi_current_room_state WHERE event_id IN ($1)" + +type currentRoomStateStatements struct { + db *sql.DB + streamIDStatements *streamIDStatements + upsertRoomStateStmt *sql.Stmt + deleteRoomStateByEventIDStmt *sql.Stmt + DeleteRoomStateForRoomStmt *sql.Stmt + selectRoomIDsWithMembershipStmt *sql.Stmt + selectJoinedUsersStmt *sql.Stmt + selectStateEventStmt *sql.Stmt +} + +func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { + s := ¤tRoomStateStatements{ + db: db, + streamIDStatements: streamID, + } + _, err := db.Exec(currentRoomStateSchema) + if err != nil { + return nil, err + } + if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { + return nil, err + } + if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { + return nil, err + } + if s.DeleteRoomStateForRoomStmt, err = db.Prepare(DeleteRoomStateForRoomSQL); err != nil { + return nil, err + } + if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { + return nil, err + } + if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { + return nil, err + } + if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { + return nil, err + } + return s, nil +} + +// JoinedMemberLists returns a map of room ID to a list of joined user IDs. +func (s *currentRoomStateStatements) SelectJoinedUsers( + ctx context.Context, +) (map[string][]string, error) { + rows, err := s.selectJoinedUsersStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed") + + result := make(map[string][]string) + for rows.Next() { + var roomID string + var userID string + if err := rows.Scan(&roomID, &userID); err != nil { + return nil, err + } + users := result[roomID] + users = append(users, userID) + result[roomID] = users + } + return result, nil +} + +// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. +func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( + ctx context.Context, + txn *sql.Tx, + userID string, + membership string, // nolint: unparam +) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithMembershipStmt) + rows, err := stmt.QueryContext(ctx, userID, membership) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsWithMembership: rows.close() failed") + + var result []string + for rows.Next() { + var roomID string + if err := rows.Scan(&roomID); err != nil { + return nil, err + } + result = append(result, roomID) + } + return result, nil +} + +// CurrentState returns all the current state events for the given room. +func (s *currentRoomStateStatements) SelectCurrentState( + ctx context.Context, txn *sql.Tx, roomID string, + stateFilter *gomatrixserverlib.StateFilter, + excludeEventIDs []string, +) ([]*gomatrixserverlib.HeaderedEvent, error) { + stmt, params, err := prepareWithFilters( + s.db, txn, selectCurrentStateSQL, + []interface{}{ + roomID, + }, + stateFilter.Senders, stateFilter.NotSenders, + stateFilter.Types, stateFilter.NotTypes, + excludeEventIDs, stateFilter.Limit, FilterOrderNone, + ) + if err != nil { + return nil, fmt.Errorf("s.prepareWithFilters: %w", err) + } + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectCurrentState: rows.close() failed") + + return rowsToEvents(rows) +} + +func (s *currentRoomStateStatements) DeleteRoomStateByEventID( + ctx context.Context, txn *sql.Tx, eventID string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) + _, err := stmt.ExecContext(ctx, eventID) + return err +} + +func (s *currentRoomStateStatements) DeleteRoomStateForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + stmt := sqlutil.TxStmt(txn, s.DeleteRoomStateForRoomStmt) + _, err := stmt.ExecContext(ctx, roomID) + return err +} + +func (s *currentRoomStateStatements) UpsertRoomState( + ctx context.Context, txn *sql.Tx, + event *gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition, +) error { + // Parse content as JSON and search for an "url" key + containsURL := false + var content map[string]interface{} + if json.Unmarshal(event.Content(), &content) != nil { + // Set containsURL to true if url is present + _, containsURL = content["url"] + } + + headeredJSON, err := json.Marshal(event) + if err != nil { + return err + } + + // upsert state event + stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) + _, err = stmt.ExecContext( + ctx, + event.RoomID(), + event.EventID(), + event.Type(), + event.Sender(), + containsURL, + *event.StateKey(), + headeredJSON, + membership, + addedAt, + ) + return err +} + +func minOfInts(a, b int) int { + if a <= b { + return a + } + return b +} + +func (s *currentRoomStateStatements) SelectEventsWithEventIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) ([]types.StreamEvent, error) { + iEventIDs := make([]interface{}, len(eventIDs)) + for k, v := range eventIDs { + iEventIDs[k] = v + } + res := make([]types.StreamEvent, 0, len(eventIDs)) + var start int + for start < len(eventIDs) { + n := minOfInts(len(eventIDs)-start, 999) + query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", sqlutil.QueryVariadic(n), 1) + rows, err := txn.QueryContext(ctx, query, iEventIDs[start:start+n]...) + if err != nil { + return nil, err + } + start = start + n + events, err := rowsToStreamEvents(rows) + internal.CloseAndLogIfError(ctx, rows, "selectEventsWithEventIDs: rows.close() failed") + if err != nil { + return nil, err + } + res = append(res, events...) + } + return res, nil +} + +func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { + result := []*gomatrixserverlib.HeaderedEvent{} + for rows.Next() { + var eventID string + var eventBytes []byte + if err := rows.Scan(&eventID, &eventBytes); err != nil { + return nil, err + } + // TODO: Handle redacted events + var ev gomatrixserverlib.HeaderedEvent + if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + return nil, err + } + result = append(result, &ev) + } + return result, nil +} + +func (s *currentRoomStateStatements) SelectStateEvent( + ctx context.Context, roomID, evType, stateKey string, +) (*gomatrixserverlib.HeaderedEvent, error) { + stmt := s.selectStateEventStmt + var res []byte + err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + var ev gomatrixserverlib.HeaderedEvent + if err = json.Unmarshal(res, &ev); err != nil { + return nil, err + } + return &ev, err +} diff --git a/syncapi/storage/cosmosdb/filter_table.go b/syncapi/storage/cosmosdb/filter_table.go new file mode 100644 index 000000000..9447ddd82 --- /dev/null +++ b/syncapi/storage/cosmosdb/filter_table.go @@ -0,0 +1,140 @@ +// Copyright 2017 Jan Christian Grünhage +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +const filterSchema = ` +-- Stores data about filters +CREATE TABLE IF NOT EXISTS syncapi_filter ( + -- The filter + filter TEXT NOT NULL, + -- The ID + id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The localpart of the Matrix user ID associated to this filter + localpart TEXT NOT NULL, + + UNIQUE (id, localpart) +); + +CREATE INDEX IF NOT EXISTS syncapi_filter_localpart ON syncapi_filter(localpart); +` + +const selectFilterSQL = "" + + "SELECT filter FROM syncapi_filter WHERE localpart = $1 AND id = $2" + +const selectFilterIDByContentSQL = "" + + "SELECT id FROM syncapi_filter WHERE localpart = $1 AND filter = $2" + +const insertFilterSQL = "" + + "INSERT INTO syncapi_filter (filter, localpart) VALUES ($1, $2)" + +type filterStatements struct { + db *sql.DB + selectFilterStmt *sql.Stmt + selectFilterIDByContentStmt *sql.Stmt + insertFilterStmt *sql.Stmt +} + +func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) { + _, err := db.Exec(filterSchema) + if err != nil { + return nil, err + } + s := &filterStatements{ + db: db, + } + if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { + return nil, err + } + if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil { + return nil, err + } + if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *filterStatements) SelectFilter( + ctx context.Context, localpart string, filterID string, +) (*gomatrixserverlib.Filter, error) { + // Retrieve filter from database (stored as canonical JSON) + var filterData []byte + err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) + if err != nil { + return nil, err + } + + // Unmarshal JSON into Filter struct + filter := gomatrixserverlib.DefaultFilter() + if err = json.Unmarshal(filterData, &filter); err != nil { + return nil, err + } + return &filter, nil +} + +func (s *filterStatements) InsertFilter( + ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, +) (filterID string, err error) { + var existingFilterID string + + // Serialise json + filterJSON, err := json.Marshal(filter) + if err != nil { + return "", err + } + // Remove whitespaces and sort JSON data + // needed to prevent from inserting the same filter multiple times + filterJSON, err = gomatrixserverlib.CanonicalJSON(filterJSON) + if err != nil { + return "", err + } + + // Check if filter already exists in the database using its localpart and content + // + // This can result in a race condition when two clients try to insert the + // same filter and localpart at the same time, however this is not a + // problem as both calls will result in the same filterID + err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, + localpart, filterJSON).Scan(&existingFilterID) + if err != nil && err != sql.ErrNoRows { + return "", err + } + // If it does, return the existing ID + if existingFilterID != "" { + return existingFilterID, nil + } + + // Otherwise insert the filter and return the new ID + res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) + if err != nil { + return "", err + } + rowid, err := res.LastInsertId() + if err != nil { + return "", err + } + filterID = fmt.Sprintf("%d", rowid) + return +} diff --git a/syncapi/storage/cosmosdb/filtering.go b/syncapi/storage/cosmosdb/filtering.go new file mode 100644 index 000000000..62d2434f6 --- /dev/null +++ b/syncapi/storage/cosmosdb/filtering.go @@ -0,0 +1,82 @@ +package cosmosdb + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +type FilterOrder int + +const ( + FilterOrderNone = iota + FilterOrderAsc + FilterOrderDesc +) + +// prepareWithFilters returns a prepared statement with the +// relevant filters included. It also includes an []interface{} +// list of all the relevant parameters to pass straight to +// QueryContext, QueryRowContext etc. +// We don't take the filter object directly here because the +// fields might come from either a StateFilter or an EventFilter, +// and it's easier just to have the caller extract the relevant +// parts. +func prepareWithFilters( + db *sql.DB, txn *sql.Tx, query string, params []interface{}, + senders, notsenders, types, nottypes []string, excludeEventIDs []string, + limit int, order FilterOrder, +) (*sql.Stmt, []interface{}, error) { + offset := len(params) + if count := len(senders); count > 0 { + query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range senders { + params, offset = append(params, v), offset+1 + } + } + if count := len(notsenders); count > 0 { + query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range notsenders { + params, offset = append(params, v), offset+1 + } + } + if count := len(types); count > 0 { + query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range types { + params, offset = append(params, v), offset+1 + } + } + if count := len(nottypes); count > 0 { + query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range nottypes { + params, offset = append(params, v), offset+1 + } + } + if count := len(excludeEventIDs); count > 0 { + query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range excludeEventIDs { + params, offset = append(params, v), offset+1 + } + } + switch order { + case FilterOrderAsc: + query += " ORDER BY id ASC" + case FilterOrderDesc: + query += " ORDER BY id DESC" + } + query += fmt.Sprintf(" LIMIT $%d", offset+1) + params = append(params, limit) + + var stmt *sql.Stmt + var err error + if txn != nil { + stmt, err = txn.Prepare(query) + } else { + stmt, err = db.Prepare(query) + } + if err != nil { + return nil, nil, fmt.Errorf("s.db.Prepare: %w", err) + } + return stmt, params, nil +} diff --git a/syncapi/storage/cosmosdb/invites_table.go b/syncapi/storage/cosmosdb/invites_table.go new file mode 100644 index 000000000..ea5d0bd85 --- /dev/null +++ b/syncapi/storage/cosmosdb/invites_table.go @@ -0,0 +1,185 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const inviteEventsSchema = ` +CREATE TABLE IF NOT EXISTS syncapi_invite_events ( + id INTEGER PRIMARY KEY, + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + target_user_id TEXT NOT NULL, + headered_event_json TEXT NOT NULL, + deleted BOOL NOT NULL +); + +CREATE INDEX IF NOT EXISTS syncapi_invites_target_user_id_idx ON syncapi_invite_events (target_user_id, id); +CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx ON syncapi_invite_events (event_id); +` + +const insertInviteEventSQL = "" + + "INSERT INTO syncapi_invite_events" + + " (id, room_id, event_id, target_user_id, headered_event_json, deleted)" + + " VALUES ($1, $2, $3, $4, $5, false)" + +const deleteInviteEventSQL = "" + + "UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2" + +const selectInviteEventsInRangeSQL = "" + + "SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" + + " WHERE target_user_id = $1 AND id > $2 AND id <= $3" + + " ORDER BY id DESC" + +const selectMaxInviteIDSQL = "" + + "SELECT MAX(id) FROM syncapi_invite_events" + +type inviteEventsStatements struct { + db *sql.DB + streamIDStatements *streamIDStatements + insertInviteEventStmt *sql.Stmt + selectInviteEventsInRangeStmt *sql.Stmt + deleteInviteEventStmt *sql.Stmt + selectMaxInviteIDStmt *sql.Stmt +} + +func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) { + s := &inviteEventsStatements{ + db: db, + streamIDStatements: streamID, + } + _, err := db.Exec(inviteEventsSchema) + if err != nil { + return nil, err + } + if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { + return nil, err + } + if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { + return nil, err + } + if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil { + return nil, err + } + if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *inviteEventsStatements) InsertInviteEvent( + ctx context.Context, txn *sql.Tx, inviteEvent *gomatrixserverlib.HeaderedEvent, +) (streamPos types.StreamPosition, err error) { + streamPos, err = s.streamIDStatements.nextInviteID(ctx, txn) + if err != nil { + return + } + + var headeredJSON []byte + headeredJSON, err = json.Marshal(inviteEvent) + if err != nil { + return + } + + stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) + _, err = stmt.ExecContext( + ctx, + streamPos, + inviteEvent.RoomID(), + inviteEvent.EventID(), + *inviteEvent.StateKey(), + headeredJSON, + ) + return +} + +func (s *inviteEventsStatements) DeleteInviteEvent( + ctx context.Context, txn *sql.Tx, inviteEventID string, +) (types.StreamPosition, error) { + streamPos, err := s.streamIDStatements.nextInviteID(ctx, txn) + if err != nil { + return streamPos, err + } + stmt := sqlutil.TxStmt(txn, s.deleteInviteEventStmt) + _, err = stmt.ExecContext(ctx, streamPos, inviteEventID) + return streamPos, err +} + +// selectInviteEventsInRange returns a map of room ID to invite event for the +// active invites for the target user ID in the supplied range. +func (s *inviteEventsStatements) SelectInviteEventsInRange( + ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range, +) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) { + stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt) + rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High()) + if err != nil { + return nil, nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed") + result := map[string]*gomatrixserverlib.HeaderedEvent{} + retired := map[string]*gomatrixserverlib.HeaderedEvent{} + for rows.Next() { + var ( + roomID string + eventJSON []byte + deleted bool + ) + if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil { + return nil, nil, err + } + + // if we have seen this room before, it has a higher stream position and hence takes priority + // because the query is ORDER BY id DESC so drop them + _, isRetired := retired[roomID] + _, isInvited := result[roomID] + if isRetired || isInvited { + continue + } + + var event *gomatrixserverlib.HeaderedEvent + if err := json.Unmarshal(eventJSON, &event); err != nil { + return nil, nil, err + } + if deleted { + retired[roomID] = event + } else { + result[roomID] = event + } + } + return result, retired, nil +} + +func (s *inviteEventsStatements) SelectMaxInviteID( + ctx context.Context, txn *sql.Tx, +) (id int64, err error) { + var nullableID sql.NullInt64 + stmt := sqlutil.TxStmt(txn, s.selectMaxInviteIDStmt) + err = stmt.QueryRowContext(ctx).Scan(&nullableID) + if nullableID.Valid { + id = nullableID.Int64 + } + return +} diff --git a/syncapi/storage/cosmosdb/memberships_table.go b/syncapi/storage/cosmosdb/memberships_table.go new file mode 100644 index 000000000..9b660509b --- /dev/null +++ b/syncapi/storage/cosmosdb/memberships_table.go @@ -0,0 +1,119 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +// The memberships table is designed to track the last time that +// the user was a given state. This allows us to find out the +// most recent time that a user was invited to, joined or left +// a room, either by choice or otherwise. This is important for +// building history visibility. + +const membershipsSchema = ` +CREATE TABLE IF NOT EXISTS syncapi_memberships ( + -- The 'room_id' key for the state event. + room_id TEXT NOT NULL, + -- The state event ID + user_id TEXT NOT NULL, + -- The status of the membership + membership TEXT NOT NULL, + -- The event ID that last changed the membership + event_id TEXT NOT NULL, + -- The stream position of the change + stream_pos BIGINT NOT NULL, + -- The topological position of the change in the room + topological_pos BIGINT NOT NULL, + -- Unique index + UNIQUE (room_id, user_id, membership) +); +` + +const upsertMembershipSQL = "" + + "INSERT INTO syncapi_memberships (room_id, user_id, membership, event_id, stream_pos, topological_pos)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT (room_id, user_id, membership)" + + " DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6" + +const selectMembershipSQL = "" + + "SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" + + " WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" + + " ORDER BY stream_pos DESC" + + " LIMIT 1" + +type membershipsStatements struct { + db *sql.DB + upsertMembershipStmt *sql.Stmt +} + +func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { + s := &membershipsStatements{ + db: db, + } + _, err := db.Exec(membershipsSchema) + if err != nil { + return nil, err + } + if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *membershipsStatements) UpsertMembership( + ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, + streamPos, topologicalPos types.StreamPosition, +) error { + membership, err := event.Membership() + if err != nil { + return fmt.Errorf("event.Membership: %w", err) + } + _, err = sqlutil.TxStmt(txn, s.upsertMembershipStmt).ExecContext( + ctx, + event.RoomID(), + *event.StateKey(), + membership, + event.EventID(), + streamPos, + topologicalPos, + ) + return err +} + +func (s *membershipsStatements) SelectMembership( + ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string, +) (eventID string, streamPos, topologyPos types.StreamPosition, err error) { + params := []interface{}{roomID, userID} + for _, membership := range memberships { + params = append(params, membership) + } + orig := strings.Replace(selectMembershipSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1) + stmt, err := s.db.Prepare(orig) + if err != nil { + return "", 0, 0, err + } + err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos) + return +} diff --git a/syncapi/storage/cosmosdb/output_room_events_table.go b/syncapi/storage/cosmosdb/output_room_events_table.go new file mode 100644 index 000000000..7ce485d8f --- /dev/null +++ b/syncapi/storage/cosmosdb/output_room_events_table.go @@ -0,0 +1,477 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "sort" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +const outputRoomEventsSchema = ` +-- Stores output room events received from the roomserver. +CREATE TABLE IF NOT EXISTS syncapi_output_room_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + event_id TEXT NOT NULL UNIQUE, + room_id TEXT NOT NULL, + headered_event_json TEXT NOT NULL, + type TEXT NOT NULL, + sender TEXT NOT NULL, + contains_url BOOL NOT NULL, + add_state_ids TEXT, -- JSON encoded string array + remove_state_ids TEXT, -- JSON encoded string array + session_id BIGINT, + transaction_id TEXT, + exclude_from_sync BOOL NOT NULL DEFAULT FALSE +); +` + +const insertEventSQL = "" + + "INSERT INTO syncapi_output_room_events (" + + "id, room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" + + ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " + + "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)" + +const selectEventsSQL = "" + + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1" + +const selectRecentEventsSQL = "" + + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + + " WHERE room_id = $1 AND id > $2 AND id <= $3" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters + +const selectRecentEventsForSyncSQL = "" + + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + + " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters + +const selectEarlyEventsSQL = "" + + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + + " WHERE room_id = $1 AND id > $2 AND id <= $3" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters + +const selectMaxEventIDSQL = "" + + "SELECT MAX(id) FROM syncapi_output_room_events" + +const updateEventJSONSQL = "" + + "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2" + +const selectStateInRangeSQL = "" + + "SELECT id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" + + " FROM syncapi_output_room_events" + + " WHERE (id > $1 AND id <= $2)" + + " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters + +const deleteEventsForRoomSQL = "" + + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" + +type outputRoomEventsStatements struct { + db *sql.DB + streamIDStatements *streamIDStatements + insertEventStmt *sql.Stmt + selectEventsStmt *sql.Stmt + selectMaxEventIDStmt *sql.Stmt + updateEventJSONStmt *sql.Stmt + deleteEventsForRoomStmt *sql.Stmt +} + +func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { + s := &outputRoomEventsStatements{ + db: db, + streamIDStatements: streamID, + } + _, err := db.Exec(outputRoomEventsSchema) + if err != nil { + return nil, err + } + if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { + return nil, err + } + if s.selectEventsStmt, err = db.Prepare(selectEventsSQL); err != nil { + return nil, err + } + if s.selectMaxEventIDStmt, err = db.Prepare(selectMaxEventIDSQL); err != nil { + return nil, err + } + if s.updateEventJSONStmt, err = db.Prepare(updateEventJSONSQL); err != nil { + return nil, err + } + if s.deleteEventsForRoomStmt, err = db.Prepare(deleteEventsForRoomSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error { + headeredJSON, err := json.Marshal(event) + if err != nil { + return err + } + _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) + return err +} + +// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. +// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the +// two positions, only the most recent state is returned. +func (s *outputRoomEventsStatements) SelectStateInRange( + ctx context.Context, txn *sql.Tx, r types.Range, + stateFilter *gomatrixserverlib.StateFilter, +) (map[string]map[string]bool, map[string]types.StreamEvent, error) { + stmt, params, err := prepareWithFilters( + s.db, txn, selectStateInRangeSQL, + []interface{}{ + r.Low(), r.High(), + }, + stateFilter.Senders, stateFilter.NotSenders, + stateFilter.Types, stateFilter.NotTypes, + nil, stateFilter.Limit, FilterOrderAsc, + ) + if err != nil { + return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err) + } + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, nil, err + } + defer rows.Close() // nolint: errcheck + // Fetch all the state change events for all rooms between the two positions then loop each event and: + // - Keep a cache of the event by ID (99% of state change events are for the event itself) + // - For each room ID, build up an array of event IDs which represents cumulative adds/removes + // For each room, map cumulative event IDs to events and return. This may need to a batch SELECT based on event ID + // if they aren't in the event ID cache. We don't handle state deletion yet. + eventIDToEvent := make(map[string]types.StreamEvent) + + // RoomID => A set (map[string]bool) of state event IDs which are between the two positions + stateNeeded := make(map[string]map[string]bool) + + for rows.Next() { + var ( + streamPos types.StreamPosition + eventBytes []byte + excludeFromSync bool + addIDsJSON string + delIDsJSON string + ) + if err := rows.Scan(&streamPos, &eventBytes, &excludeFromSync, &addIDsJSON, &delIDsJSON); err != nil { + return nil, nil, err + } + + addIDs, delIDs, err := unmarshalStateIDs(addIDsJSON, delIDsJSON) + if err != nil { + return nil, nil, err + } + + // Sanity check for deleted state and whine if we see it. We don't need to do anything + // since it'll just mark the event as not being needed. + if len(addIDs) < len(delIDs) { + log.WithFields(log.Fields{ + "since": r.From, + "current": r.To, + "adds": addIDsJSON, + "dels": delIDsJSON, + }).Warn("StateBetween: ignoring deleted state") + } + + // TODO: Handle redacted events + var ev gomatrixserverlib.HeaderedEvent + if err := json.Unmarshal(eventBytes, &ev); err != nil { + return nil, nil, err + } + needSet := stateNeeded[ev.RoomID()] + if needSet == nil { // make set if required + needSet = make(map[string]bool) + } + for _, id := range delIDs { + needSet[id] = false + } + for _, id := range addIDs { + needSet[id] = true + } + stateNeeded[ev.RoomID()] = needSet + + eventIDToEvent[ev.EventID()] = types.StreamEvent{ + HeaderedEvent: &ev, + StreamPosition: streamPos, + ExcludeFromSync: excludeFromSync, + } + } + + return stateNeeded, eventIDToEvent, nil +} + +// MaxID returns the ID of the last inserted event in this table. 'txn' is optional. If it is not supplied, +// then this function should only ever be used at startup, as it will race with inserting events if it is +// done afterwards. If there are no inserted events, 0 is returned. +func (s *outputRoomEventsStatements) SelectMaxEventID( + ctx context.Context, txn *sql.Tx, +) (id int64, err error) { + var nullableID sql.NullInt64 + stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt) + err = stmt.QueryRowContext(ctx).Scan(&nullableID) + if nullableID.Valid { + id = nullableID.Int64 + } + return +} + +// InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position +// of the inserted event. +func (s *outputRoomEventsStatements) InsertEvent( + ctx context.Context, txn *sql.Tx, + event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, + transactionID *api.TransactionID, excludeFromSync bool, +) (types.StreamPosition, error) { + var txnID *string + var sessionID *int64 + if transactionID != nil { + sessionID = &transactionID.SessionID + txnID = &transactionID.TransactionID + } + + // Parse content as JSON and search for an "url" key + containsURL := false + var content map[string]interface{} + if json.Unmarshal(event.Content(), &content) != nil { + // Set containsURL to true if url is present + _, containsURL = content["url"] + } + + var headeredJSON []byte + headeredJSON, err := json.Marshal(event) + if err != nil { + return 0, err + } + + var addStateJSON, removeStateJSON []byte + if len(addState) > 0 { + addStateJSON, err = json.Marshal(addState) + } + if err != nil { + return 0, fmt.Errorf("json.Marshal(addState): %w", err) + } + if len(removeState) > 0 { + removeStateJSON, err = json.Marshal(removeState) + } + if err != nil { + return 0, fmt.Errorf("json.Marshal(removeState): %w", err) + } + + streamPos, err := s.streamIDStatements.nextPDUID(ctx, txn) + if err != nil { + return 0, err + } + insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) + _, err = insertStmt.ExecContext( + ctx, + streamPos, + event.RoomID(), + event.EventID(), + headeredJSON, + event.Type(), + event.Sender(), + containsURL, + string(addStateJSON), + string(removeStateJSON), + sessionID, + txnID, + excludeFromSync, + excludeFromSync, + ) + return streamPos, err +} + +func (s *outputRoomEventsStatements) SelectRecentEvents( + ctx context.Context, txn *sql.Tx, + roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, + chronologicalOrder bool, onlySyncEvents bool, +) ([]types.StreamEvent, bool, error) { + var query string + if onlySyncEvents { + query = selectRecentEventsForSyncSQL + } else { + query = selectRecentEventsSQL + } + + stmt, params, err := prepareWithFilters( + s.db, txn, query, + []interface{}{ + roomID, r.Low(), r.High(), + }, + eventFilter.Senders, eventFilter.NotSenders, + eventFilter.Types, eventFilter.NotTypes, + nil, eventFilter.Limit+1, FilterOrderDesc, + ) + if err != nil { + return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) + } + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, false, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") + events, err := rowsToStreamEvents(rows) + if err != nil { + return nil, false, err + } + if chronologicalOrder { + // The events need to be returned from oldest to latest, which isn't + // necessary the way the SQL query returns them, so a sort is necessary to + // ensure the events are in the right order in the slice. + sort.SliceStable(events, func(i int, j int) bool { + return events[i].StreamPosition < events[j].StreamPosition + }) + } + // we queried for 1 more than the limit, so if we returned one more mark limited=true + limited := false + if len(events) > eventFilter.Limit { + limited = true + // re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last. + if chronologicalOrder { + events = events[1:] + } else { + events = events[:len(events)-1] + } + } + return events, limited, nil +} + +func (s *outputRoomEventsStatements) SelectEarlyEvents( + ctx context.Context, txn *sql.Tx, + roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, +) ([]types.StreamEvent, error) { + stmt, params, err := prepareWithFilters( + s.db, txn, selectEarlyEventsSQL, + []interface{}{ + roomID, r.Low(), r.High(), + }, + eventFilter.Senders, eventFilter.NotSenders, + eventFilter.Types, eventFilter.NotTypes, + nil, eventFilter.Limit, FilterOrderAsc, + ) + if err != nil { + return nil, fmt.Errorf("s.prepareWithFilters: %w", err) + } + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectEarlyEvents: rows.close() failed") + events, err := rowsToStreamEvents(rows) + if err != nil { + return nil, err + } + // The events need to be returned from oldest to latest, which isn't + // necessarily the way the SQL query returns them, so a sort is necessary to + // ensure the events are in the right order in the slice. + sort.SliceStable(events, func(i int, j int) bool { + return events[i].StreamPosition < events[j].StreamPosition + }) + return events, nil +} + +// selectEvents returns the events for the given event IDs. If an event is +// missing from the database, it will be omitted. +func (s *outputRoomEventsStatements) SelectEvents( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) ([]types.StreamEvent, error) { + var returnEvents []types.StreamEvent + stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) + for _, eventID := range eventIDs { + rows, err := stmt.QueryContext(ctx, eventID) + if err != nil { + return nil, err + } + if streamEvents, err := rowsToStreamEvents(rows); err == nil { + returnEvents = append(returnEvents, streamEvents...) + } + internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed") + } + return returnEvents, nil +} + +func (s *outputRoomEventsStatements) DeleteEventsForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteEventsForRoomStmt).ExecContext(ctx, roomID) + return err +} + +func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { + var result []types.StreamEvent + for rows.Next() { + var ( + eventID string + streamPos types.StreamPosition + eventBytes []byte + excludeFromSync bool + sessionID *int64 + txnID *string + transactionID *api.TransactionID + ) + if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil { + return nil, err + } + // TODO: Handle redacted events + var ev gomatrixserverlib.HeaderedEvent + if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + return nil, err + } + + if sessionID != nil && txnID != nil { + transactionID = &api.TransactionID{ + SessionID: *sessionID, + TransactionID: *txnID, + } + } + + result = append(result, types.StreamEvent{ + HeaderedEvent: &ev, + StreamPosition: streamPos, + TransactionID: transactionID, + ExcludeFromSync: excludeFromSync, + }) + } + return result, nil +} + +func unmarshalStateIDs(addIDsJSON, delIDsJSON string) (addIDs []string, delIDs []string, err error) { + if len(addIDsJSON) > 0 { + if err = json.Unmarshal([]byte(addIDsJSON), &addIDs); err != nil { + return + } + } + if len(delIDsJSON) > 0 { + if err = json.Unmarshal([]byte(delIDsJSON), &delIDs); err != nil { + return + } + } + return +} diff --git a/syncapi/storage/cosmosdb/output_room_events_topology_table.go b/syncapi/storage/cosmosdb/output_room_events_topology_table.go new file mode 100644 index 000000000..1a52b76b8 --- /dev/null +++ b/syncapi/storage/cosmosdb/output_room_events_topology_table.go @@ -0,0 +1,179 @@ +// Copyright 2018 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const outputRoomEventsTopologySchema = ` +-- Stores output room events received from the roomserver. +CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology ( + event_id TEXT PRIMARY KEY, + topological_position BIGINT NOT NULL, + stream_position BIGINT NOT NULL, + room_id TEXT NOT NULL, + + UNIQUE(topological_position, room_id, stream_position) +); +-- The topological order will be used in events selection and ordering +-- CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, stream_position, room_id); +` + +const insertEventInTopologySQL = "" + + "INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id, stream_position)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT DO NOTHING" + +const selectEventIDsInRangeASCSQL = "" + + "SELECT event_id FROM syncapi_output_room_events_topology" + + " WHERE room_id = $1 AND (" + + "(topological_position > $2 AND topological_position < $3) OR" + + "(topological_position = $4 AND stream_position <= $5)" + + ") ORDER BY topological_position ASC, stream_position ASC LIMIT $6" + +const selectEventIDsInRangeDESCSQL = "" + + "SELECT event_id FROM syncapi_output_room_events_topology" + + " WHERE room_id = $1 AND (" + + "(topological_position > $2 AND topological_position < $3) OR" + + "(topological_position = $4 AND stream_position <= $5)" + + ") ORDER BY topological_position DESC, stream_position DESC LIMIT $6" + +const selectPositionInTopologySQL = "" + + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + + " WHERE event_id = $1" + +const selectMaxPositionInTopologySQL = "" + + "SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" + + " WHERE room_id = $1 ORDER BY stream_position DESC" + +const deleteTopologyForRoomSQL = "" + + "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" + +type outputRoomEventsTopologyStatements struct { + db *sql.DB + insertEventInTopologyStmt *sql.Stmt + selectEventIDsInRangeASCStmt *sql.Stmt + selectEventIDsInRangeDESCStmt *sql.Stmt + selectPositionInTopologyStmt *sql.Stmt + selectMaxPositionInTopologyStmt *sql.Stmt + deleteTopologyForRoomStmt *sql.Stmt +} + +func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { + s := &outputRoomEventsTopologyStatements{ + db: db, + } + _, err := db.Exec(outputRoomEventsTopologySchema) + if err != nil { + return nil, err + } + if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil { + return nil, err + } + if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil { + return nil, err + } + if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil { + return nil, err + } + if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { + return nil, err + } + if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { + return nil, err + } + if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil { + return nil, err + } + return s, nil +} + +// insertEventInTopology inserts the given event in the room's topology, based +// on the event's depth. +func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( + ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, +) (types.StreamPosition, error) { + _, err := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt).ExecContext( + ctx, event.EventID(), event.Depth(), event.RoomID(), pos, + ) + return types.StreamPosition(event.Depth()), err +} + +func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( + ctx context.Context, txn *sql.Tx, roomID string, + minDepth, maxDepth, maxStreamPos types.StreamPosition, + limit int, chronologicalOrder bool, +) (eventIDs []string, err error) { + // Decide on the selection's order according to whether chronological order + // is requested or not. + var stmt *sql.Stmt + if chronologicalOrder { + stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt) + } else { + stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt) + } + + // Query the event IDs. + rows, err := stmt.QueryContext(ctx, roomID, minDepth, maxDepth, maxDepth, maxStreamPos, limit) + if err == sql.ErrNoRows { + // If no event matched the request, return an empty slice. + return []string{}, nil + } else if err != nil { + return + } + + // Return the IDs. + var eventID string + for rows.Next() { + if err = rows.Scan(&eventID); err != nil { + return + } + eventIDs = append(eventIDs, eventID) + } + + return +} + +// selectPositionInTopology returns the position of a given event in the +// topology of the room it belongs to. +func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology( + ctx context.Context, txn *sql.Tx, eventID string, +) (pos types.StreamPosition, spos types.StreamPosition, err error) { + stmt := sqlutil.TxStmt(txn, s.selectPositionInTopologyStmt) + err = stmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos) + return +} + +func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( + ctx context.Context, txn *sql.Tx, roomID string, +) (pos types.StreamPosition, spos types.StreamPosition, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt) + err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) + return +} + +func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/cosmosdb/peeks_table.go b/syncapi/storage/cosmosdb/peeks_table.go new file mode 100644 index 000000000..e6d9b8a3c --- /dev/null +++ b/syncapi/storage/cosmosdb/peeks_table.go @@ -0,0 +1,206 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" +) + +const peeksSchema = ` +CREATE TABLE IF NOT EXISTS syncapi_peeks ( + id INTEGER, + room_id TEXT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + deleted BOOL NOT NULL DEFAULT false, + -- When the peek was created in UNIX epoch ms. + creation_ts INTEGER NOT NULL, + UNIQUE(room_id, user_id, device_id) +); + +CREATE INDEX IF NOT EXISTS syncapi_peeks_room_id_idx ON syncapi_peeks(room_id); +CREATE INDEX IF NOT EXISTS syncapi_peeks_user_id_device_id_idx ON syncapi_peeks(user_id, device_id); +` + +const insertPeekSQL = "" + + "INSERT OR REPLACE INTO syncapi_peeks" + + " (id, room_id, user_id, device_id, creation_ts, deleted)" + + " VALUES ($1, $2, $3, $4, $5, false)" + +const deletePeekSQL = "" + + "UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3 AND device_id = $4" + +const deletePeeksSQL = "" + + "UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3" + +// we care about all the peeks which were created in this range, deleted in this range, +// or were created before this range but haven't been deleted yet. +// BEWARE: sqlite chokes on out of order substitution strings. +const selectPeeksInRangeSQL = "" + + "SELECT id, room_id, deleted FROM syncapi_peeks WHERE user_id = $1 AND device_id = $2 AND ((id <= $3 AND NOT deleted=true) OR (id > $3 AND id <= $4))" + +const selectPeekingDevicesSQL = "" + + "SELECT room_id, user_id, device_id FROM syncapi_peeks WHERE deleted=false" + +const selectMaxPeekIDSQL = "" + + "SELECT MAX(id) FROM syncapi_peeks" + +type peekStatements struct { + db *sql.DB + streamIDStatements *streamIDStatements + insertPeekStmt *sql.Stmt + deletePeekStmt *sql.Stmt + deletePeeksStmt *sql.Stmt + selectPeeksInRangeStmt *sql.Stmt + selectPeekingDevicesStmt *sql.Stmt + selectMaxPeekIDStmt *sql.Stmt +} + +func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) { + _, err := db.Exec(peeksSchema) + if err != nil { + return nil, err + } + s := &peekStatements{ + db: db, + streamIDStatements: streamID, + } + if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil { + return nil, err + } + if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil { + return nil, err + } + if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil { + return nil, err + } + if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil { + return nil, err + } + if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil { + return nil, err + } + if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *peekStatements) InsertPeek( + ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string, +) (streamPos types.StreamPosition, err error) { + streamPos, err = s.streamIDStatements.nextPDUID(ctx, txn) + if err != nil { + return + } + nowMilli := time.Now().UnixNano() / int64(time.Millisecond) + _, err = sqlutil.TxStmt(txn, s.insertPeekStmt).ExecContext(ctx, streamPos, roomID, userID, deviceID, nowMilli) + return +} + +func (s *peekStatements) DeletePeek( + ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string, +) (streamPos types.StreamPosition, err error) { + streamPos, err = s.streamIDStatements.nextPDUID(ctx, txn) + if err != nil { + return + } + _, err = sqlutil.TxStmt(txn, s.deletePeekStmt).ExecContext(ctx, streamPos, roomID, userID, deviceID) + return +} + +func (s *peekStatements) DeletePeeks( + ctx context.Context, txn *sql.Tx, roomID, userID string, +) (types.StreamPosition, error) { + streamPos, err := s.streamIDStatements.nextPDUID(ctx, txn) + if err != nil { + return 0, err + } + result, err := sqlutil.TxStmt(txn, s.deletePeeksStmt).ExecContext(ctx, streamPos, roomID, userID) + if err != nil { + return 0, err + } + numAffected, err := result.RowsAffected() + if err != nil { + return 0, err + } + if numAffected == 0 { + return 0, sql.ErrNoRows + } + return streamPos, nil +} + +func (s *peekStatements) SelectPeeksInRange( + ctx context.Context, txn *sql.Tx, userID, deviceID string, r types.Range, +) (peeks []types.Peek, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectPeeksInRangeStmt).QueryContext(ctx, userID, deviceID, r.Low(), r.High()) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectPeeksInRange: rows.close() failed") + + for rows.Next() { + peek := types.Peek{} + var id types.StreamPosition + if err = rows.Scan(&id, &peek.RoomID, &peek.Deleted); err != nil { + return + } + peek.New = (id > r.Low() && id <= r.High()) && !peek.Deleted + peeks = append(peeks, peek) + } + + return peeks, rows.Err() +} + +func (s *peekStatements) SelectPeekingDevices( + ctx context.Context, +) (peekingDevices map[string][]types.PeekingDevice, err error) { + rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectPeekingDevices: rows.close() failed") + + result := make(map[string][]types.PeekingDevice) + for rows.Next() { + var roomID, userID, deviceID string + if err := rows.Scan(&roomID, &userID, &deviceID); err != nil { + return nil, err + } + devices := result[roomID] + devices = append(devices, types.PeekingDevice{UserID: userID, DeviceID: deviceID}) + result[roomID] = devices + } + return result, nil +} + +func (s *peekStatements) SelectMaxPeekID( + ctx context.Context, txn *sql.Tx, +) (id int64, err error) { + var nullableID sql.NullInt64 + stmt := sqlutil.TxStmt(txn, s.selectMaxPeekIDStmt) + err = stmt.QueryRowContext(ctx).Scan(&nullableID) + if nullableID.Valid { + id = nullableID.Int64 + } + return +} diff --git a/syncapi/storage/cosmosdb/receipt_table.go b/syncapi/storage/cosmosdb/receipt_table.go new file mode 100644 index 000000000..de3983c5b --- /dev/null +++ b/syncapi/storage/cosmosdb/receipt_table.go @@ -0,0 +1,141 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const receiptsSchema = ` +-- Stores data about receipts +CREATE TABLE IF NOT EXISTS syncapi_receipts ( + -- The ID + id BIGINT, + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + receipt_ts BIGINT NOT NULL, + CONSTRAINT syncapi_receipts_unique UNIQUE (room_id, receipt_type, user_id) +); +CREATE INDEX IF NOT EXISTS syncapi_receipts_room_id_idx ON syncapi_receipts(room_id); +` + +const upsertReceipt = "" + + "INSERT INTO syncapi_receipts" + + " (id, room_id, receipt_type, user_id, event_id, receipt_ts)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT (room_id, receipt_type, user_id)" + + " DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9" + +const selectRoomReceipts = "" + + "SELECT id, room_id, receipt_type, user_id, event_id, receipt_ts" + + " FROM syncapi_receipts" + + " WHERE id > $1 and room_id in ($2)" + +const selectMaxReceiptIDSQL = "" + + "SELECT MAX(id) FROM syncapi_receipts" + +type receiptStatements struct { + db *sql.DB + streamIDStatements *streamIDStatements + upsertReceipt *sql.Stmt + selectRoomReceipts *sql.Stmt + selectMaxReceiptID *sql.Stmt +} + +func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) { + _, err := db.Exec(receiptsSchema) + if err != nil { + return nil, err + } + r := &receiptStatements{ + db: db, + streamIDStatements: streamID, + } + if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil { + return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err) + } + if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { + return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) + } + if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil { + return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) + } + return r, nil +} + +// UpsertReceipt creates new user receipts +func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { + pos, err = r.streamIDStatements.nextReceiptID(ctx, txn) + if err != nil { + return + } + stmt := sqlutil.TxStmt(txn, r.upsertReceipt) + _, err = stmt.ExecContext(ctx, pos, roomId, receiptType, userId, eventId, timestamp, pos, eventId, timestamp) + return +} + +// SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp +func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) { + selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1) + lastPos := streamPos + params := make([]interface{}, len(roomIDs)+1) + params[0] = streamPos + for k, v := range roomIDs { + params[k+1] = v + } + rows, err := r.db.QueryContext(ctx, selectSQL, params...) + if err != nil { + return 0, nil, fmt.Errorf("unable to query room receipts: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed") + var res []api.OutputReceiptEvent + for rows.Next() { + r := api.OutputReceiptEvent{} + var id types.StreamPosition + err = rows.Scan(&id, &r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp) + if err != nil { + return 0, res, fmt.Errorf("unable to scan row to api.Receipts: %w", err) + } + res = append(res, r) + if id > lastPos { + lastPos = id + } + } + return lastPos, res, rows.Err() +} + +func (s *receiptStatements) SelectMaxReceiptID( + ctx context.Context, txn *sql.Tx, +) (id int64, err error) { + var nullableID sql.NullInt64 + stmt := sqlutil.TxStmt(txn, s.selectMaxReceiptID) + err = stmt.QueryRowContext(ctx).Scan(&nullableID) + if nullableID.Valid { + id = nullableID.Int64 + } + return +} diff --git a/syncapi/storage/cosmosdb/send_to_device_table.go b/syncapi/storage/cosmosdb/send_to_device_table.go new file mode 100644 index 000000000..3a985c8d4 --- /dev/null +++ b/syncapi/storage/cosmosdb/send_to_device_table.go @@ -0,0 +1,160 @@ +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/sirupsen/logrus" +) + +const sendToDeviceSchema = ` +-- Stores send-to-device messages. +CREATE TABLE IF NOT EXISTS syncapi_send_to_device ( + -- The ID that uniquely identifies this message. + id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The user ID to send the message to. + user_id TEXT NOT NULL, + -- The device ID to send the message to. + device_id TEXT NOT NULL, + -- The event content JSON. + content TEXT NOT NULL +); +` + +const insertSendToDeviceMessageSQL = ` + INSERT INTO syncapi_send_to_device (user_id, device_id, content) + VALUES ($1, $2, $3) +` + +const selectSendToDeviceMessagesSQL = ` + SELECT id, user_id, device_id, content + FROM syncapi_send_to_device + WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4 + ORDER BY id DESC +` + +const deleteSendToDeviceMessagesSQL = ` + DELETE FROM syncapi_send_to_device + WHERE user_id = $1 AND device_id = $2 AND id < $3 +` + +const selectMaxSendToDeviceIDSQL = "" + + "SELECT MAX(id) FROM syncapi_send_to_device" + +type sendToDeviceStatements struct { + db *sql.DB + insertSendToDeviceMessageStmt *sql.Stmt + selectSendToDeviceMessagesStmt *sql.Stmt + deleteSendToDeviceMessagesStmt *sql.Stmt + selectMaxSendToDeviceIDStmt *sql.Stmt +} + +func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { + s := &sendToDeviceStatements{ + db: db, + } + _, err := db.Exec(sendToDeviceSchema) + if err != nil { + return nil, err + } + if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { + return nil, err + } + if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { + return nil, err + } + if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil { + return nil, err + } + if s.selectMaxSendToDeviceIDStmt, err = db.Prepare(selectMaxSendToDeviceIDSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *sendToDeviceStatements) InsertSendToDeviceMessage( + ctx context.Context, txn *sql.Tx, userID, deviceID, content string, +) (pos types.StreamPosition, err error) { + var result sql.Result + result, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) + if p, err := result.LastInsertId(); err != nil { + return 0, err + } else { + pos = types.StreamPosition(p) + } + return +} + +func (s *sendToDeviceStatements) SelectSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition, +) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed") + + for rows.Next() { + var id types.StreamPosition + var userID, deviceID, content string + if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil { + logrus.WithError(err).Errorf("Failed to retrieve send-to-device message") + return + } + if id > lastPos { + lastPos = id + } + event := types.SendToDeviceEvent{ + ID: id, + UserID: userID, + DeviceID: deviceID, + } + if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { + logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message") + continue + } + events = append(events, event) + } + if lastPos == 0 { + lastPos = to + } + return lastPos, events, rows.Err() +} + +func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, userID, deviceID string, pos types.StreamPosition, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, userID, deviceID, pos) + return +} + +func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID( + ctx context.Context, txn *sql.Tx, +) (id int64, err error) { + var nullableID sql.NullInt64 + stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt) + err = stmt.QueryRowContext(ctx).Scan(&nullableID) + if nullableID.Valid { + id = nullableID.Int64 + } + return +} diff --git a/syncapi/storage/cosmosdb/stream_id_table.go b/syncapi/storage/cosmosdb/stream_id_table.go new file mode 100644 index 000000000..a599a9e65 --- /dev/null +++ b/syncapi/storage/cosmosdb/stream_id_table.go @@ -0,0 +1,94 @@ +package cosmosdb + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/types" +) + +const streamIDTableSchema = ` +-- Global stream ID counter, used by other tables. +CREATE TABLE IF NOT EXISTS syncapi_stream_id ( + stream_name TEXT NOT NULL PRIMARY KEY, + stream_id INT DEFAULT 0, + + UNIQUE(stream_name) +); +INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0) + ON CONFLICT DO NOTHING; +INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("receipt", 0) + ON CONFLICT DO NOTHING; +INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("accountdata", 0) + ON CONFLICT DO NOTHING; +INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("invite", 0) + ON CONFLICT DO NOTHING; +` + +const increaseStreamIDStmt = "" + + "UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" + +const selectStreamIDStmt = "" + + "SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1" + +type streamIDStatements struct { + db *sql.DB + increaseStreamIDStmt *sql.Stmt + selectStreamIDStmt *sql.Stmt +} + +func (s *streamIDStatements) prepare(db *sql.DB) (err error) { + s.db = db + _, err = db.Exec(streamIDTableSchema) + if err != nil { + return + } + if s.increaseStreamIDStmt, err = db.Prepare(increaseStreamIDStmt); err != nil { + return + } + if s.selectStreamIDStmt, err = db.Prepare(selectStreamIDStmt); err != nil { + return + } + return +} + +func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { + increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) + selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) + if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil { + return + } + err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos) + return +} + +func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { + increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) + selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) + if _, err = increaseStmt.ExecContext(ctx, "receipt"); err != nil { + return + } + err = selectStmt.QueryRowContext(ctx, "receipt").Scan(&pos) + return +} + +func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { + increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) + selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) + if _, err = increaseStmt.ExecContext(ctx, "invite"); err != nil { + return + } + err = selectStmt.QueryRowContext(ctx, "invite").Scan(&pos) + return +} + +func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { + increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) + selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) + if _, err = increaseStmt.ExecContext(ctx, "accountdata"); err != nil { + return + } + err = selectStmt.QueryRowContext(ctx, "accountdata").Scan(&pos) + return +} diff --git a/syncapi/storage/cosmosdb/syncserver.go b/syncapi/storage/cosmosdb/syncserver.go new file mode 100644 index 000000000..7bf1a1387 --- /dev/null +++ b/syncapi/storage/cosmosdb/syncserver.go @@ -0,0 +1,128 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "database/sql" + + // Import the sqlite3 package + _ "github.com/mattn/go-sqlite3" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/shared" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas" +) + +// SyncServerDatasource represents a sync server datasource which manages +// both the database for PDUs and caches for EDUs. +type SyncServerDatasource struct { + shared.Database + db *sql.DB + writer sqlutil.Writer + sqlutil.PartitionOffsetStatements + streamID streamIDStatements +} + +// NewDatabase creates a new sync server database +// nolint: gocyclo +func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { + var d SyncServerDatasource + var err error + if d.db, err = sqlutil.Open(dbProperties); err != nil { + return nil, err + } + d.writer = sqlutil.NewExclusiveWriter() + if err = d.prepare(dbProperties); err != nil { + return nil, err + } + return &d, nil +} + +func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) { + if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil { + return err + } + if err = d.streamID.prepare(d.db); err != nil { + return err + } + accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID) + if err != nil { + return err + } + events, err := NewSqliteEventsTable(d.db, &d.streamID) + if err != nil { + return err + } + roomState, err := NewSqliteCurrentRoomStateTable(d.db, &d.streamID) + if err != nil { + return err + } + invites, err := NewSqliteInvitesTable(d.db, &d.streamID) + if err != nil { + return err + } + peeks, err := NewSqlitePeeksTable(d.db, &d.streamID) + if err != nil { + return err + } + topology, err := NewSqliteTopologyTable(d.db) + if err != nil { + return err + } + bwExtrem, err := NewSqliteBackwardsExtremitiesTable(d.db) + if err != nil { + return err + } + sendToDevice, err := NewSqliteSendToDeviceTable(d.db) + if err != nil { + return err + } + filter, err := NewSqliteFilterTable(d.db) + if err != nil { + return err + } + receipts, err := NewSqliteReceiptsTable(d.db, &d.streamID) + if err != nil { + return err + } + memberships, err := NewSqliteMembershipsTable(d.db) + if err != nil { + return err + } + m := sqlutil.NewMigrations() + deltas.LoadFixSequences(m) + deltas.LoadRemoveSendToDeviceSentColumn(m) + if err = m.RunDeltas(d.db, dbProperties); err != nil { + return err + } + d.Database = shared.Database{ + DB: d.db, + Writer: d.writer, + Invites: invites, + Peeks: peeks, + AccountData: accountData, + OutputEvents: events, + BackwardExtremities: bwExtrem, + CurrentRoomState: roomState, + Topology: topology, + Filter: filter, + SendToDevice: sendToDevice, + Receipts: receipts, + Memberships: memberships, + } + return nil +} diff --git a/syncapi/storage/storage.go b/syncapi/storage/storage.go index 15386c338..d37b632b0 100644 --- a/syncapi/storage/storage.go +++ b/syncapi/storage/storage.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/cosmosdb" "github.com/matrix-org/dendrite/syncapi/storage/postgres" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" ) @@ -27,6 +28,8 @@ import ( // NewSyncServerDatasource opens a database connection. func NewSyncServerDatasource(dbProperties *config.DatabaseOptions) (Database, error) { switch { + case dbProperties.ConnectionString.IsCosmosDB(): + return cosmosdb.NewDatabase(dbProperties) case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(dbProperties) case dbProperties.ConnectionString.IsPostgres(): diff --git a/syncapi/storage/storage_wasm.go b/syncapi/storage/storage_wasm.go index f7fef962b..6f59e3c71 100644 --- a/syncapi/storage/storage_wasm.go +++ b/syncapi/storage/storage_wasm.go @@ -24,6 +24,8 @@ import ( // NewPublicRoomsServerDatabase opens a database connection. func NewSyncServerDatasource(dbProperties *config.DatabaseOptions) (Database, error) { switch { + case dbProperties.ConnectionString.IsCosmosDB(): + return nil, fmt.Errorf("can't use CosmosDB implementation") case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(dbProperties) case dbProperties.ConnectionString.IsPostgres(): diff --git a/userapi/storage/accounts/cosmosdb/account_data_table.go b/userapi/storage/accounts/cosmosdb/account_data_table.go new file mode 100644 index 000000000..916d28735 --- /dev/null +++ b/userapi/storage/accounts/cosmosdb/account_data_table.go @@ -0,0 +1,134 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const accountDataSchema = ` +-- Stores data about accounts data. +CREATE TABLE IF NOT EXISTS account_data ( + -- The Matrix user ID localpart for this account + localpart TEXT NOT NULL, + -- The room ID for this data (empty string if not specific to a room) + room_id TEXT, + -- The account data type + type TEXT NOT NULL, + -- The account data content + content TEXT NOT NULL, + + PRIMARY KEY(localpart, room_id, type) +); +` + +const insertAccountDataSQL = ` + INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) + ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 +` + +const selectAccountDataSQL = "" + + "SELECT room_id, type, content FROM account_data WHERE localpart = $1" + +const selectAccountDataByTypeSQL = "" + + "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" + +type accountDataStatements struct { + db *sql.DB + insertAccountDataStmt *sql.Stmt + selectAccountDataStmt *sql.Stmt + selectAccountDataByTypeStmt *sql.Stmt +} + +func (s *accountDataStatements) prepare(db *sql.DB) (err error) { + s.db = db + _, err = db.Exec(accountDataSchema) + if err != nil { + return + } + if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { + return + } + if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil { + return + } + if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil { + return + } + return +} + +func (s *accountDataStatements) insertAccountData( + ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, +) error { + _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) + return err +} + +func (s *accountDataStatements) selectAccountData( + ctx context.Context, localpart string, +) ( + /* global */ map[string]json.RawMessage, + /* rooms */ map[string]map[string]json.RawMessage, + error, +) { + rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) + if err != nil { + return nil, nil, err + } + + global := map[string]json.RawMessage{} + rooms := map[string]map[string]json.RawMessage{} + + for rows.Next() { + var roomID string + var dataType string + var content []byte + + if err = rows.Scan(&roomID, &dataType, &content); err != nil { + return nil, nil, err + } + + if roomID != "" { + if _, ok := rooms[roomID]; !ok { + rooms[roomID] = map[string]json.RawMessage{} + } + rooms[roomID][dataType] = content + } else { + global[dataType] = content + } + } + + return global, rooms, nil +} + +func (s *accountDataStatements) selectAccountDataByType( + ctx context.Context, localpart, roomID, dataType string, +) (data json.RawMessage, err error) { + var bytes []byte + stmt := s.selectAccountDataByTypeStmt + if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return + } + data = json.RawMessage(bytes) + return +} diff --git a/userapi/storage/accounts/cosmosdb/accounts_table.go b/userapi/storage/accounts/cosmosdb/accounts_table.go new file mode 100644 index 000000000..f871c1830 --- /dev/null +++ b/userapi/storage/accounts/cosmosdb/accounts_table.go @@ -0,0 +1,187 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + + log "github.com/sirupsen/logrus" +) + +const accountsSchema = ` +-- Stores data about accounts. +CREATE TABLE IF NOT EXISTS account_accounts ( + -- The Matrix user ID localpart for this account + localpart TEXT NOT NULL PRIMARY KEY, + -- When this account was first created, as a unix timestamp (ms resolution). + created_ts BIGINT NOT NULL, + -- The password hash for this account. Can be NULL if this is a passwordless account. + password_hash TEXT, + -- Identifies which application service this account belongs to, if any. + appservice_id TEXT, + -- If the account is currently active + is_deactivated BOOLEAN DEFAULT 0 + -- TODO: + -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? +); +` + +const insertAccountSQL = "" + + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" + +const updatePasswordSQL = "" + + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" + +const deactivateAccountSQL = "" + + "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" + +const selectAccountByLocalpartSQL = "" + + "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" + +const selectPasswordHashSQL = "" + + "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" + +const selectNewNumericLocalpartSQL = "" + + "SELECT COUNT(localpart) FROM account_accounts" + +type accountsStatements struct { + db *sql.DB + insertAccountStmt *sql.Stmt + updatePasswordStmt *sql.Stmt + deactivateAccountStmt *sql.Stmt + selectAccountByLocalpartStmt *sql.Stmt + selectPasswordHashStmt *sql.Stmt + selectNewNumericLocalpartStmt *sql.Stmt + serverName gomatrixserverlib.ServerName +} + +func (s *accountsStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(accountsSchema) + return err +} + +func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + s.db = db + if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { + return + } + if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil { + return + } + if s.deactivateAccountStmt, err = db.Prepare(deactivateAccountSQL); err != nil { + return + } + if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { + return + } + if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil { + return + } + if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil { + return + } + s.serverName = server + return +} + +// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, +// this account will be passwordless. Returns an error if this account already exists. Returns the account +// on success. +func (s *accountsStatements) insertAccount( + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, +) (*api.Account, error) { + createdTimeMS := time.Now().UnixNano() / 1000000 + stmt := s.insertAccountStmt + + var err error + if appserviceID == "" { + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) + } else { + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) + } + if err != nil { + return nil, err + } + + return &api.Account{ + Localpart: localpart, + UserID: userutil.MakeUserID(localpart, s.serverName), + ServerName: s.serverName, + AppServiceID: appserviceID, + }, nil +} + +func (s *accountsStatements) updatePassword( + ctx context.Context, localpart, passwordHash string, +) (err error) { + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + return +} + +func (s *accountsStatements) deactivateAccount( + ctx context.Context, localpart string, +) (err error) { + _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) + return +} + +func (s *accountsStatements) selectPasswordHash( + ctx context.Context, localpart string, +) (hash string, err error) { + err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) + return +} + +func (s *accountsStatements) selectAccountByLocalpart( + ctx context.Context, localpart string, +) (*api.Account, error) { + var appserviceIDPtr sql.NullString + var acc api.Account + + stmt := s.selectAccountByLocalpartStmt + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) + if err != nil { + if err != sql.ErrNoRows { + log.WithError(err).Error("Unable to retrieve user from the db") + } + return nil, err + } + if appserviceIDPtr.Valid { + acc.AppServiceID = appserviceIDPtr.String + } + + acc.UserID = userutil.MakeUserID(localpart, s.serverName) + acc.ServerName = s.serverName + + return &acc, nil +} + +func (s *accountsStatements) selectNewNumericLocalpart( + ctx context.Context, txn *sql.Tx, +) (id int64, err error) { + stmt := s.selectNewNumericLocalpartStmt + if txn != nil { + stmt = sqlutil.TxStmt(txn, stmt) + } + err = stmt.QueryRowContext(ctx).Scan(&id) + return +} diff --git a/userapi/storage/accounts/cosmosdb/constraint_wasm.go b/userapi/storage/accounts/cosmosdb/constraint_wasm.go new file mode 100644 index 000000000..9f15d2012 --- /dev/null +++ b/userapi/storage/accounts/cosmosdb/constraint_wasm.go @@ -0,0 +1,21 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build wasm + +package cosmosdb + +func isConstraintError(err error) bool { + return false +} diff --git a/userapi/storage/accounts/cosmosdb/openid_table.go b/userapi/storage/accounts/cosmosdb/openid_table.go new file mode 100644 index 000000000..c5bab0308 --- /dev/null +++ b/userapi/storage/accounts/cosmosdb/openid_table.go @@ -0,0 +1,86 @@ +package cosmosdb + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +const openIDTokenSchema = ` +-- Stores data about accounts. +CREATE TABLE IF NOT EXISTS open_id_tokens ( + -- The value of the token issued to a user + token TEXT NOT NULL PRIMARY KEY, + -- The Matrix user ID for this account + localpart TEXT NOT NULL, + -- When the token expires, as a unix timestamp (ms resolution). + token_expires_at_ms BIGINT NOT NULL +); +` + +const insertTokenSQL = "" + + "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + +const selectTokenSQL = "" + + "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" + +type tokenStatements struct { + db *sql.DB + insertTokenStmt *sql.Stmt + selectTokenStmt *sql.Stmt + serverName gomatrixserverlib.ServerName +} + +func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + s.db = db + _, err = db.Exec(openIDTokenSchema) + if err != nil { + return err + } + if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil { + return + } + if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil { + return + } + s.serverName = server + return +} + +// insertToken inserts a new OpenID Connect token to the DB. +// Returns new token, otherwise returns error if the token already exists. +func (s *tokenStatements) insertToken( + ctx context.Context, + txn *sql.Tx, + token, localpart string, + expiresAtMS int64, +) (err error) { + stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) + _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) + return +} + +// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB +// Returns the existing token's attributes, or err if no token is found +func (s *tokenStatements) selectOpenIDTokenAtrributes( + ctx context.Context, + token string, +) (*api.OpenIDTokenAttributes, error) { + var openIDTokenAttrs api.OpenIDTokenAttributes + err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( + &openIDTokenAttrs.UserID, + &openIDTokenAttrs.ExpiresAtMS, + ) + if err != nil { + if err != sql.ErrNoRows { + log.WithError(err).Error("Unable to retrieve token from the db") + } + return nil, err + } + + return &openIDTokenAttrs, nil +} diff --git a/userapi/storage/accounts/cosmosdb/profile_table.go b/userapi/storage/accounts/cosmosdb/profile_table.go new file mode 100644 index 000000000..73ffec031 --- /dev/null +++ b/userapi/storage/accounts/cosmosdb/profile_table.go @@ -0,0 +1,143 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const profilesSchema = ` +-- Stores data about accounts profiles. +CREATE TABLE IF NOT EXISTS account_profiles ( + -- The Matrix user ID localpart for this account + localpart TEXT NOT NULL PRIMARY KEY, + -- The display name for this account + display_name TEXT, + -- The URL of the avatar for this account + avatar_url TEXT +); +` + +const insertProfileSQL = "" + + "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + +const selectProfileByLocalpartSQL = "" + + "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" + +const setAvatarURLSQL = "" + + "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" + +const setDisplayNameSQL = "" + + "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" + +const selectProfilesBySearchSQL = "" + + "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + +type profilesStatements struct { + db *sql.DB + insertProfileStmt *sql.Stmt + selectProfileByLocalpartStmt *sql.Stmt + setAvatarURLStmt *sql.Stmt + setDisplayNameStmt *sql.Stmt + selectProfilesBySearchStmt *sql.Stmt +} + +func (s *profilesStatements) prepare(db *sql.DB) (err error) { + s.db = db + _, err = db.Exec(profilesSchema) + if err != nil { + return + } + if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil { + return + } + if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil { + return + } + if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil { + return + } + if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { + return + } + if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil { + return + } + return +} + +func (s *profilesStatements) insertProfile( + ctx context.Context, txn *sql.Tx, localpart string, +) error { + _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + return err +} + +func (s *profilesStatements) selectProfileByLocalpart( + ctx context.Context, localpart string, +) (*authtypes.Profile, error) { + var profile authtypes.Profile + err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( + &profile.Localpart, &profile.DisplayName, &profile.AvatarURL, + ) + if err != nil { + return nil, err + } + return &profile, nil +} + +func (s *profilesStatements) setAvatarURL( + ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, +) (err error) { + stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) + _, err = stmt.ExecContext(ctx, avatarURL, localpart) + return +} + +func (s *profilesStatements) setDisplayName( + ctx context.Context, txn *sql.Tx, localpart string, displayName string, +) (err error) { + stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) + _, err = stmt.ExecContext(ctx, displayName, localpart) + return +} + +func (s *profilesStatements) selectProfilesBySearch( + ctx context.Context, searchString string, limit int, +) ([]authtypes.Profile, error) { + var profiles []authtypes.Profile + // The fmt.Sprintf directive below is building a parameter for the + // "LIKE" condition in the SQL query. %% escapes the % char, so the + // statement in the end will look like "LIKE %searchString%". + rows, err := s.selectProfilesBySearchStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") + for rows.Next() { + var profile authtypes.Profile + if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { + return nil, err + } + profiles = append(profiles, profile) + } + return profiles, nil +} diff --git a/userapi/storage/accounts/cosmosdb/storage.go b/userapi/storage/accounts/cosmosdb/storage.go new file mode 100644 index 000000000..0524d499b --- /dev/null +++ b/userapi/storage/accounts/cosmosdb/storage.go @@ -0,0 +1,408 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "strconv" + "sync" + "time" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas" + "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/bcrypt" +) + +// Database represents an account database +type Database struct { + db *sql.DB + writer sqlutil.Writer + + sqlutil.PartitionOffsetStatements + accounts accountsStatements + profiles profilesStatements + accountDatas accountDataStatements + threepids threepidStatements + openIDTokens tokenStatements + serverName gomatrixserverlib.ServerName + bcryptCost int + openIDTokenLifetimeMS int64 + + accountsMu sync.Mutex + profilesMu sync.Mutex + accountDatasMu sync.Mutex + threepidsMu sync.Mutex +} + +// NewDatabase creates a new accounts and profiles database +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { + db, err := sqlutil.Open(dbProperties) + if err != nil { + return nil, err + } + d := &Database{ + serverName: serverName, + db: db, + writer: sqlutil.NewExclusiveWriter(), + bcryptCost: bcryptCost, + openIDTokenLifetimeMS: openIDTokenLifetimeMS, + } + + // Create tables before executing migrations so we don't fail if the table is missing, + // and THEN prepare statements so we don't fail due to referencing new columns + if err = d.accounts.execSchema(db); err != nil { + return nil, err + } + m := sqlutil.NewMigrations() + deltas.LoadIsActive(m) + if err = m.RunDeltas(db, dbProperties); err != nil { + return nil, err + } + + partitions := sqlutil.PartitionOffsetStatements{} + if err = partitions.Prepare(db, d.writer, "account"); err != nil { + return nil, err + } + if err = d.accounts.prepare(db, serverName); err != nil { + return nil, err + } + if err = d.profiles.prepare(db); err != nil { + return nil, err + } + if err = d.accountDatas.prepare(db); err != nil { + return nil, err + } + if err = d.threepids.prepare(db); err != nil { + return nil, err + } + if err = d.openIDTokens.prepare(db, serverName); err != nil { + return nil, err + } + + return d, nil +} + +// GetAccountByPassword returns the account associated with the given localpart and password. +// Returns sql.ErrNoRows if no account exists which matches the given localpart. +func (d *Database) GetAccountByPassword( + ctx context.Context, localpart, plaintextPassword string, +) (*api.Account, error) { + hash, err := d.accounts.selectPasswordHash(ctx, localpart) + if err != nil { + return nil, err + } + if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { + return nil, err + } + return d.accounts.selectAccountByLocalpart(ctx, localpart) +} + +// GetProfileByLocalpart returns the profile associated with the given localpart. +// Returns sql.ErrNoRows if no profile exists which matches the given localpart. +func (d *Database) GetProfileByLocalpart( + ctx context.Context, localpart string, +) (*authtypes.Profile, error) { + return d.profiles.selectProfileByLocalpart(ctx, localpart) +} + +// SetAvatarURL updates the avatar URL of the profile associated with the given +// localpart. Returns an error if something went wrong with the SQL query +func (d *Database) SetAvatarURL( + ctx context.Context, localpart string, avatarURL string, +) error { + d.profilesMu.Lock() + defer d.profilesMu.Unlock() + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL) + }) +} + +// SetDisplayName updates the display name of the profile associated with the given +// localpart. Returns an error if something went wrong with the SQL query +func (d *Database) SetDisplayName( + ctx context.Context, localpart string, displayName string, +) error { + d.profilesMu.Lock() + defer d.profilesMu.Unlock() + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.profiles.setDisplayName(ctx, txn, localpart, displayName) + }) +} + +// SetPassword sets the account password to the given hash. +func (d *Database) SetPassword( + ctx context.Context, localpart, plaintextPassword string, +) error { + hash, err := d.hashPassword(plaintextPassword) + if err != nil { + return err + } + err = d.accounts.updatePassword(ctx, localpart, hash) + return err +} + +// CreateGuestAccount makes a new guest account and creates an empty profile +// for this account. +func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { + // We need to lock so we sequentially create numeric localparts. If we don't, two calls to + // this function will cause the same number to be selected and one will fail with 'database is locked' + // when the first txn upgrades to a write txn. We also need to lock the account creation else we can + // race with CreateAccount + // We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed. + d.profilesMu.Lock() + d.accountDatasMu.Lock() + d.accountsMu.Lock() + defer d.profilesMu.Unlock() + defer d.accountDatasMu.Unlock() + defer d.accountsMu.Unlock() + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + var numLocalpart int64 + numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) + if err != nil { + return err + } + localpart := strconv.FormatInt(numLocalpart, 10) + acc, err = d.createAccount(ctx, txn, localpart, "", "") + return err + }) + return acc, err +} + +// CreateAccount makes a new account with the given login name and password, and creates an empty profile +// for this account. If no password is supplied, the account will be a passwordless account. If the +// account already exists, it will return nil, ErrUserExists. +func (d *Database) CreateAccount( + ctx context.Context, localpart, plaintextPassword, appserviceID string, +) (acc *api.Account, err error) { + // Create one account at a time else we can get 'database is locked'. + d.profilesMu.Lock() + d.accountDatasMu.Lock() + d.accountsMu.Lock() + defer d.profilesMu.Unlock() + defer d.accountDatasMu.Unlock() + defer d.accountsMu.Unlock() + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) + return err + }) + return +} + +// WARNING! This function assumes that the relevant mutexes have already +// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). +func (d *Database) createAccount( + ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, +) (*api.Account, error) { + var err error + var account *api.Account + // Generate a password hash if this is not a password-less user + hash := "" + if plaintextPassword != "" { + hash, err = d.hashPassword(plaintextPassword) + if err != nil { + return nil, err + } + } + if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil { + return nil, sqlutil.ErrUserExists + } + if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { + return nil, err + } + if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ + "global": { + "content": [], + "override": [], + "room": [], + "sender": [], + "underride": [] + } + }`)); err != nil { + return nil, err + } + return account, nil +} + +// SaveAccountData saves new account data for a given user and a given room. +// If the account data is not specific to a room, the room ID should be an empty string +// If an account data already exists for a given set (user, room, data type), it will +// update the corresponding row with the new content +// Returns a SQL error if there was an issue with the insertion/update +func (d *Database) SaveAccountData( + ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, +) error { + d.accountDatasMu.Lock() + defer d.accountDatasMu.Unlock() + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) + }) +} + +// GetAccountData returns account data related to a given localpart +// If no account data could be found, returns an empty arrays +// Returns an error if there was an issue with the retrieval +func (d *Database) GetAccountData(ctx context.Context, localpart string) ( + global map[string]json.RawMessage, + rooms map[string]map[string]json.RawMessage, + err error, +) { + return d.accountDatas.selectAccountData(ctx, localpart) +} + +// GetAccountDataByType returns account data matching a given +// localpart, room ID and type. +// If no account data could be found, returns nil +// Returns an error if there was an issue with the retrieval +func (d *Database) GetAccountDataByType( + ctx context.Context, localpart, roomID, dataType string, +) (data json.RawMessage, err error) { + return d.accountDatas.selectAccountDataByType( + ctx, localpart, roomID, dataType, + ) +} + +// GetNewNumericLocalpart generates and returns a new unused numeric localpart +func (d *Database) GetNewNumericLocalpart( + ctx context.Context, +) (int64, error) { + return d.accounts.selectNewNumericLocalpart(ctx, nil) +} + +func (d *Database) hashPassword(plaintext string) (hash string, err error) { + hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost) + return string(hashBytes), err +} + +// Err3PIDInUse is the error returned when trying to save an association involving +// a third-party identifier which is already associated to a local user. +var Err3PIDInUse = errors.New("This third-party identifier is already in use") + +// SaveThreePIDAssociation saves the association between a third party identifier +// and a local Matrix user (identified by the user's ID's local part). +// If the third-party identifier is already part of an association, returns Err3PIDInUse. +// Returns an error if there was a problem talking to the database. +func (d *Database) SaveThreePIDAssociation( + ctx context.Context, threepid, localpart, medium string, +) (err error) { + d.threepidsMu.Lock() + defer d.threepidsMu.Unlock() + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + user, err := d.threepids.selectLocalpartForThreePID( + ctx, txn, threepid, medium, + ) + if err != nil { + return err + } + + if len(user) > 0 { + return Err3PIDInUse + } + + return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart) + }) +} + +// RemoveThreePIDAssociation removes the association involving a given third-party +// identifier. +// If no association exists involving this third-party identifier, returns nothing. +// If there was a problem talking to the database, returns an error. +func (d *Database) RemoveThreePIDAssociation( + ctx context.Context, threepid string, medium string, +) (err error) { + d.threepidsMu.Lock() + defer d.threepidsMu.Unlock() + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.threepids.deleteThreePID(ctx, txn, threepid, medium) + }) +} + +// GetLocalpartForThreePID looks up the localpart associated with a given third-party +// identifier. +// If no association involves the given third-party idenfitier, returns an empty +// string. +// Returns an error if there was a problem talking to the database. +func (d *Database) GetLocalpartForThreePID( + ctx context.Context, threepid string, medium string, +) (localpart string, err error) { + return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) +} + +// GetThreePIDsForLocalpart looks up the third-party identifiers associated with +// a given local user. +// If no association is known for this user, returns an empty slice. +// Returns an error if there was an issue talking to the database. +func (d *Database) GetThreePIDsForLocalpart( + ctx context.Context, localpart string, +) (threepids []authtypes.ThreePID, err error) { + return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) +} + +// CheckAccountAvailability checks if the username/localpart is already present +// in the database. +// If the DB returns sql.ErrNoRows the Localpart isn't taken. +func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { + _, err := d.accounts.selectAccountByLocalpart(ctx, localpart) + if err == sql.ErrNoRows { + return true, nil + } + return false, err +} + +// GetAccountByLocalpart returns the account associated with the given localpart. +// This function assumes the request is authenticated or the account data is used only internally. +// Returns sql.ErrNoRows if no account exists which matches the given localpart. +func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, +) (*api.Account, error) { + return d.accounts.selectAccountByLocalpart(ctx, localpart) +} + +// SearchProfiles returns all profiles where the provided localpart or display name +// match any part of the profiles in the database. +func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int, +) ([]authtypes.Profile, error) { + return d.profiles.selectProfilesBySearch(ctx, searchString, limit) +} + +// DeactivateAccount deactivates the user's account, removing all ability for the user to login again. +func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { + return d.accounts.deactivateAccount(ctx, localpart) +} + +// CreateOpenIDToken persists a new token that was issued for OpenID Connect +func (d *Database) CreateOpenIDToken( + ctx context.Context, + token, localpart string, +) (int64, error) { + expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS + err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS) + }) + return expiresAtMS, err +} + +// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token +func (d *Database) GetOpenIDTokenAttributes( + ctx context.Context, + token string, +) (*api.OpenIDTokenAttributes, error) { + return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) +} diff --git a/userapi/storage/accounts/cosmosdb/threepid_table.go b/userapi/storage/accounts/cosmosdb/threepid_table.go new file mode 100644 index 000000000..0d37dda0e --- /dev/null +++ b/userapi/storage/accounts/cosmosdb/threepid_table.go @@ -0,0 +1,133 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" +) + +const threepidSchema = ` +-- Stores data about third party identifiers +CREATE TABLE IF NOT EXISTS account_threepid ( + -- The third party identifier + threepid TEXT NOT NULL, + -- The 3PID medium + medium TEXT NOT NULL DEFAULT 'email', + -- The localpart of the Matrix user ID associated to this 3PID + localpart TEXT NOT NULL, + + PRIMARY KEY(threepid, medium) +); + +CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart); +` + +const selectLocalpartForThreePIDSQL = "" + + "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" + +const selectThreePIDsForLocalpartSQL = "" + + "SELECT threepid, medium FROM account_threepid WHERE localpart = $1" + +const insertThreePIDSQL = "" + + "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)" + +const deleteThreePIDSQL = "" + + "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" + +type threepidStatements struct { + db *sql.DB + selectLocalpartForThreePIDStmt *sql.Stmt + selectThreePIDsForLocalpartStmt *sql.Stmt + insertThreePIDStmt *sql.Stmt + deleteThreePIDStmt *sql.Stmt +} + +func (s *threepidStatements) prepare(db *sql.DB) (err error) { + s.db = db + _, err = db.Exec(threepidSchema) + if err != nil { + return + } + if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil { + return + } + if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil { + return + } + if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil { + return + } + if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil { + return + } + + return +} + +func (s *threepidStatements) selectLocalpartForThreePID( + ctx context.Context, txn *sql.Tx, threepid string, medium string, +) (localpart string, err error) { + stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) + err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) + if err == sql.ErrNoRows { + return "", nil + } + return +} + +func (s *threepidStatements) selectThreePIDsForLocalpart( + ctx context.Context, localpart string, +) (threepids []authtypes.ThreePID, err error) { + rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "selectThreePIDsForLocalpart: rows.close() failed") + + threepids = []authtypes.ThreePID{} + for rows.Next() { + var threepid string + var medium string + if err = rows.Scan(&threepid, &medium); err != nil { + return + } + threepids = append(threepids, authtypes.ThreePID{ + Address: threepid, + Medium: medium, + }) + } + return threepids, rows.Err() +} + +func (s *threepidStatements) insertThreePID( + ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, +) (err error) { + stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) + _, err = stmt.ExecContext(ctx, threepid, medium, localpart) + return err +} + +func (s *threepidStatements) deleteThreePID( + ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) { + stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) + _, err = stmt.ExecContext(ctx, threepid, medium) + return err +} diff --git a/userapi/storage/accounts/storage.go b/userapi/storage/accounts/storage.go index 3489c9d07..ff17a97ff 100644 --- a/userapi/storage/accounts/storage.go +++ b/userapi/storage/accounts/storage.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/storage/accounts/cosmosdb" "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres" "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" "github.com/matrix-org/gomatrixserverlib" @@ -29,6 +30,8 @@ import ( // and sets postgres connection parameters func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) { switch { + case dbProperties.ConnectionString.IsCosmosDB(): + return cosmosdb.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) case dbProperties.ConnectionString.IsPostgres(): diff --git a/userapi/storage/accounts/storage_wasm.go b/userapi/storage/accounts/storage_wasm.go index 11a88a20a..f8bf3e322 100644 --- a/userapi/storage/accounts/storage_wasm.go +++ b/userapi/storage/accounts/storage_wasm.go @@ -29,6 +29,8 @@ func NewDatabase( openIDTokenLifetimeMS int64, ) (Database, error) { switch { + case dbProperties.ConnectionString.IsCosmosDB(): + return nil, fmt.Errorf("can't use CosmosDB implementation") case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) case dbProperties.ConnectionString.IsPostgres(): diff --git a/userapi/storage/devices/cosmosdb/devices_table.go b/userapi/storage/devices/cosmosdb/devices_table.go new file mode 100644 index 000000000..f52e76507 --- /dev/null +++ b/userapi/storage/devices/cosmosdb/devices_table.go @@ -0,0 +1,322 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "strings" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const devicesSchema = ` +-- This sequence is used for automatic allocation of session_id. +-- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1; + +-- Stores data about devices. +CREATE TABLE IF NOT EXISTS device_devices ( + access_token TEXT PRIMARY KEY, + session_id INTEGER, + device_id TEXT , + localpart TEXT , + created_ts BIGINT, + display_name TEXT, + last_seen_ts BIGINT, + ip TEXT, + user_agent TEXT, + + UNIQUE (localpart, device_id) +); +` + +const insertDeviceSQL = "" + + "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + +const selectDevicesCountSQL = "" + + "SELECT COUNT(access_token) FROM device_devices" + +const selectDeviceByTokenSQL = "" + + "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" + +const selectDeviceByIDSQL = "" + + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" + +const selectDevicesByLocalpartSQL = "" + + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2" + +const updateDeviceNameSQL = "" + + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + +const deleteDeviceSQL = "" + + "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" + +const deleteDevicesByLocalpartSQL = "" + + "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" + +const deleteDevicesSQL = "" + + "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" + +const selectDevicesByIDSQL = "" + + "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)" + +const updateDeviceLastSeen = "" + + "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4" + +type devicesStatements struct { + db *sql.DB + writer sqlutil.Writer + insertDeviceStmt *sql.Stmt + selectDevicesCountStmt *sql.Stmt + selectDeviceByTokenStmt *sql.Stmt + selectDeviceByIDStmt *sql.Stmt + selectDevicesByIDStmt *sql.Stmt + selectDevicesByLocalpartStmt *sql.Stmt + updateDeviceNameStmt *sql.Stmt + updateDeviceLastSeenStmt *sql.Stmt + deleteDeviceStmt *sql.Stmt + deleteDevicesByLocalpartStmt *sql.Stmt + serverName gomatrixserverlib.ServerName +} + +func (s *devicesStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(devicesSchema) + return err +} + +func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { + s.db = db + s.writer = writer + if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { + return + } + if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil { + return + } + if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { + return + } + if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil { + return + } + if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil { + return + } + if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil { + return + } + if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil { + return + } + if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { + return + } + if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { + return + } + if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil { + return + } + s.serverName = server + return +} + +// insertDevice creates a new device. Returns an error if any device with the same access token already exists. +// Returns an error if the user already has a device with the given device ID. +// Returns the device on success. +func (s *devicesStatements) insertDevice( + ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, + displayName *string, ipAddr, userAgent string, +) (*api.Device, error) { + createdTimeMS := time.Now().UnixNano() / 1000000 + var sessionID int64 + countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt) + insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) + if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { + return nil, err + } + sessionID++ + if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { + return nil, err + } + return &api.Device{ + ID: id, + UserID: userutil.MakeUserID(localpart, s.serverName), + AccessToken: accessToken, + SessionID: sessionID, + LastSeenTS: createdTimeMS, + LastSeenIP: ipAddr, + UserAgent: userAgent, + }, nil +} + +func (s *devicesStatements) deleteDevice( + ctx context.Context, txn *sql.Tx, id, localpart string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) + _, err := stmt.ExecContext(ctx, id, localpart) + return err +} + +func (s *devicesStatements) deleteDevices( + ctx context.Context, txn *sql.Tx, localpart string, devices []string, +) error { + orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1) + prep, err := s.db.Prepare(orig) + if err != nil { + return err + } + stmt := sqlutil.TxStmt(txn, prep) + params := make([]interface{}, len(devices)+1) + params[0] = localpart + for i, v := range devices { + params[i+1] = v + } + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func (s *devicesStatements) deleteDevicesByLocalpart( + ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) + _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) + return err +} + +func (s *devicesStatements) updateDeviceName( + ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, +) error { + stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) + _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + return err +} + +func (s *devicesStatements) selectDeviceByToken( + ctx context.Context, accessToken string, +) (*api.Device, error) { + var dev api.Device + var localpart string + stmt := s.selectDeviceByTokenStmt + err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) + if err == nil { + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.AccessToken = accessToken + } + return &dev, err +} + +// selectDeviceByID retrieves a device from the database with the given user +// localpart and deviceID +func (s *devicesStatements) selectDeviceByID( + ctx context.Context, localpart, deviceID string, +) (*api.Device, error) { + var dev api.Device + var displayName sql.NullString + stmt := s.selectDeviceByIDStmt + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) + if err == nil { + dev.ID = deviceID + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + if displayName.Valid { + dev.DisplayName = displayName.String + } + } + return &dev, err +} + +func (s *devicesStatements) selectDevicesByLocalpart( + ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, +) ([]api.Device, error) { + devices := []api.Device{} + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) + + if err != nil { + return devices, err + } + + for rows.Next() { + var dev api.Device + var lastseents sql.NullInt64 + var id, displayname, ip, useragent sql.NullString + err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent) + if err != nil { + return devices, err + } + if id.Valid { + dev.ID = id.String + } + if displayname.Valid { + dev.DisplayName = displayname.String + } + if lastseents.Valid { + dev.LastSeenTS = lastseents.Int64 + } + if ip.Valid { + dev.LastSeenIP = ip.String + } + if useragent.Valid { + dev.UserAgent = useragent.String + } + + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + devices = append(devices, dev) + } + + return devices, nil +} + +func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1) + iDeviceIDs := make([]interface{}, len(deviceIDs)) + for i := range deviceIDs { + iDeviceIDs[i] = deviceIDs[i] + } + + rows, err := s.db.QueryContext(ctx, sqlQuery, iDeviceIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed") + var devices []api.Device + for rows.Next() { + var dev api.Device + var localpart string + var displayName sql.NullString + if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil { + return nil, err + } + if displayName.Valid { + dev.DisplayName = displayName.String + } + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + devices = append(devices, dev) + } + return devices, rows.Err() +} + +func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { + lastSeenTs := time.Now().UnixNano() / 1000000 + stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) + _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) + return err +} diff --git a/userapi/storage/devices/cosmosdb/storage.go b/userapi/storage/devices/cosmosdb/storage.go new file mode 100644 index 000000000..338210148 --- /dev/null +++ b/userapi/storage/devices/cosmosdb/storage.go @@ -0,0 +1,214 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/base64" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas" + "github.com/matrix-org/gomatrixserverlib" + + _ "github.com/mattn/go-sqlite3" +) + +// The length of generated device IDs +var deviceIDByteLength = 6 + +// Database represents a device database. +type Database struct { + db *sql.DB + writer sqlutil.Writer + devices devicesStatements +} + +// NewDatabase creates a new device database +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) { + db, err := sqlutil.Open(dbProperties) + if err != nil { + return nil, err + } + writer := sqlutil.NewExclusiveWriter() + d := devicesStatements{} + + // Create tables before executing migrations so we don't fail if the table is missing, + // and THEN prepare statements so we don't fail due to referencing new columns + if err = d.execSchema(db); err != nil { + return nil, err + } + m := sqlutil.NewMigrations() + deltas.LoadLastSeenTSIP(m) + if err = m.RunDeltas(db, dbProperties); err != nil { + return nil, err + } + if err = d.prepare(db, writer, serverName); err != nil { + return nil, err + } + return &Database{db, writer, d}, nil +} + +// GetDeviceByAccessToken returns the device matching the given access token. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByAccessToken( + ctx context.Context, token string, +) (*api.Device, error) { + return d.devices.selectDeviceByToken(ctx, token) +} + +// GetDeviceByID returns the device matching the given ID. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByID( + ctx context.Context, localpart, deviceID string, +) (*api.Device, error) { + return d.devices.selectDeviceByID(ctx, localpart, deviceID) +} + +// GetDevicesByLocalpart returns the devices matching the given localpart. +func (d *Database) GetDevicesByLocalpart( + ctx context.Context, localpart string, +) ([]api.Device, error) { + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") +} + +func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + return d.devices.selectDevicesByID(ctx, deviceIDs) +} + +// CreateDevice makes a new device associated with the given user ID localpart. +// If there is already a device with the same device ID for this user, that access token will be revoked +// and replaced with the given accessToken. If the given accessToken is already in use for another device, +// an error will be returned. +// If no device ID is given one is generated. +// Returns the device on success. +func (d *Database) CreateDevice( + ctx context.Context, localpart string, deviceID *string, accessToken string, + displayName *string, ipAddr, userAgent string, +) (dev *api.Device, returnErr error) { + if deviceID != nil { + returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + var err error + // Revoke existing tokens for this device + if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { + return err + } + + dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return err + }) + } else { + // We generate device IDs in a loop in case its already taken. + // We cap this at going round 5 times to ensure we don't spin forever + var newDeviceID string + for i := 1; i <= 5; i++ { + newDeviceID, returnErr = generateDeviceID() + if returnErr != nil { + return + } + + returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + var err error + dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return err + }) + if returnErr == nil { + return + } + } + } + return +} + +// generateDeviceID creates a new device id. Returns an error if failed to generate +// random bytes. +func generateDeviceID() (string, error) { + b := make([]byte, deviceIDByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + // url-safe no padding + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// UpdateDevice updates the given device with the display name. +// Returns SQL error if there are problems and nil on success. +func (d *Database) UpdateDevice( + ctx context.Context, localpart, deviceID string, displayName *string, +) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) + }) +} + +// RemoveDevice revokes a device by deleting the entry in the database +// matching with the given device ID and user ID localpart. +// If the device doesn't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevice( + ctx context.Context, deviceID, localpart string, +) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveDevices revokes one or more devices by deleting the entry in the database +// matching with the given device IDs and user ID localpart. +// If the devices don't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevices( + ctx context.Context, localpart string, devices []string, +) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveAllDevices revokes devices by deleting the entry in the +// database matching the given user ID localpart. +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveAllDevices( + ctx context.Context, localpart, exceptDeviceID string, +) (devices []api.Device, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) + if err != nil { + return err + } + if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { + return err + } + return nil + }) + return +} + +// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) + }) +} diff --git a/userapi/storage/devices/storage.go b/userapi/storage/devices/storage.go index bfce924d9..6baaa5957 100644 --- a/userapi/storage/devices/storage.go +++ b/userapi/storage/devices/storage.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/storage/devices/cosmosdb" "github.com/matrix-org/dendrite/userapi/storage/devices/postgres" "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" "github.com/matrix-org/gomatrixserverlib" @@ -29,6 +30,8 @@ import ( // and sets postgres connection parameters func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (Database, error) { switch { + case dbProperties.ConnectionString.IsCosmosDB(): + return cosmosdb.NewDatabase(dbProperties, serverName) case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(dbProperties, serverName) case dbProperties.ConnectionString.IsPostgres(): diff --git a/userapi/storage/devices/storage_wasm.go b/userapi/storage/devices/storage_wasm.go index f360f9857..dc462957e 100644 --- a/userapi/storage/devices/storage_wasm.go +++ b/userapi/storage/devices/storage_wasm.go @@ -27,6 +27,8 @@ func NewDatabase( serverName gomatrixserverlib.ServerName, ) (Database, error) { switch { + case dbProperties.ConnectionString.IsCosmosDB(): + return nil, fmt.Errorf("can't use CosmosDB implementation") case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(dbProperties, serverName) case dbProperties.ConnectionString.IsPostgres(): From 5ded872da9c5d343ae744ef2eea968bdfee4b06b Mon Sep 17 00:00:00 2001 From: Alex Flatow Date: Thu, 6 May 2021 15:09:44 +1000 Subject: [PATCH 2/2] - Add CosmosDB as a Datasource type - Use the SQLLite as a base for the CosmosDB package(s) - Update the ConnString to use file: from cosmosdb: so it still works - Add a yaml file for the config to use CosmosDB --- appservice/storage/cosmosdb/storage.go | 2 ++ federationsender/storage/cosmosdb/storage.go | 2 ++ keyserver/storage/cosmosdb/storage.go | 2 ++ mediaapi/storage/cosmosdb/storage.go | 2 ++ roomserver/storage/cosmosdb/storage.go | 2 ++ setup/kafka/kafka.go | 5 +++++ signingkeyserver/storage/cosmosdb/keydb.go | 2 ++ syncapi/storage/cosmosdb/syncserver.go | 2 ++ userapi/storage/accounts/cosmosdb/storage.go | 2 ++ userapi/storage/devices/cosmosdb/storage.go | 2 ++ 10 files changed, 23 insertions(+) diff --git a/appservice/storage/cosmosdb/storage.go b/appservice/storage/cosmosdb/storage.go index 3639010e1..2f07167b9 100644 --- a/appservice/storage/cosmosdb/storage.go +++ b/appservice/storage/cosmosdb/storage.go @@ -16,6 +16,7 @@ package cosmosdb import ( + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "context" "database/sql" @@ -37,6 +38,7 @@ type Database struct { // NewDatabase opens a new database func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { + dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) var result Database var err error if result.db, err = sqlutil.Open(dbProperties); err != nil { diff --git a/federationsender/storage/cosmosdb/storage.go b/federationsender/storage/cosmosdb/storage.go index da429046b..fb38d6e6d 100644 --- a/federationsender/storage/cosmosdb/storage.go +++ b/federationsender/storage/cosmosdb/storage.go @@ -16,6 +16,7 @@ package cosmosdb import ( + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "database/sql" _ "github.com/mattn/go-sqlite3" @@ -37,6 +38,7 @@ type Database struct { // NewDatabase opens a new database func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (*Database, error) { + dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) var d Database var err error if d.db, err = sqlutil.Open(dbProperties); err != nil { diff --git a/keyserver/storage/cosmosdb/storage.go b/keyserver/storage/cosmosdb/storage.go index ba000cb24..c4a0c0c97 100644 --- a/keyserver/storage/cosmosdb/storage.go +++ b/keyserver/storage/cosmosdb/storage.go @@ -15,12 +15,14 @@ package cosmosdb import ( + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/storage/shared" "github.com/matrix-org/dendrite/setup/config" ) func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) { + dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err diff --git a/mediaapi/storage/cosmosdb/storage.go b/mediaapi/storage/cosmosdb/storage.go index b05373868..43b2879df 100644 --- a/mediaapi/storage/cosmosdb/storage.go +++ b/mediaapi/storage/cosmosdb/storage.go @@ -16,6 +16,7 @@ package cosmosdb import ( + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "context" "database/sql" @@ -36,6 +37,7 @@ type Database struct { // Open opens a postgres database. func Open(dbProperties *config.DatabaseOptions) (*Database, error) { + dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) d := Database{ writer: sqlutil.NewExclusiveWriter(), } diff --git a/roomserver/storage/cosmosdb/storage.go b/roomserver/storage/cosmosdb/storage.go index bb3f6af2e..aa712d07d 100644 --- a/roomserver/storage/cosmosdb/storage.go +++ b/roomserver/storage/cosmosdb/storage.go @@ -16,6 +16,7 @@ package cosmosdb import ( + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "context" "database/sql" @@ -37,6 +38,7 @@ type Database struct { // Open a sqlite database. func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { + dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) var d Database var db *sql.DB var err error diff --git a/setup/kafka/kafka.go b/setup/kafka/kafka.go index a2902c962..431da23b6 100644 --- a/setup/kafka/kafka.go +++ b/setup/kafka/kafka.go @@ -1,6 +1,7 @@ package kafka import ( + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/naffka" @@ -46,6 +47,10 @@ func setupNaffka(cfg *config.Kafka) (sarama.Consumer, sarama.SyncProducer) { if naffkaInstance != nil { return naffkaInstance, naffkaInstance } + if(cfg.Database.ConnectionString.IsCosmosDB()) { + cfg.Database.ConnectionString = cosmosdbutil.GetConnectionString(&cfg.Database.ConnectionString) + } + naffkaDB, err := naffkaStorage.NewDatabase(string(cfg.Database.ConnectionString)) if err != nil { logrus.WithError(err).Panic("Failed to setup naffka database") diff --git a/signingkeyserver/storage/cosmosdb/keydb.go b/signingkeyserver/storage/cosmosdb/keydb.go index 0f4371bce..46c95d88a 100644 --- a/signingkeyserver/storage/cosmosdb/keydb.go +++ b/signingkeyserver/storage/cosmosdb/keydb.go @@ -16,6 +16,7 @@ package cosmosdb import ( + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "context" "golang.org/x/crypto/ed25519" @@ -44,6 +45,7 @@ func NewDatabase( serverKey ed25519.PublicKey, serverKeyID gomatrixserverlib.KeyID, ) (*Database, error) { + dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err diff --git a/syncapi/storage/cosmosdb/syncserver.go b/syncapi/storage/cosmosdb/syncserver.go index 7bf1a1387..719c8fdad 100644 --- a/syncapi/storage/cosmosdb/syncserver.go +++ b/syncapi/storage/cosmosdb/syncserver.go @@ -16,6 +16,7 @@ package cosmosdb import ( + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "database/sql" // Import the sqlite3 package @@ -40,6 +41,7 @@ type SyncServerDatasource struct { // NewDatabase creates a new sync server database // nolint: gocyclo func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { + dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) var d SyncServerDatasource var err error if d.db, err = sqlutil.Open(dbProperties); err != nil { diff --git a/userapi/storage/accounts/cosmosdb/storage.go b/userapi/storage/accounts/cosmosdb/storage.go index 0524d499b..2e9f2888d 100644 --- a/userapi/storage/accounts/cosmosdb/storage.go +++ b/userapi/storage/accounts/cosmosdb/storage.go @@ -15,6 +15,7 @@ package cosmosdb import ( + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "context" "database/sql" "encoding/json" @@ -55,6 +56,7 @@ type Database struct { // NewDatabase creates a new accounts and profiles database func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { + dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err diff --git a/userapi/storage/devices/cosmosdb/storage.go b/userapi/storage/devices/cosmosdb/storage.go index 338210148..d414e9026 100644 --- a/userapi/storage/devices/cosmosdb/storage.go +++ b/userapi/storage/devices/cosmosdb/storage.go @@ -15,6 +15,7 @@ package cosmosdb import ( + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "context" "crypto/rand" "database/sql" @@ -41,6 +42,7 @@ type Database struct { // NewDatabase creates a new device database func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) { + dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err