From 5d68daef8014aca413384ca6fa38c2c56bb4cf91 Mon Sep 17 00:00:00 2001 From: alexfca <75228224+alexfca@users.noreply.github.com> Date: Thu, 20 May 2021 14:42:33 +1000 Subject: [PATCH] Implement Cosmos DB for the RoomServer Service (#5) * - Implement Cosmos for the devices_table - Use the ConnectionString in the YAML to include the Tenant - Revert all other non implemented tables back to use SQLLite3 * - Change the Config to use "test.criticicalarc.com" Container - Add generic function GetDocumentOrNil to standardize GetDocument - Add func to return CrossPartition queries for Aggregates - Add func GetNextSequence() as generic seq generator for AutoIncrement - Add cosmosdbutil.ErrNoRows to return (emulate) sql.ErrNoRows - Add a "fake" ExclusiveWriterFake - Add standard "getXX", "setXX" and "queryXX" to all TABLE class files - Add specific Table SEQ for the Events table - Add specific Table SEQ for the Rooms table - Add specific Table SEQ for the StateSnapshot table --- dendrite-config-cosmosdb.yaml | 6 +- internal/cosmosdbapi/document.go | 25 +- internal/cosmosdbapi/documentoperations.go | 18 +- internal/cosmosdbutil/document_seq.go | 76 ++ internal/cosmosdbutil/errors.go | 12 + internal/cosmosdbutil/writer_exclusive.go | 77 ++ .../20201028212440_add_forgotten_column.go | 82 ++ .../storage/cosmosdb/event_json_table.go | 162 +++- roomserver/storage/cosmosdb/event_seq.go | 24 + .../cosmosdb/event_state_keys_table.go | 321 +++++-- .../storage/cosmosdb/event_types_table.go | 306 +++++-- roomserver/storage/cosmosdb/events_table.go | 851 ++++++++++++------ roomserver/storage/cosmosdb/invite_table.go | 275 ++++-- .../storage/cosmosdb/membership_table.go | 595 +++++++++--- .../storage/cosmosdb/previous_events_table.go | 198 ++-- .../storage/cosmosdb/published_table.go | 214 +++-- .../storage/cosmosdb/redactions_table.go | 279 ++++-- .../storage/cosmosdb/room_aliases_table.go | 271 ++++-- roomserver/storage/cosmosdb/room_seq.go | 12 + roomserver/storage/cosmosdb/rooms_table.go | 531 +++++++---- .../storage/cosmosdb/state_block_table.go | 322 ++++--- .../storage/cosmosdb/state_snapshot_seq.go | 12 + .../storage/cosmosdb/state_snapshot_table.go | 177 ++-- roomserver/storage/cosmosdb/storage.go | 89 +- .../storage/cosmosdb/transactions_table.go | 160 +++- .../accounts/cosmosdb/account_data_table.go | 63 +- .../accounts/cosmosdb/accounts_table.go | 95 +- .../storage/accounts/cosmosdb/openid_table.go | 45 +- .../accounts/cosmosdb/profile_table.go | 94 +- userapi/storage/accounts/cosmosdb/storage.go | 6 +- .../accounts/cosmosdb/threepid_table.go | 57 +- .../storage/devices/cosmosdb/devices_table.go | 119 ++- userapi/storage/devices/cosmosdb/storage.go | 2 - 33 files changed, 4012 insertions(+), 1564 deletions(-) create mode 100644 internal/cosmosdbutil/document_seq.go create mode 100644 internal/cosmosdbutil/errors.go create mode 100644 internal/cosmosdbutil/writer_exclusive.go create mode 100644 roomserver/storage/cosmosdb/deltas/20201028212440_add_forgotten_column.go create mode 100644 roomserver/storage/cosmosdb/event_seq.go create mode 100644 roomserver/storage/cosmosdb/room_seq.go create mode 100644 roomserver/storage/cosmosdb/state_snapshot_seq.go diff --git a/dendrite-config-cosmosdb.yaml b/dendrite-config-cosmosdb.yaml index 189abe766..ef7883e23 100644 --- a/dendrite-config-cosmosdb.yaml +++ b/dendrite-config-cosmosdb.yaml @@ -291,7 +291,7 @@ room_server: listen: http://localhost:7770 connect: http://localhost:7770 database: - connection_string: file:roomserver.db + connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;" max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 @@ -354,12 +354,12 @@ user_api: listen: http://localhost:7781 connect: http://localhost:7781 account_database: - connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=criticalarc.com;" + connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;" max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 device_database: - connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=criticalarc.com;" + connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;" max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 diff --git a/internal/cosmosdbapi/document.go b/internal/cosmosdbapi/document.go index 9e419fc52..54f6499e2 100644 --- a/internal/cosmosdbapi/document.go +++ b/internal/cosmosdbapi/document.go @@ -1,8 +1,8 @@ package cosmosdbapi import ( + "context" "fmt" - ) func GetDocumentId(tenantName string, collectionName string, id string) string { @@ -11,4 +11,25 @@ func GetDocumentId(tenantName string, collectionName string, id string) string { func GetPartitionKey(tenantName string, collectionName string) string { return fmt.Sprintf("%s,%s", collectionName, tenantName) -} \ No newline at end of file +} + +func GetDocumentOrNil(connection CosmosConnection, config CosmosConfig, ctx context.Context, partitionKey string, cosmosDocId string, dbData interface{}) error { + var _, err = GetClient(connection).GetDocument( + ctx, + config.DatabaseName, + config.ContainerName, + cosmosDocId, + GetGetDocumentOptions(partitionKey), + &dbData, + ) + + if err != nil { + if err.Error() == "Resource that no longer exists" { + dbData = nil + return nil + } + return err + } + + return nil +} diff --git a/internal/cosmosdbapi/documentoperations.go b/internal/cosmosdbapi/documentoperations.go index 37e8ea883..ad52c05c7 100644 --- a/internal/cosmosdbapi/documentoperations.go +++ b/internal/cosmosdbapi/documentoperations.go @@ -6,14 +6,14 @@ import ( func GetCreateDocumentOptions(pk string) cosmosapi.CreateDocumentOptions { return cosmosapi.CreateDocumentOptions{ - IsUpsert: false, + IsUpsert: false, PartitionKeyValue: pk, } } func GetUpsertDocumentOptions(pk string) cosmosapi.CreateDocumentOptions { return cosmosapi.CreateDocumentOptions{ - IsUpsert: true, + IsUpsert: true, PartitionKeyValue: pk, } } @@ -21,8 +21,16 @@ func GetUpsertDocumentOptions(pk string) cosmosapi.CreateDocumentOptions { func GetQueryDocumentsOptions(pk string) cosmosapi.QueryDocumentsOptions { return cosmosapi.QueryDocumentsOptions{ PartitionKeyValue: pk, - IsQuery: true, - ContentType: cosmosapi.QUERY_CONTENT_TYPE, + IsQuery: true, + ContentType: cosmosapi.QUERY_CONTENT_TYPE, + } +} + +func GetQueryAllPartitionsDocumentsOptions() cosmosapi.QueryDocumentsOptions { + return cosmosapi.QueryDocumentsOptions{ + IsQuery: true, + EnableCrossPartition: true, + ContentType: cosmosapi.QUERY_CONTENT_TYPE, } } @@ -35,7 +43,7 @@ func GetGetDocumentOptions(pk string) cosmosapi.GetDocumentOptions { func GetReplaceDocumentOptions(pk string, etag string) cosmosapi.ReplaceDocumentOptions { return cosmosapi.ReplaceDocumentOptions{ PartitionKeyValue: pk, - IfMatch: etag, + IfMatch: etag, } } diff --git a/internal/cosmosdbutil/document_seq.go b/internal/cosmosdbutil/document_seq.go new file mode 100644 index 000000000..ecce50a3f --- /dev/null +++ b/internal/cosmosdbutil/document_seq.go @@ -0,0 +1,76 @@ +package cosmosdbutil + +import ( + "context" + "fmt" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" +) + +type SequenceCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Value int64 `json:"_value"` +} + +func GetNextSequence( + ctx context.Context, + connection cosmosdbapi.CosmosConnection, + config cosmosdbapi.CosmosConfig, + serviceName string, + tableName string, + seqId string, + initial int64, +) (int64, error) { + collName := fmt.Sprintf("%s_%s", tableName, seqId) + dbCollectionName := cosmosdbapi.GetCollectionName(serviceName, collName) + cosmosDocId := cosmosdbapi.GetDocumentId(config.ContainerName, dbCollectionName, seqId) + pk := cosmosDocId + + dbData := SequenceCosmosData{} + cosmosdbapi.GetDocumentOrNil( + connection, + config, + ctx, + pk, + cosmosDocId, + &dbData, + ) + + if dbData.Id == "" { + dbData = SequenceCosmosData{} + dbData.Id = cosmosDocId + dbData.Pk = pk + dbData.Cn = dbCollectionName + dbData.Value = initial + var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + var _, _, err = cosmosdbapi.GetClient(connection).CreateDocument( + ctx, + config.DatabaseName, + config.ContainerName, + dbData, + optionsCreate, + ) + if err != nil { + return -1, err + } + } else { + dbData.Value++ + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(dbData.Pk, dbData.ETag) + var _, _, err = cosmosdbapi.GetClient(connection).ReplaceDocument( + ctx, + config.DatabaseName, + config.ContainerName, + cosmosDocId, + dbData, + optionsReplace, + ) + if err != nil { + return -1, err + } + } + return dbData.Value, nil +} diff --git a/internal/cosmosdbutil/errors.go b/internal/cosmosdbutil/errors.go new file mode 100644 index 000000000..320669dce --- /dev/null +++ b/internal/cosmosdbutil/errors.go @@ -0,0 +1,12 @@ +package cosmosdbutil + +import ( + "database/sql" + "errors" +) + +// ErrNoRows is returned by Scan when QueryRow doesn't return a +// row. Used to simulate the SQL responses as its used for business logic +var ErrNoRows = sql.ErrNoRows + +var ErrNotImplemented = errors.New("cosmosdb: not implemented") diff --git a/internal/cosmosdbutil/writer_exclusive.go b/internal/cosmosdbutil/writer_exclusive.go new file mode 100644 index 000000000..c4b759cdf --- /dev/null +++ b/internal/cosmosdbutil/writer_exclusive.go @@ -0,0 +1,77 @@ +package cosmosdbutil + +import ( + "database/sql" + "errors" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "go.uber.org/atomic" +) + +// ExclusiveWriter implements sqlutil.Writer. +// Allows non-SQL DBs to still use the same infrastructure +// as Matrix assumes SQL TXN +type ExclusiveWriterFake struct { + running atomic.Bool + todo chan transactionWriterTaskFake +} + +func NewExclusiveWriterFake() sqlutil.Writer { + return &ExclusiveWriterFake{ + todo: make(chan transactionWriterTaskFake), + } +} + +// transactionWriterTask represents a specific task. +type transactionWriterTaskFake struct { + db *sql.DB + txn *sql.Tx + f func(txn *sql.Tx) error + wait chan error +} + +func (w *ExclusiveWriterFake) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error { + if w.todo == nil { + return errors.New("not initialised") + } + if !w.running.Load() { + go w.run() + } + task := transactionWriterTaskFake{ + db: db, + txn: txn, + f: f, + wait: make(chan error, 1), + } + w.todo <- task + return <-task.wait +} + +// run processes the tasks for a given transaction writer. Only one +// of these goroutines will run at a time. A transaction will be +// opened using the database object from the task and then this will +// be passed as a parameter to the task function. +func (w *ExclusiveWriterFake) run() { + if !w.running.CAS(false, true) { + return + } + // if tracingEnabled { + // gid := goid() + // goidToWriter.Store(gid, w) + // defer goidToWriter.Delete(gid) + // } + + defer w.running.Store(false) + for task := range w.todo { + if task.db != nil && task.txn != nil { + task.wait <- task.f(task.txn) + // } else if task.db != nil && task.txn == nil { + // task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error { + // return task.f(txn) + // }) + } else { + task.wait <- task.f(nil) + } + close(task.wait) + } +} diff --git a/roomserver/storage/cosmosdb/deltas/20201028212440_add_forgotten_column.go b/roomserver/storage/cosmosdb/deltas/20201028212440_add_forgotten_column.go new file mode 100644 index 000000000..33fe9e2a9 --- /dev/null +++ b/roomserver/storage/cosmosdb/deltas/20201028212440_add_forgotten_column.go @@ -0,0 +1,82 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/pressly/goose" +) + +func LoadFromGoose() { + goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn) +} + +func LoadAddForgottenColumn(m *sqlutil.Migrations) { + m.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn) +} + +func UpAddForgottenColumn(tx *sql.Tx) error { + _, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp; +CREATE TABLE IF NOT EXISTS roomserver_membership ( + room_nid INTEGER NOT NULL, + target_nid INTEGER NOT NULL, + sender_nid INTEGER NOT NULL DEFAULT 0, + membership_nid INTEGER NOT NULL DEFAULT 1, + event_nid INTEGER NOT NULL DEFAULT 0, + target_local BOOLEAN NOT NULL DEFAULT false, + forgotten BOOLEAN NOT NULL DEFAULT false, + UNIQUE (room_nid, target_nid) + ); +INSERT + INTO roomserver_membership ( + room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local + ) SELECT + room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local + FROM roomserver_membership_tmp +; +DROP TABLE roomserver_membership_tmp;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownAddForgottenColumn(tx *sql.Tx) error { + _, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp; +CREATE TABLE IF NOT EXISTS roomserver_membership ( + room_nid INTEGER NOT NULL, + target_nid INTEGER NOT NULL, + sender_nid INTEGER NOT NULL DEFAULT 0, + membership_nid INTEGER NOT NULL DEFAULT 1, + event_nid INTEGER NOT NULL DEFAULT 0, + target_local BOOLEAN NOT NULL DEFAULT false, + UNIQUE (room_nid, target_nid) + ); +INSERT + INTO roomserver_membership ( + room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local + ) SELECT + room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local + FROM roomserver_membership_tmp +; +DROP TABLE roomserver_membership_tmp;`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/roomserver/storage/cosmosdb/event_json_table.go b/roomserver/storage/cosmosdb/event_json_table.go index 05b6b1b62..d7585be56 100644 --- a/roomserver/storage/cosmosdb/event_json_table.go +++ b/roomserver/storage/cosmosdb/event_json_table.go @@ -18,76 +18,152 @@ package cosmosdb import ( "context" "database/sql" - "strings" + "fmt" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) -const eventJSONSchema = ` - CREATE TABLE IF NOT EXISTS roomserver_event_json ( - event_nid INTEGER NOT NULL PRIMARY KEY, - event_json TEXT NOT NULL - ); -` +// const eventJSONSchema = ` +// CREATE TABLE IF NOT EXISTS roomserver_event_json ( +// event_nid INTEGER NOT NULL PRIMARY KEY, +// event_json TEXT NOT NULL +// ); +// ` -const insertEventJSONSQL = ` - INSERT OR REPLACE INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2) -` +type EventJSONCosmos struct { + EventNID int64 `json:"event_nid"` + EventJSON []byte `json:"event_json"` +} + +type EventJSONCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + EventJSON EventJSONCosmos `json:"mx_roomserver_event_json"` +} + +// const insertEventJSONSQL = ` +// INSERT OR REPLACE INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2) +// ` // Bulk event JSON lookup by numeric event ID. // Sort by the numeric event ID. // This means that we can use binary search to lookup by numeric event ID. -const bulkSelectEventJSONSQL = ` - SELECT event_nid, event_json FROM roomserver_event_json - WHERE event_nid IN ($1) - ORDER BY event_nid ASC -` +// SELECT event_nid, event_json FROM roomserver_event_json +// WHERE event_nid IN ($1) +// ORDER BY event_nid ASC +const bulkSelectEventJSONSQL = "" + + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_json.event_nid) " + + "order by c.mx_roomserver_event_json.event_nid asc" type eventJSONStatements struct { - db *sql.DB - insertEventJSONStmt *sql.Stmt - bulkSelectEventJSONStmt *sql.Stmt + db *Database + // insertEventJSONStmt *sql.Stmt + bulkSelectEventJSONStmt string + tableName string } -func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) { - s := &eventJSONStatements{ - db: db, - } - _, err := db.Exec(eventJSONSchema) +func queryEventJSON(s *eventJSONStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventJSONCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []EventJSONCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } - return s, shared.StatementList{ - {&s.insertEventJSONStmt, insertEventJSONSQL}, - {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, - }.Prepare(db) + return response, nil +} + +func NewCosmosDBEventJSONTable(db *Database) (tables.EventJSON, error) { + s := &eventJSONStatements{ + db: db, + } + // _, err := db.Exec(eventJSONSchema) + // if err != nil { + // return nil, err + // } + // return s, shared.StatementList{ + // {&s.insertEventJSONStmt, insertEventJSONSQL}, + s.bulkSelectEventJSONStmt = bulkSelectEventJSONSQL + // }.Prepare(db) + s.tableName = "event_json" + return s, nil } func (s *eventJSONStatements) InsertEventJSON( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { - _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) + + // _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) + // INSERT OR REPLACE INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + + docId := fmt.Sprintf("%d", eventNID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + data := EventJSONCosmos{ + EventNID: int64(eventNID), + EventJSON: eventJSON, + } + + var dbData = EventJSONCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + EventJSON: data, + } + + //Insert OR Replace + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + return err } func (s *eventJSONStatements) BulkSelectEventJSON( ctx context.Context, eventNIDs []types.EventNID, ) ([]tables.EventJSONPair, error) { - iEventNIDs := make([]interface{}, len(eventNIDs)) - for k, v := range eventNIDs { - iEventNIDs[k] = v - } - selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) - rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...) + // SELECT event_nid, event_json FROM roomserver_event_json + // WHERE event_nid IN ($1) + // ORDER BY event_nid ASC + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventNIDs, + } + + response, err := queryEventJSON(s, ctx, s.bulkSelectEventJSONStmt, params) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventJSON: rows.close() failed") // We know that we will only get as many results as event NIDs // because of the unique constraint on event NIDs. @@ -95,13 +171,11 @@ func (s *eventJSONStatements) BulkSelectEventJSON( // We might get fewer results than NIDs so we adjust the length of the slice before returning it. results := make([]tables.EventJSONPair, len(eventNIDs)) i := 0 - for ; rows.Next(); i++ { + for _, item := range response { result := &results[i] - var eventNID int64 - if err := rows.Scan(&eventNID, &result.EventJSON); err != nil { - return nil, err - } - result.EventNID = types.EventNID(eventNID) + result.EventNID = types.EventNID(item.EventJSON.EventNID) + result.EventJSON = item.EventJSON.EventJSON + i++ } return results[:i], nil } diff --git a/roomserver/storage/cosmosdb/event_seq.go b/roomserver/storage/cosmosdb/event_seq.go new file mode 100644 index 000000000..cf3e07a10 --- /dev/null +++ b/roomserver/storage/cosmosdb/event_seq.go @@ -0,0 +1,24 @@ +package cosmosdb + +import ( + "context" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" +) + +func GetNextEventStateKeyNID(s *eventStateKeyStatements, ctx context.Context) (int64, error) { + const docId = "eventstatekeynid_seq" + //1 insert start at 2 + return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 2) +} + +func GetNextEventTypeNID(s *eventTypeStatements, ctx context.Context) (int64, error) { + const docId = "eventtypenid_seq" + //7 inserts start at 8 + return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 8) +} + +func GetNextEventNID(s *eventStatements, ctx context.Context) (int64, error) { + const docId = "eventnid_seq" + return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1) +} diff --git a/roomserver/storage/cosmosdb/event_state_keys_table.go b/roomserver/storage/cosmosdb/event_state_keys_table.go index a9307f68a..35129f865 100644 --- a/roomserver/storage/cosmosdb/event_state_keys_table.go +++ b/roomserver/storage/cosmosdb/event_state_keys_table.go @@ -18,96 +18,248 @@ package cosmosdb import ( "context" "database/sql" - "strings" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) -const eventStateKeysSchema = ` - CREATE TABLE IF NOT EXISTS roomserver_event_state_keys ( - event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT, - event_state_key TEXT NOT NULL UNIQUE - ); - INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key) - VALUES (1, '') - ON CONFLICT DO NOTHING; -` +// const eventStateKeysSchema = ` +// CREATE TABLE IF NOT EXISTS roomserver_event_state_keys ( +// event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT, +// event_state_key TEXT NOT NULL UNIQUE +// ); +// INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key) +// VALUES (1, '') +// ON CONFLICT DO NOTHING; +// ` + +type EventStateKeysCosmos struct { + EventStateKeyNID int64 `json:"event_state_key_nid"` + EventStateKey string `json:"event_state_key"` +} + +type EventStateKeysCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + EventStateKeys EventStateKeysCosmos `json:"mx_roomserver_event_state_keys"` +} // Same as insertEventTypeNIDSQL -const insertEventStateKeyNIDSQL = ` - INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1) - ON CONFLICT DO NOTHING; -` +// const insertEventStateKeyNIDSQL = ` +// INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1) +// ON CONFLICT DO NOTHING; +// ` -const selectEventStateKeyNIDSQL = ` - SELECT event_state_key_nid FROM roomserver_event_state_keys - WHERE event_state_key = $1 -` +// SELECT event_state_key_nid FROM roomserver_event_state_keys +// WHERE event_state_key = $1 +const selectEventStateKeyNIDSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_roomserver_event_state_keys.event_state_key = @x2" -// Bulk lookup from string state key to numeric ID for that state key. -// Takes an array of strings as the query parameter. -const bulkSelectEventStateKeySQL = ` - SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys - WHERE event_state_key IN ($1) -` +// // Bulk lookup from string state key to numeric ID for that state key. +// // Takes an array of strings as the query parameter. +// SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys +// WHERE event_state_key IN ($1) +const bulkSelectEventStateKeySQL = "" + + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_state_keys.event_state_key_nid)" // Bulk lookup from numeric ID to string state key for that state key. // Takes an array of strings as the query parameter. -const bulkSelectEventStateKeyNIDSQL = ` - SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys - WHERE event_state_key_nid IN ($1) -` +// SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys +// WHERE event_state_key_nid IN ($1) +const bulkSelectEventStateKeyNIDSQL = "" + + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_state_keys.event_state_key)" type eventStateKeyStatements struct { - db *sql.DB - insertEventStateKeyNIDStmt *sql.Stmt - selectEventStateKeyNIDStmt *sql.Stmt - bulkSelectEventStateKeyNIDStmt *sql.Stmt - bulkSelectEventStateKeyStmt *sql.Stmt + db *Database + insertEventStateKeyNIDStmt string + selectEventStateKeyNIDStmt string + bulkSelectEventStateKeyNIDStmt string + bulkSelectEventStateKeyStmt string + tableName string } -func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { - s := &eventStateKeyStatements{ - db: db, - } - _, err := db.Exec(eventStateKeysSchema) +func queryEventStateKeys(s *eventStateKeyStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventStateKeysCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []EventStateKeysCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } - return s, shared.StatementList{ - {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, - {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, - {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, - {&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL}, - }.Prepare(db) + return response, nil +} + +func getEventStateKeys(s *eventStateKeyStatements, ctx context.Context, pk string, docId string) (*EventStateKeysCosmosData, error) { + response := EventStateKeysCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, nil + } + + return &response, err +} + +func NewCosmosDBEventStateKeysTable(db *Database) (tables.EventStateKeys, error) { + s := &eventStateKeyStatements{ + db: db, + } + // return s, shared.StatementList{ + // {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, + s.selectEventStateKeyNIDStmt = selectEventStateKeyNIDSQL + s.bulkSelectEventStateKeyNIDStmt = bulkSelectEventStateKeyNIDSQL + s.bulkSelectEventStateKeyStmt = bulkSelectEventStateKeySQL + // }.Prepare(db) + s.tableName = "event_state_keys" + //Add in the initial data + ensureEventStateKeys(s, context.Background()) + return s, nil +} + +func ensureEventStateKeys(s *eventStateKeyStatements, ctx context.Context) { + + // INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key) + // VALUES (1, '') + // ON CONFLICT DO NOTHING; + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // event_state_key TEXT NOT NULL UNIQUE + docId := "" + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + data := EventStateKeysCosmos{ + EventStateKey: "", + EventStateKeyNID: 1, + } + + // event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT, + dbData := EventStateKeysCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + EventStateKeys: data, + } + + insertEventStateKeyCore(s, ctx, dbData) +} + +func insertEventStateKeyCore(s *eventStateKeyStatements, ctx context.Context, dbData EventStateKeysCosmosData) error { + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData, + options) + + if err != nil { + return err + } + + return nil } func (s *eventStateKeyStatements) InsertEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { - insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt) - res, err := insertStmt.ExecContext(ctx, eventStateKey) - if err != nil { - return 0, err + + // INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1) + // ON CONFLICT DO NOTHING; + if len(eventStateKey) == 0 { + return 0, cosmosdbutil.ErrNoRows } - eventStateKeyNID, err := res.LastInsertId() - if err != nil { - return 0, err + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // event_state_key TEXT NOT NULL UNIQUE + docId := eventStateKey + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + existing, _ := getEventStateKeys(s, ctx, pk, cosmosDocId) + + var dbData EventStateKeysCosmosData + if existing == nil { + //Not exists, we need to create a new one with a SEQ + eventStateKeyNIDSeq, seqErr := GetNextEventStateKeyNID(s, ctx) + if seqErr != nil { + return -1, seqErr + } + + data := EventStateKeysCosmos{ + EventStateKey: eventStateKey, + EventStateKeyNID: eventStateKeyNIDSeq, + } + + // event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT, + dbData = EventStateKeysCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + EventStateKeys: data, + } + } else { + dbData.EventStateKeys = existing.EventStateKeys } - return types.EventStateKeyNID(eventStateKeyNID), err + + err := insertEventStateKeyCore(s, ctx, dbData) + + return types.EventStateKeyNID(dbData.EventStateKeys.EventStateKeyNID), err } func (s *eventStateKeyStatements) SelectEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { - var eventStateKeyNID int64 - stmt := sqlutil.TxStmt(txn, s.selectEventStateKeyNIDStmt) - err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) - return types.EventStateKeyNID(eventStateKeyNID), err + + // SELECT event_state_key_nid FROM roomserver_event_state_keys + // WHERE event_state_key = $1 + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventStateKey, + } + + response, err := queryEventStateKeys(s, ctx, s.selectEventStateKeyNIDStmt, params) + + if err != nil { + return 0, err + } + //See storage.assignStateKeyNID() + if len(response) == 0 { + return 0, cosmosdbutil.ErrNoRows + } + + return types.EventStateKeyNID(response[0].EventStateKeys.EventStateKeyNID), err } func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( @@ -117,21 +269,25 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( for k, v := range eventStateKeys { iEventStateKeys[k] = v } - selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1) - rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...) + // SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys + // WHERE event_state_key IN ($1) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventStateKeys, + } + + response, err := queryEventStateKeys(s, ctx, s.bulkSelectEventStateKeyNIDStmt, params) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKeyNID: rows.close() failed") + result := make(map[string]types.EventStateKeyNID, len(eventStateKeys)) - for rows.Next() { - var stateKey string - var stateKeyNID int64 - if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { - return nil, err - } - result[stateKey] = types.EventStateKeyNID(stateKeyNID) + for _, item := range response { + result[item.EventStateKeys.EventStateKey] = types.EventStateKeyNID(item.EventStateKeys.EventStateKeyNID) } return result, nil } @@ -139,25 +295,24 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( func (s *eventStateKeyStatements) BulkSelectEventStateKey( ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { - iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs)) - for k, v := range eventStateKeyNIDs { - iEventStateKeyNIDs[k] = v - } - selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1) - rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) + // SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys + // WHERE event_state_key_nid IN ($1) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventStateKeyNIDs, + } + + response, err := queryEventStateKeys(s, ctx, s.bulkSelectEventStateKeyStmt, params) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKey: rows.close() failed") result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs)) - for rows.Next() { - var stateKey string - var stateKeyNID int64 - if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { - return nil, err - } - result[types.EventStateKeyNID(stateKeyNID)] = stateKey + for _, item := range response { + result[types.EventStateKeyNID(item.EventStateKeys.EventStateKeyNID)] = item.EventStateKeys.EventStateKey } return result, nil } diff --git a/roomserver/storage/cosmosdb/event_types_table.go b/roomserver/storage/cosmosdb/event_types_table.go index a63b537ce..72ed99419 100644 --- a/roomserver/storage/cosmosdb/event_types_table.go +++ b/roomserver/storage/cosmosdb/event_types_table.go @@ -18,30 +18,44 @@ package cosmosdb import ( "context" "database/sql" - "fmt" - "strings" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) -const eventTypesSchema = ` - CREATE TABLE IF NOT EXISTS roomserver_event_types ( - event_type_nid INTEGER PRIMARY KEY AUTOINCREMENT, - event_type TEXT NOT NULL UNIQUE - ); - INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES - (1, 'm.room.create'), - (2, 'm.room.power_levels'), - (3, 'm.room.join_rules'), - (4, 'm.room.third_party_invite'), - (5, 'm.room.member'), - (6, 'm.room.redaction'), - (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING; -` +// const eventTypesSchema = ` +// CREATE TABLE IF NOT EXISTS roomserver_event_types ( +// event_type_nid INTEGER PRIMARY KEY AUTOINCREMENT, +// event_type TEXT NOT NULL UNIQUE +// ); +// INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES +// (1, 'm.room.create'), +// (2, 'm.room.power_levels'), +// (3, 'm.room.join_rules'), +// (4, 'm.room.third_party_invite'), +// (5, 'm.room.member'), +// (6, 'm.room.redaction'), +// (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING; +// ` + +type EventTypeCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + EventType EventTypeCosmos `json:"mx_roomserver_event_type"` +} + +type EventTypeCosmos struct { + EventTypeNID int64 `json:"event_type_nid"` + EventType string `json:"event_type"` +} // Assign a new numeric event type ID. // The usual case is that the event type is not in the database. @@ -56,105 +70,243 @@ const eventTypesSchema = ` // return it. Modifying the rows will cause postgres to assign a new tuple for the // row even though the data doesn't change resulting in unncesssary modifications // to the indexes. -const insertEventTypeNIDSQL = ` - INSERT INTO roomserver_event_types (event_type) VALUES ($1) - ON CONFLICT DO NOTHING; -` +// const insertEventTypeNIDSQL = ` +// INSERT INTO roomserver_event_types (event_type) VALUES ($1) +// ON CONFLICT DO NOTHING; +// ` -const insertEventTypeNIDResultSQL = ` - SELECT event_type_nid FROM roomserver_event_types - WHERE rowid = last_insert_rowid(); -` +// const insertEventTypeNIDResultSQL = ` +// SELECT event_type_nid FROM roomserver_event_types +// WHERE rowid = last_insert_rowid(); +// ` -const selectEventTypeNIDSQL = ` - SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1 -` +// const selectEventTypeNIDSQL = ` +// SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1 +// ` // Bulk lookup from string event type to numeric ID for that event type. // Takes an array of strings as the query parameter. -const bulkSelectEventTypeNIDSQL = ` - SELECT event_type, event_type_nid FROM roomserver_event_types - WHERE event_type IN ($1) -` +// SELECT event_type, event_type_nid FROM roomserver_event_types +// WHERE event_type IN ($1) +const bulkSelectEventTypeNIDSQL = "" + + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_type.event_type)" type eventTypeStatements struct { - db *sql.DB - insertEventTypeNIDStmt *sql.Stmt - insertEventTypeNIDResultStmt *sql.Stmt - selectEventTypeNIDStmt *sql.Stmt - bulkSelectEventTypeNIDStmt *sql.Stmt + db *Database + // insertEventTypeNIDStmt *sql.Stmt + // insertEventTypeNIDResultStmt *sql.Stmt + // selectEventTypeNIDStmt *sql.Stmt + bulkSelectEventTypeNIDStmt string + tableName string } -func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) { +func NewCosmosDBEventTypesTable(db *Database) (tables.EventTypes, error) { s := &eventTypeStatements{ db: db, } - _, err := db.Exec(eventTypesSchema) + + // return s, shared.StatementList{ + // {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, + // {&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL}, + // {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, + s.bulkSelectEventTypeNIDStmt = bulkSelectEventTypeNIDSQL + // }.Prepare(db) + s.tableName = "event_types" + ensureEventTypes(s, context.Background()) + return s, nil +} + +func queryEventTypes(s *eventTypeStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventTypeCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []EventTypeCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } - - return s, shared.StatementList{ - {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, - {&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL}, - {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, - {&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL}, - }.Prepare(db) + return response, nil } func (s *eventTypeStatements) InsertEventTypeNID( ctx context.Context, txn *sql.Tx, eventType string, ) (types.EventTypeNID, error) { - var eventTypeNID int64 - insertStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDStmt) - resultStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDResultStmt) - _, err := insertStmt.ExecContext(ctx, eventType) + //We need to create a new one with a SEQ + eventTypeNIDSeq, seqErr := GetNextEventTypeNID(s, ctx) + if seqErr != nil { + return -1, seqErr + } + + data := EventTypeCosmos{ + EventType: eventType, + EventTypeNID: eventTypeNIDSeq, + } + + dbData, err := insertEventTypeCore(s, ctx, data) + if err != nil { - return 0, fmt.Errorf("insertStmt.ExecContext: %w", err) + return 0, err } - if err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID); err != nil { - return 0, fmt.Errorf("resultStmt.QueryRowContext.Scan: %w", err) + + return types.EventTypeNID(dbData.EventTypeNID), err +} + +func insertEventTypeCore(s *eventTypeStatements, ctx context.Context, eventType EventTypeCosmos) (*EventTypeCosmos, error) { + // INSERT INTO roomserver_event_types (event_type) VALUES ($1) + // ON CONFLICT DO NOTHING; + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + + //Unique on eventType + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, eventType.EventType) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + var dbData = EventTypeCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + EventType: eventType, } - return types.EventTypeNID(eventTypeNID), err + + var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + + if err != nil { + dbData, errGet := selectEventTypeCore(s, ctx, eventType.EventType) + if errGet != nil { + return nil, errGet + } + return dbData, nil + } + + return &dbData.EventType, err +} + +func ensureEventTypes(s *eventTypeStatements, ctx context.Context) error { + // INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES + // (1, 'm.room.create'), + // (2, 'm.room.power_levels'), + // (3, 'm.room.join_rules'), + // (4, 'm.room.third_party_invite'), + // (5, 'm.room.member'), + // (6, 'm.room.redaction'), + // (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING; + + // (1, 'm.room.create'), + _, err := insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 1, EventType: "m.room.create"}) + if err != nil { + return err + } + // (2, 'm.room.power_levels'), + _, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 2, EventType: "m.room.power_levels"}) + if err != nil { + return err + } + // (3, 'm.room.join_rules'), + _, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 3, EventType: "m.room.join_rules"}) + if err != nil { + return err + } + // (4, 'm.room.third_party_invite'), + _, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 4, EventType: "m.room.third_party_invite"}) + if err != nil { + return err + } + // (5, 'm.room.member'), + _, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 5, EventType: "m.room.member"}) + if err != nil { + return err + } + // (6, 'm.room.redaction'), + _, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 6, EventType: "m.room.redaction"}) + if err != nil { + return err + } + // (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING; + _, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 7, EventType: "m.room.history_visibility"}) + if err != nil { + return err + } + return nil +} + +func selectEventTypeCore(s *eventTypeStatements, ctx context.Context, eventType string) (*EventTypeCosmos, error) { + var response EventTypeCosmosData + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, eventType) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + cosmosDocId, + &response) + + if err != nil { + return nil, err + } + + if response.Id == "" { + return nil, cosmosdbutil.ErrNoRows + } + + return &response.EventType, nil } func (s *eventTypeStatements) SelectEventTypeNID( ctx context.Context, tx *sql.Tx, eventType string, ) (types.EventTypeNID, error) { - var eventTypeNID int64 - selectStmt := sqlutil.TxStmt(tx, s.selectEventTypeNIDStmt) - err := selectStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) - return types.EventTypeNID(eventTypeNID), err + + // SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1 + + dbData, err := selectEventTypeCore(s, ctx, eventType) + if err != nil { + return -1, err + } + return types.EventTypeNID(dbData.EventTypeNID), nil } func (s *eventTypeStatements) BulkSelectEventTypeNID( ctx context.Context, eventTypes []string, ) (map[string]types.EventTypeNID, error) { - /////////////// - iEventTypes := make([]interface{}, len(eventTypes)) - for k, v := range eventTypes { - iEventTypes[k] = v - } - selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventTypes)), 1) - selectPrep, err := s.db.Prepare(selectOrig) - if err != nil { - return nil, err - } - /////////////// - rows, err := selectPrep.QueryContext(ctx, iEventTypes...) + // SELECT event_type, event_type_nid FROM roomserver_event_types + // WHERE event_type IN ($1) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventTypes, + } + + response, err := queryEventTypes(s, ctx, s.bulkSelectEventTypeNIDStmt, params) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventTypeNID: rows.close() failed") result := make(map[string]types.EventTypeNID, len(eventTypes)) - for rows.Next() { + for _, item := range response { var eventType string var eventTypeNID int64 - if err := rows.Scan(&eventType, &eventTypeNID); err != nil { - return nil, err - } + eventType = item.EventType.EventType + eventTypeNID = item.EventType.EventTypeNID result[eventType] = types.EventTypeNID(eventTypeNID) } return result, nil diff --git a/roomserver/storage/cosmosdb/events_table.go b/roomserver/storage/cosmosdb/events_table.go index d8c83cad1..c05539b5e 100644 --- a/roomserver/storage/cosmosdb/events_table.go +++ b/roomserver/storage/cosmosdb/events_table.go @@ -18,127 +18,307 @@ package cosmosdb import ( "context" "database/sql" - "encoding/json" "fmt" - "strings" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) -const eventsSchema = ` - CREATE TABLE IF NOT EXISTS roomserver_events ( - event_nid INTEGER PRIMARY KEY AUTOINCREMENT, - room_nid INTEGER NOT NULL, - event_type_nid INTEGER NOT NULL, - event_state_key_nid INTEGER NOT NULL, - sent_to_output BOOLEAN NOT NULL DEFAULT FALSE, - state_snapshot_nid INTEGER NOT NULL DEFAULT 0, - depth INTEGER NOT NULL, - event_id TEXT NOT NULL UNIQUE, - reference_sha256 BLOB NOT NULL, - auth_event_nids TEXT NOT NULL DEFAULT '[]', - is_rejected BOOLEAN NOT NULL DEFAULT FALSE - ); -` +// const eventsSchema = ` +// CREATE TABLE IF NOT EXISTS roomserver_events ( +// event_nid INTEGER PRIMARY KEY AUTOINCREMENT, +// room_nid INTEGER NOT NULL, +// event_type_nid INTEGER NOT NULL, +// event_state_key_nid INTEGER NOT NULL, +// sent_to_output BOOLEAN NOT NULL DEFAULT FALSE, +// state_snapshot_nid INTEGER NOT NULL DEFAULT 0, +// depth INTEGER NOT NULL, +// event_id TEXT NOT NULL UNIQUE, +// reference_sha256 BLOB NOT NULL, +// auth_event_nids TEXT NOT NULL DEFAULT '[]', +// is_rejected BOOLEAN NOT NULL DEFAULT FALSE +// ); +// ` -const insertEventSQL = ` - INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - ON CONFLICT DO NOTHING; -` +type EventCosmos struct { + EventNID int64 `json:"event_nid"` + RoomNID int64 `json:"room_nid"` + EventTypeNID int64 `json:"event_type_nid"` + EventStateKeyNID int64 `json:"event_state_key_nid"` + SentToOutput bool `json:"sent_to_output"` + StateSnapshotNID int64 `json:"state_snapshot_nid"` + Depth int64 `json:"depth"` + EventId string `json:"event_id"` + ReferenceSha256 []byte `json:"reference_sha256"` + AuthEventNIDs []int64 `json:"auth_event_nids"` + IsRejected bool `json:"is_rejected"` +} -const selectEventSQL = "" + - "SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1" +type EventCosmosMaxDepth struct { + Max int64 `json:"maxdepth"` +} + +type EventCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Event EventCosmos `json:"mx_roomserver_event"` +} + +// const insertEventSQL = ` +// INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected) +// VALUES ($1, $2, $3, $4, $5, $6, $7, $8) +// ON CONFLICT DO NOTHING; +// ` + +// "SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1" +// const selectEventSQL = "" + +// "select * from c where c._cn = @x1 and c.mx_roomserver_event.event_id = @x2" // Bulk lookup of events by string ID. // Sort by the numeric IDs for event type and state key. // This means we can use binary search to lookup entries by type and state key. +// "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + +// " WHERE event_id IN ($1)" + +// " ORDER BY event_type_nid, event_state_key_nid ASC" const bulkSelectStateEventByIDSQL = "" + - "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + - " WHERE event_id IN ($1)" + - " ORDER BY event_type_nid, event_state_key_nid ASC" + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_event.event_id) " + + "order by c.mx_roomserver_event.event_type_nid " + + // Cant do multi field order by - The order by query does not have a corresponding composite index that it can be served from + // ", c.mx_roomserver_event.event_state_key_nid " + + "asc" +// "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + +// " WHERE event_id IN ($1)" const bulkSelectStateAtEventByIDSQL = "" + - "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + - " WHERE event_id IN ($1)" + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_event.event_id)" +// "UPDATE roomserver_events SET state_snapshot_nid = $1 WHERE event_nid = $2" const updateEventStateSQL = "" + - "UPDATE roomserver_events SET state_snapshot_nid = $1 WHERE event_nid = $2" + "select * from c where c._cn = @x1 and c.mx_roomserver_event.event_nid = @x2" +// "SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1" const selectEventSentToOutputSQL = "" + - "SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1" + "select * from c where c._cn = @x1 and c.mx_roomserver_event.event_nid = @x2" +// "UPDATE roomserver_events SET sent_to_output = TRUE WHERE event_nid = $1" const updateEventSentToOutputSQL = "" + - "UPDATE roomserver_events SET sent_to_output = TRUE WHERE event_nid = $1" + "select * from c where c._cn = @x1 and c.mx_roomserver_event.event_nid = @x2" +// "SELECT event_id FROM roomserver_events WHERE event_nid = $1" const selectEventIDSQL = "" + - "SELECT event_id FROM roomserver_events WHERE event_nid = $1" + "select * from c where c._cn = @x1 and c.mx_roomserver_event.event_nid = @x2" +// "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" + +// " FROM roomserver_events WHERE event_nid IN ($1)" const bulkSelectStateAtEventAndReferenceSQL = "" + - "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" + - " FROM roomserver_events WHERE event_nid IN ($1)" + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_event.event_nid)" +// "SELECT event_id, reference_sha256 FROM roomserver_events WHERE event_nid IN ($1)" const bulkSelectEventReferenceSQL = "" + - "SELECT event_id, reference_sha256 FROM roomserver_events WHERE event_nid IN ($1)" + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_event.event_nid)" +// "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid IN ($1)" const bulkSelectEventIDSQL = "" + - "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid IN ($1)" + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_event.event_nid)" +// "SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)" const bulkSelectEventNIDSQL = "" + - "SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)" + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_event.event_id)" +// "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" const selectMaxEventDepthSQL = "" + - "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" + "select sub.maxinner != null ? sub.maxinner + 1 : 0 as maxdepth from " + + "(select MAX(c.mx_roomserver_event.depth) maxinner from c where c._cn = @x1 " + + " and ARRAY_CONTAINS(@x2, c.mx_roomserver_event.event_nid)) sub" +// "SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid IN ($1)" const selectRoomNIDsForEventNIDsSQL = "" + - "SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid IN ($1)" + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_event.event_nid)" type eventStatements struct { - db *sql.DB - insertEventStmt *sql.Stmt - selectEventStmt *sql.Stmt - bulkSelectStateEventByIDStmt *sql.Stmt - bulkSelectStateAtEventByIDStmt *sql.Stmt - updateEventStateStmt *sql.Stmt - selectEventSentToOutputStmt *sql.Stmt - updateEventSentToOutputStmt *sql.Stmt - selectEventIDStmt *sql.Stmt - bulkSelectStateAtEventAndReferenceStmt *sql.Stmt - bulkSelectEventReferenceStmt *sql.Stmt - bulkSelectEventIDStmt *sql.Stmt - bulkSelectEventNIDStmt *sql.Stmt - //selectRoomNIDsForEventNIDsStmt *sql.Stmt + db *Database + // insertEventStmt *sql.Stmt + // selectEventStmt string + bulkSelectStateEventByIDStmt string + bulkSelectStateAtEventByIDStmt string + updateEventStateStmt string + selectEventSentToOutputStmt string + updateEventSentToOutputStmt string + selectEventIDStmt string + bulkSelectStateAtEventAndReferenceStmt string + bulkSelectEventReferenceStmt string + bulkSelectEventIDStmt string + bulkSelectEventNIDStmt string + // selectRoomNIDsForEventNIDsStmt string + tableName string } -func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) { +func NewCosmosDBEventsTable(db *Database) (tables.Events, error) { s := &eventStatements{ db: db, } - _, err := db.Exec(eventsSchema) + // _, err := db.Exec(eventsSchema) + // if err != nil { + // return nil, err + // } + s.tableName = "events" + // return s, shared.StatementList{ + // {&s.insertEventStmt, insertEventSQL}, + // s.selectEventStmt = selectEventSQL + s.bulkSelectStateEventByIDStmt = bulkSelectStateEventByIDSQL + s.bulkSelectStateAtEventByIDStmt = bulkSelectStateAtEventByIDSQL + s.updateEventStateStmt = updateEventStateSQL + s.updateEventSentToOutputStmt = updateEventSentToOutputSQL + s.selectEventSentToOutputStmt = selectEventSentToOutputSQL + s.selectEventIDStmt = selectEventIDSQL + s.bulkSelectStateAtEventAndReferenceStmt = bulkSelectStateAtEventAndReferenceSQL + s.bulkSelectEventReferenceStmt = bulkSelectEventReferenceSQL + s.bulkSelectEventIDStmt = bulkSelectEventIDSQL + s.bulkSelectEventNIDStmt = bulkSelectEventNIDSQL + // }.Prepare(db) + return s, nil +} + +func mapFromEventNIDArray(eventNIDs []types.EventNID) []int64 { + result := []int64{} + for i := 0; i < len(eventNIDs); i++ { + result = append(result, int64(eventNIDs[i])) + } + return result +} + +func queryEvent(s *eventStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []EventCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } + return response, nil +} - return s, shared.StatementList{ - {&s.insertEventStmt, insertEventSQL}, - {&s.selectEventStmt, selectEventSQL}, - {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, - {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, - {&s.updateEventStateStmt, updateEventStateSQL}, - {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, - {&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL}, - {&s.selectEventIDStmt, selectEventIDSQL}, - {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, - {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, - {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, - {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, - //{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, - }.Prepare(db) +func getEvent(s *eventStatements, ctx context.Context, pk string, docId string) (*EventCosmosData, error) { + response := EventCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, cosmosdbutil.ErrNoRows + } + + return &response, err +} + +func setEvent(s *eventStatements, ctx context.Context, pk string, event EventCosmosData) (*EventCosmosData, error) { + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, event.ETag) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + event.Id, + &event, + optionsReplace) + return &event, ex +} + +func isEventAuthEventNIDsSame( + ids []int64, + authEventNIDs []types.EventNID, +) bool { + if len(ids) != len(authEventNIDs) { + return false + } + for i := 0; i < len(ids); i++ { + if ids[i] != int64(authEventNIDs[i]) { + return false + } + } + return true +} + +func isReferenceSha256Same( + ids []byte, + referenceSHA256 []byte, +) bool { + if len(ids) != len(referenceSHA256) { + return false + } + for i := 0; i < len(ids); i++ { + if ids[i] != referenceSHA256[i] { + return false + } + } + return true +} + +func isEventSame( + event EventCosmos, + roomNID types.RoomNID, + eventTypeNID types.EventTypeNID, + eventStateKeyNID types.EventStateKeyNID, + eventID string, + referenceSHA256 []byte, + authEventNIDs []types.EventNID, + depth int64, + isRejected bool, +) bool { + if event.RoomNID != int64(roomNID) { + return false + } + if event.EventTypeNID != int64(eventTypeNID) { + return false + } + if event.EventStateKeyNID != int64(eventStateKeyNID) { + return false + } + if event.EventId != eventID { + return false + } + if isReferenceSha256Same(event.ReferenceSha256, referenceSHA256) { + return false + } + if !isEventAuthEventNIDsSame(event.AuthEventNIDs, authEventNIDs) { + return false + } + if event.Depth != depth { + return false + } + if event.IsRejected != isRejected { + return false + } + return true } func (s *eventStatements) InsertEvent( @@ -153,32 +333,109 @@ func (s *eventStatements) InsertEvent( depth int64, isRejected bool, ) (types.EventNID, types.StateSnapshotNID, error) { - // attempt to insert: the last_row_id is the event NID + + // INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected) + // VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + + // event_nid INTEGER PRIMARY KEY AUTOINCREMENT, + // event_id TEXT NOT NULL UNIQUE, + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + docId := eventID + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + dbData, errGet := getEvent(s, ctx, pk, cosmosDocId) + + // ON CONFLICT DO NOTHING; + // event_nid INTEGER PRIMARY KEY AUTOINCREMENT, var eventNID int64 - insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) - result, err := insertStmt.ExecContext( - ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), - eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, isRejected, - ) + if errGet == cosmosdbutil.ErrNoRows { + eventNIDSeq, seqErr := GetNextEventNID(s, ctx) + if seqErr != nil { + return 0, 0, seqErr + } + data := EventCosmos{ + AuthEventNIDs: mapFromEventNIDArray(authEventNIDs), + Depth: depth, + EventId: eventID, + EventNID: eventNIDSeq, + EventStateKeyNID: int64(eventStateKeyNID), + EventTypeNID: int64(eventTypeNID), + IsRejected: isRejected, + ReferenceSha256: referenceSHA256, + RoomNID: int64(roomNID), + } + + dbData = &EventCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + Event: data, + } + } else { + modified := !isEventSame( + dbData.Event, + roomNID, + eventTypeNID, + eventStateKeyNID, + eventID, + referenceSHA256, + authEventNIDs, + depth, + isRejected, + ) + if modified == false { + return 0, 0, cosmosdbutil.ErrNoRows + } + dbData.Event.AuthEventNIDs = mapFromEventNIDArray(authEventNIDs) + dbData.Event.Depth = depth + // Dont change the unique keys + // dbData.Event.EventId = eventID + // dbData.Event.EventNID = eventNID + dbData.Event.EventStateKeyNID = int64(eventStateKeyNID) + dbData.Event.EventTypeNID = int64(eventTypeNID) + dbData.Event.IsRejected = isRejected + dbData.Event.ReferenceSha256 = referenceSHA256 + dbData.Event.RoomNID = int64(roomNID) + + dbData.Timestamp = time.Now().Unix() + } + + // ON CONFLICT DO NOTHING; - Do Upsert + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + _, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + if err != nil { return 0, 0, err } - modified, err := result.RowsAffected() - if modified == 0 && err == nil { - return 0, 0, sql.ErrNoRows - } - eventNID, err = result.LastInsertId() + + eventNID = dbData.Event.EventNID return types.EventNID(eventNID), 0, err } func (s *eventStatements) SelectEvent( ctx context.Context, txn *sql.Tx, eventID string, ) (types.EventNID, types.StateSnapshotNID, error) { - var eventNID int64 - var stateNID int64 - selectStmt := sqlutil.TxStmt(txn, s.selectEventStmt) - err := selectStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID) - return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err + + // "SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + docId := eventID + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + var response, err = getEvent(s, ctx, pk, cosmosDocId) + if err != nil { + return 0, 0, err + } + + var event = response.Event + + return types.EventNID(event.EventNID), types.StateSnapshotNID(event.StateSnapshotNID), err } // bulkSelectStateEventByID lookups a list of state events by event ID. @@ -186,38 +443,38 @@ func (s *eventStatements) SelectEvent( func (s *eventStatements) BulkSelectStateEventByID( ctx context.Context, eventIDs []string, ) ([]types.StateEntry, error) { - /////////////// - iEventIDs := make([]interface{}, len(eventIDs)) - for k, v := range eventIDs { - iEventIDs[k] = v + if len(eventIDs) == 0 { + return make([]types.StateEntry, len(eventIDs)), nil } - selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) - selectStmt, err := s.db.Prepare(selectOrig) - if err != nil { - return nil, err - } - /////////////// - rows, err := selectStmt.QueryContext(ctx, iEventIDs...) + // "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + + // " WHERE event_id IN ($1)" + + // " ORDER BY event_type_nid, event_state_key_nid ASC" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventIDs, + } + + response, err := queryEvent(s, ctx, s.bulkSelectStateEventByIDStmt, params) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed") + // We know that we will only get as many results as event IDs // because of the unique constraint on event IDs. // So we can allocate an array of the correct size now. // We might get fewer results than IDs so we adjust the length of the slice before returning it. - results := make([]types.StateEntry, len(eventIDs)) + results := make([]types.StateEntry, len(response)) i := 0 - for ; rows.Next(); i++ { + for _, item := range response { result := &results[i] - if err = rows.Scan( - &result.EventTypeNID, - &result.EventStateKeyNID, - &result.EventNID, - ); err != nil { - return nil, err - } + result.EventTypeNID = types.EventTypeNID(item.Event.EventTypeNID) + result.EventStateKeyNID = types.EventStateKeyNID(item.Event.EventStateKeyNID) + result.EventNID = types.EventNID(item.Event.EventNID) + i++ } if i != len(eventIDs) { // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have. @@ -238,40 +495,39 @@ func (s *eventStatements) BulkSelectStateEventByID( func (s *eventStatements) BulkSelectStateAtEventByID( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { - /////////////// - iEventIDs := make([]interface{}, len(eventIDs)) - for k, v := range eventIDs { - iEventIDs[k] = v + if len(eventIDs) == 0 { + return make([]types.StateAtEvent, len(eventIDs)), nil } - selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) - selectStmt, err := s.db.Prepare(selectOrig) + + // "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + + // " WHERE event_id IN ($1)" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventIDs, + } + + response, err := queryEvent(s, ctx, s.bulkSelectStateAtEventByIDStmt, params) + if err != nil { return nil, err } - /////////////// - rows, err := selectStmt.QueryContext(ctx, iEventIDs...) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventByID: rows.close() failed") + results := make([]types.StateAtEvent, len(eventIDs)) i := 0 - for ; rows.Next(); i++ { + for _, item := range response { result := &results[i] - if err = rows.Scan( - &result.EventTypeNID, - &result.EventStateKeyNID, - &result.EventNID, - &result.BeforeStateSnapshotNID, - &result.IsRejected, - ); err != nil { - return nil, err - } + result.EventTypeNID = types.EventTypeNID(item.Event.EventTypeNID) + result.EventStateKeyNID = types.EventStateKeyNID(item.Event.EventStateKeyNID) + result.EventNID = types.EventNID(item.Event.EventNID) + result.BeforeStateSnapshotNID = types.StateSnapshotNID(item.Event.StateSnapshotNID) + result.IsRejected = item.Event.IsRejected if result.BeforeStateSnapshotNID == 0 { return nil, types.MissingEventError( fmt.Sprintf("storage: missing state for event NID %d", result.EventNID), ) } + i++ } if i != len(eventIDs) { return nil, types.MissingEventError( @@ -284,76 +540,136 @@ func (s *eventStatements) BulkSelectStateAtEventByID( func (s *eventStatements) UpdateEventState( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { - stmt := sqlutil.TxStmt(txn, s.updateEventStateStmt) - _, err := stmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) - return err + + // "UPDATE roomserver_events SET state_snapshot_nid = $1 WHERE event_nid = $2" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventNID, + } + + response, err := queryEvent(s, ctx, s.updateEventStateStmt, params) + + if err != nil { + return err + } + + item := response[0] + item.Event.StateSnapshotNID = int64(stateNID) + + var _, exReplace = setEvent(s, ctx, item.Pk, item) + if exReplace != nil { + return exReplace + } + return exReplace } func (s *eventStatements) SelectEventSentToOutput( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (sentToOutput bool, err error) { - selectStmt := sqlutil.TxStmt(txn, s.selectEventSentToOutputStmt) - err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) + + // "SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventNID, + } + + response, err := queryEvent(s, ctx, s.selectEventSentToOutputStmt, params) + + if err != nil { + return false, err + } + + item := response[0] + sentToOutput = item.Event.SentToOutput return } func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { - updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt) - _, err := updateStmt.ExecContext(ctx, int64(eventNID)) - return err + + // "UPDATE roomserver_events SET sent_to_output = TRUE WHERE event_nid = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventNID, + } + + response, err := queryEvent(s, ctx, s.updateEventSentToOutputStmt, params) + + if err != nil { + return err + } + + item := response[0] + item.Event.SentToOutput = true + + var _, exReplace = setEvent(s, ctx, item.Pk, item) + if exReplace != nil { + return exReplace + } + return exReplace } func (s *eventStatements) SelectEventID( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (eventID string, err error) { - selectStmt := sqlutil.TxStmt(txn, s.selectEventIDStmt) - err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID) + + // "SELECT event_id FROM roomserver_events WHERE event_nid = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventNID, + } + + response, err := queryEvent(s, ctx, s.selectEventIDStmt, params) + + if err != nil { + return "", err + } + + item := response[0] + eventNID = types.EventNID(item.Event.EventNID) return } func (s *eventStatements) BulkSelectStateAtEventAndReference( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]types.StateAtEventAndReference, error) { - /////////////// - iEventNIDs := make([]interface{}, len(eventNIDs)) - for k, v := range eventNIDs { - iEventNIDs[k] = v + if len(eventNIDs) == 0 { + return make([]types.StateAtEventAndReference, len(eventNIDs)), nil } - selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) - selectPrep, err := s.db.Prepare(selectOrig) + + // "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" + + // " FROM roomserver_events WHERE event_nid IN ($1)" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventNIDs, + } + + response, err := queryEvent(s, ctx, s.bulkSelectStateAtEventAndReferenceStmt, params) + if err != nil { return nil, err } - ////////////// - rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...) - if err != nil { - return nil, fmt.Errorf("sqlutil.TxStmt.QueryContext: %w", err) - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed") - results := make([]types.StateAtEventAndReference, len(eventNIDs)) + results := make([]types.StateAtEventAndReference, len(response)) i := 0 - for ; rows.Next(); i++ { - var ( - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - stateSnapshotNID int64 - eventID string - eventSHA256 []byte - ) - if err = rows.Scan( - &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256, - ); err != nil { - return nil, err - } + for _, item := range response { result := &results[i] - result.EventTypeNID = types.EventTypeNID(eventTypeNID) - result.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - result.EventNID = types.EventNID(eventNID) - result.BeforeStateSnapshotNID = types.StateSnapshotNID(stateSnapshotNID) - result.EventID = eventID - result.EventSHA256 = eventSHA256 + result.EventTypeNID = types.EventTypeNID(item.Event.EventTypeNID) + result.EventStateKeyNID = types.EventStateKeyNID(item.Event.EventStateKeyNID) + result.EventNID = types.EventNID(item.Event.EventNID) + result.BeforeStateSnapshotNID = types.StateSnapshotNID(item.Event.StateSnapshotNID) + result.EventID = item.Event.EventId + result.EventSHA256 = item.Event.ReferenceSha256 + i++ } if i != len(eventNIDs) { return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) @@ -364,31 +680,30 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( func (s *eventStatements) BulkSelectEventReference( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]gomatrixserverlib.EventReference, error) { - /////////////// - iEventNIDs := make([]interface{}, len(eventNIDs)) - for k, v := range eventNIDs { - iEventNIDs[k] = v + if len(eventNIDs) == 0 { + return make([]gomatrixserverlib.EventReference, len(eventNIDs)), nil } - selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) - selectPrep, err := s.db.Prepare(selectOrig) - if err != nil { - return nil, err - } - /////////////// + // "SELECT event_id, reference_sha256 FROM roomserver_events WHERE event_nid IN ($1)" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventNIDs, + } + + response, err := queryEvent(s, ctx, s.bulkSelectEventReferenceStmt, params) - selectStmt := sqlutil.TxStmt(txn, selectPrep) - rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventReference: rows.close() failed") + results := make([]gomatrixserverlib.EventReference, len(eventNIDs)) i := 0 - for ; rows.Next(); i++ { + for _, item := range response { result := &results[i] - if err = rows.Scan(&result.EventID, &result.EventSHA256); err != nil { - return nil, err - } + result.EventID = item.Event.EventId + result.EventSHA256 = item.Event.ReferenceSha256 + i++ } if i != len(eventNIDs) { return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) @@ -398,32 +713,31 @@ func (s *eventStatements) BulkSelectEventReference( // bulkSelectEventID returns a map from numeric event ID to string event ID. func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { - /////////////// - iEventNIDs := make([]interface{}, len(eventNIDs)) - for k, v := range eventNIDs { - iEventNIDs[k] = v - } - selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) - selectStmt, err := s.db.Prepare(selectOrig) - if err != nil { - return nil, err - } - /////////////// - - rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventID: rows.close() failed") results := make(map[types.EventNID]string, len(eventNIDs)) + if len(eventNIDs) == 0 { + return results, nil + } + + // "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid IN ($1)" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventNIDs, + } + + response, err := queryEvent(s, ctx, s.bulkSelectEventIDStmt, params) + + if err != nil { + return nil, err + } + i := 0 - for ; rows.Next(); i++ { - var eventNID int64 - var eventID string - if err = rows.Scan(&eventNID, &eventID); err != nil { - return nil, err - } + for _, item := range response { + eventNID := item.Event.EventNID + eventID := item.Event.EventId results[types.EventNID(eventNID)] = eventID + i++ } if i != len(eventNIDs) { return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) @@ -434,82 +748,87 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { - /////////////// - iEventIDs := make([]interface{}, len(eventIDs)) - for k, v := range eventIDs { - iEventIDs[k] = v + if len(eventIDs) == 0 { + return make(map[string]types.EventNID, len(eventIDs)), nil } - selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) - selectStmt, err := s.db.Prepare(selectOrig) + // "SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventIDs, + } + + response, err := queryEvent(s, ctx, s.bulkSelectEventNIDStmt, params) + if err != nil { return nil, err } - /////////////// - rows, err := selectStmt.QueryContext(ctx, iEventIDs...) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed") + results := make(map[string]types.EventNID, len(eventIDs)) - for rows.Next() { - var eventID string - var eventNID int64 - if err = rows.Scan(&eventID, &eventNID); err != nil { - return nil, err - } + for _, item := range response { + eventID := item.Event.EventId + eventNID := item.Event.EventNID results[eventID] = types.EventNID(eventNID) } return results, nil } func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) { - var result int64 - iEventIDs := make([]interface{}, len(eventNIDs)) - for i, v := range eventNIDs { - iEventIDs[i] = v + if len(eventNIDs) == 0 { + return 0, nil } - sqlStr := strings.Replace(selectMaxEventDepthSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) - sqlPrep, err := s.db.Prepare(sqlStr) - if err != nil { - return 0, err + + // "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var response []EventCosmosMaxDepth + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventNIDs, } - err = sqlutil.TxStmt(txn, sqlPrep).QueryRowContext(ctx, iEventIDs...).Scan(&result) + + var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions() + var query = cosmosdbapi.GetQuery(selectMaxEventDepthSQL, params) + var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return 0, fmt.Errorf("sqlutil.TxStmt.QueryRowContext: %w", err) } - return result, nil + return response[0].Max, nil } func (s *eventStatements) SelectRoomNIDsForEventNIDs( ctx context.Context, eventNIDs []types.EventNID, ) (map[types.EventNID]types.RoomNID, error) { - sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1) - sqlPrep, err := s.db.Prepare(sqlStr) + if len(eventNIDs) == 0 { + return make(map[types.EventNID]types.RoomNID), nil + } + + // "SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid IN ($1)" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventNIDs, + } + + response, err := queryEvent(s, ctx, selectRoomNIDsForEventNIDsSQL, params) + if err != nil { return nil, err } - iEventNIDs := make([]interface{}, len(eventNIDs)) - for i, v := range eventNIDs { - iEventNIDs[i] = v - } - rows, err := sqlPrep.QueryContext(ctx, iEventNIDs...) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed") + result := make(map[types.EventNID]types.RoomNID) - for rows.Next() { - var eventNID types.EventNID - var roomNID types.RoomNID - if err = rows.Scan(&eventNID, &roomNID); err != nil { - return nil, err - } + for _, item := range response { + roomNID := types.RoomNID(item.Event.RoomNID) + eventNID := types.EventNID(item.Event.EventNID) result[eventNID] = roomNID } return result, nil } - -func eventNIDsAsArray(eventNIDs []types.EventNID) string { - b, _ := json.Marshal(eventNIDs) - return string(b) -} diff --git a/roomserver/storage/cosmosdb/invite_table.go b/roomserver/storage/cosmosdb/invite_table.go index 2e3bf328e..584ec2bed 100644 --- a/roomserver/storage/cosmosdb/invite_table.go +++ b/roomserver/storage/cosmosdb/invite_table.go @@ -18,73 +18,148 @@ package cosmosdb import ( "context" "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) -const inviteSchema = ` - CREATE TABLE IF NOT EXISTS roomserver_invites ( - invite_event_id TEXT PRIMARY KEY, - room_nid INTEGER NOT NULL, - target_nid INTEGER NOT NULL, - sender_nid INTEGER NOT NULL DEFAULT 0, - retired BOOLEAN NOT NULL DEFAULT FALSE, - invite_event_json TEXT NOT NULL - ); +// const inviteSchema = ` +// CREATE TABLE IF NOT EXISTS roomserver_invites ( +// invite_event_id TEXT PRIMARY KEY, +// room_nid INTEGER NOT NULL, +// target_nid INTEGER NOT NULL, +// sender_nid INTEGER NOT NULL DEFAULT 0, +// retired BOOLEAN NOT NULL DEFAULT FALSE, +// invite_event_json TEXT NOT NULL +// ); - CREATE INDEX IF NOT EXISTS roomserver_invites_active_idx ON roomserver_invites (target_nid, room_nid) - WHERE NOT retired; -` -const insertInviteEventSQL = "" + - "INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," + - " sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" + - " ON CONFLICT DO NOTHING" +// CREATE INDEX IF NOT EXISTS roomserver_invites_active_idx ON roomserver_invites (target_nid, room_nid) +// WHERE NOT retired; +// ` +type InviteCosmos struct { + InviteEventID string `json:"invite_event_id"` + RoomNID int64 `json:"room_nid"` + TargetNID int64 `json:"target_nid"` + SenderNID int64 `json:"sender_nid"` + Retired bool `json:"retired"` + InviteEventJSON []byte `json:"invite_event_json"` +} + +type InviteCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Invite InviteCosmos `json:"mx_roomserver_invite"` +} + +// const insertInviteEventSQL = "" + +// "INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," + +// " sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" + +// " ON CONFLICT DO NOTHING" + +// "SELECT invite_event_id, sender_nid FROM roomserver_invites" + +// " WHERE target_nid = $1 AND room_nid = $2" + +// " AND NOT retired" const selectInviteActiveForUserInRoomSQL = "" + - "SELECT invite_event_id, sender_nid FROM roomserver_invites" + - " WHERE target_nid = $1 AND room_nid = $2" + - " AND NOT retired" + "select * from c where c._cn = @x1 " + + " and c.mx_roomserver_invite.target_nid = @x2" + + " and c.mx_roomserver_invite.room_nid = @x3" + + " and c.mx_roomserver_invite.retired = false" // Retire every active invite for a user in a room. // Ideally we'd know which invite events were retired by a given update so we // wouldn't need to remove every active invite. // However the matrix protocol doesn't give us a way to reliably identify the // invites that were retired, so we are forced to retire all of them. -const updateInviteRetiredSQL = ` - UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired -` +// const updateInviteRetiredSQL = ` +// UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired +// ` -const selectInvitesAboutToRetireSQL = ` -SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired -` +// SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired +const selectInvitesAboutToRetireSQL = "" + + "select * from c where c._cn = @x1 " + + " and c.mx_roomserver_invite.room_nid = @x2" + + " and c.mx_roomserver_invite.target_nid = @x3" + + " and c.mx_roomserver_invite.retired = false" type inviteStatements struct { - db *sql.DB - insertInviteEventStmt *sql.Stmt - selectInviteActiveForUserInRoomStmt *sql.Stmt - updateInviteRetiredStmt *sql.Stmt - selectInvitesAboutToRetireStmt *sql.Stmt + db *Database + // insertInviteEventStmt *sql.Stmt + selectInviteActiveForUserInRoomStmt string + // updateInviteRetiredStmt *sql.Stmt + selectInvitesAboutToRetireStmt string + tableName string } -func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) { - s := &inviteStatements{ - db: db, - } - _, err := db.Exec(inviteSchema) +func queryInvite(s *inviteStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InviteCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []InviteCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } + return response, nil +} - return s, shared.StatementList{ - {&s.insertInviteEventStmt, insertInviteEventSQL}, - {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, - {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, - {&s.selectInvitesAboutToRetireStmt, selectInvitesAboutToRetireSQL}, - }.Prepare(db) +func getInvite(s *inviteStatements, ctx context.Context, pk string, docId string) (*InviteCosmosData, error) { + response := InviteCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, cosmosdbutil.ErrNoRows + } + + return &response, err +} + +func setInvite(s *inviteStatements, ctx context.Context, invite InviteCosmosData) (*InviteCosmosData, error) { + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(invite.Pk, invite.ETag) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + invite.Id, + &invite, + optionsReplace) + return &invite, ex +} + +func NewCosmosDBInvitesTable(db *Database) (tables.Invites, error) { + s := &inviteStatements{ + db: db, + } + // return s, shared.StatementList{ + // {&s.insertInviteEventStmt, insertInviteEventSQL}, + s.selectInviteActiveForUserInRoomStmt = selectInviteActiveForUserInRoomSQL + // {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, + s.selectInvitesAboutToRetireStmt = selectInvitesAboutToRetireSQL + // }.Prepare(db) + s.tableName = "invites" + return s, nil } func (s *inviteStatements) InsertInviteEvent( @@ -93,42 +168,84 @@ func (s *inviteStatements) InsertInviteEvent( targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte, ) (bool, error) { - var count int64 - stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) - result, err := stmt.ExecContext( - ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, - ) + + // "INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," + + // " sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" + + // " ON CONFLICT DO NOTHING" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + + data := InviteCosmos{ + InviteEventID: inviteEventID, + InviteEventJSON: inviteEventJSON, + Retired: false, + RoomNID: int64(roomNID), + SenderNID: int64(senderUserNID), + TargetNID: int64(targetUserNID), + } + + // invite_event_id TEXT PRIMARY KEY, + docId := inviteEventID + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + var dbData = InviteCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + Invite: data, + } + + var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + _, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + if err != nil { return false, err } - count, err = result.RowsAffected() - if err != nil { - return false, err - } - return count != 0, err + // TODO: Is this important? + // count, err = result.RowsAffected() + // return count != 0, err + return true, nil } func (s *inviteStatements) UpdateInviteRetired( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventIDs []string, err error) { + + // "SELECT invite_event_id, sender_nid FROM roomserver_invites" + + // " WHERE target_nid = $1 AND room_nid = $2" + + // " AND NOT retired" + // gather all the event IDs we will retire - stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) - rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": targetUserNID, + "@x3": roomNID, + } + + response, err := queryInvite(s, ctx, s.selectInvitesAboutToRetireStmt, params) + if err != nil { return } - defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed") - for rows.Next() { - var inviteEventID string - if err = rows.Scan(&inviteEventID); err != nil { - return - } - eventIDs = append(eventIDs, inviteEventID) + + for _, item := range response { + eventIDs = append(eventIDs, item.Invite.InviteEventID) + // UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired + + // now retire the invites + item.Invite.Retired = true + _, err = setInvite(s, ctx, item) } - // now retire the invites - stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) - _, err = stmt.ExecContext(ctx, roomNID, targetUserNID) + return } @@ -137,21 +254,27 @@ func (s *inviteStatements) SelectInviteActiveForUserInRoom( ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, ) ([]types.EventStateKeyNID, []string, error) { - rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( - ctx, targetUserNID, roomNID, - ) + + // SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomNID, + "@x3": targetUserNID, + } + + response, err := queryInvite(s, ctx, s.selectInviteActiveForUserInRoomStmt, params) + if err != nil { return nil, nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed") + var result []types.EventStateKeyNID var eventIDs []string - for rows.Next() { - var eventID string - var senderUserNID int64 - if err := rows.Scan(&eventID, &senderUserNID); err != nil { - return nil, nil, err - } + for _, item := range response { + var eventID = item.Invite.InviteEventID + var senderUserNID = item.Invite.SenderNID result = append(result, types.EventStateKeyNID(senderUserNID)) eventIDs = append(eventIDs, eventID) } diff --git a/roomserver/storage/cosmosdb/membership_table.go b/roomserver/storage/cosmosdb/membership_table.go index a318d6caf..308870427 100644 --- a/roomserver/storage/cosmosdb/membership_table.go +++ b/roomserver/storage/cosmosdb/membership_table.go @@ -19,125 +19,233 @@ import ( "context" "database/sql" "fmt" - "strings" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) -const membershipSchema = ` - CREATE TABLE IF NOT EXISTS roomserver_membership ( - room_nid INTEGER NOT NULL, - target_nid INTEGER NOT NULL, - sender_nid INTEGER NOT NULL DEFAULT 0, - membership_nid INTEGER NOT NULL DEFAULT 1, - event_nid INTEGER NOT NULL DEFAULT 0, - target_local BOOLEAN NOT NULL DEFAULT false, - forgotten BOOLEAN NOT NULL DEFAULT false, - UNIQUE (room_nid, target_nid) - ); -` +// const membershipSchema = ` +// CREATE TABLE IF NOT EXISTS roomserver_membership ( +// room_nid INTEGER NOT NULL, +// target_nid INTEGER NOT NULL, +// sender_nid INTEGER NOT NULL DEFAULT 0, +// membership_nid INTEGER NOT NULL DEFAULT 1, +// event_nid INTEGER NOT NULL DEFAULT 0, +// target_local BOOLEAN NOT NULL DEFAULT false, +// forgotten BOOLEAN NOT NULL DEFAULT false, +// UNIQUE (room_nid, target_nid) +// ); +// ` +type MembershipCosmos struct { + RoomNID int64 `json:"room_nid"` + TargetNID int64 `json:"target_nid"` + SenderNID int64 `json:"sender_nid"` + MembershipNID int64 `json:"membership_nid"` + EventNID int64 `json:"event_nid"` + TargetLocal bool `json:"target_local"` + Forgotten bool `json:"forgotten"` +} + +type MembershipCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Membership MembershipCosmos `json:"mx_roomserver_membership"` +} + +type MembershipJoinedCountCosmosData struct { + TargetNID int64 `json:"target_nid"` + RoomCount int `json:"room_count"` +} + +// "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" + +// " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + +// " GROUP BY target_nid" var selectJoinedUsersSetForRoomsSQL = "" + - "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" + - " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + - " GROUP BY target_nid" + "select c.mx_roomserver_membership.target_nid, count(c.mx_roomserver_membership.room_id) as room_count from c where c._cn = @x1 " + + " and ARRAY_CONTAINS(@x2, c.mx_roomserver_membership.room_id)" + + " and c.mx_roomserver_membership.membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + + " and c.mx_roomserver_membership.forgotten = false" + + " group by c.mx_roomserver_membership.target_nid" // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE -const insertMembershipSQL = "" + - "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" + - " VALUES ($1, $2, $3)" + - " ON CONFLICT DO NOTHING" +// const insertMembershipSQL = "" + +// "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" + +// " VALUES ($1, $2, $3)" + +// " ON CONFLICT DO NOTHING" -const selectMembershipFromRoomAndTargetSQL = "" + - "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" + - " WHERE room_nid = $1 AND target_nid = $2" +// const selectMembershipFromRoomAndTargetSQL = "" + +// "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" + +// " WHERE room_nid = $1 AND target_nid = $2" +// "SELECT event_nid FROM roomserver_membership" + +// " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false" const selectMembershipsFromRoomAndMembershipSQL = "" + - "SELECT event_nid FROM roomserver_membership" + - " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false" + "select * from c where c._cn = @x1 " + + " and c.mx_roomserver_membership.room_nid = @x2" + + " and c.mx_roomserver_membership.membership_nid = @x3" + + " and c.mx_roomserver_membership.forgotten = false" +// "SELECT event_nid FROM roomserver_membership" + +// " WHERE room_nid = $1 AND membership_nid = $2" + +// " AND target_local = true and forgotten = false" const selectLocalMembershipsFromRoomAndMembershipSQL = "" + - "SELECT event_nid FROM roomserver_membership" + - " WHERE room_nid = $1 AND membership_nid = $2" + - " AND target_local = true and forgotten = false" + "select * from c where c._cn = @x1 " + + " and c.mx_roomserver_membership.room_nid = @x2" + + " and c.mx_roomserver_membership.membership_nid = @x3" + + " and c.mx_roomserver_membership.target_local = true" + + " and c.mx_roomserver_membership.forgotten = false" +// "SELECT event_nid FROM roomserver_membership" + +// " WHERE room_nid = $1 and forgotten = false" const selectMembershipsFromRoomSQL = "" + - "SELECT event_nid FROM roomserver_membership" + - " WHERE room_nid = $1 and forgotten = false" + "select * from c where c._cn = @x1 " + + " and c.mx_roomserver_membership.room_nid = @x2" + + " and c.mx_roomserver_membership.forgotten = false" +// "SELECT event_nid FROM roomserver_membership" + +// " WHERE room_nid = $1" + +// " AND target_local = true and forgotten = false" const selectLocalMembershipsFromRoomSQL = "" + - "SELECT event_nid FROM roomserver_membership" + - " WHERE room_nid = $1" + - " AND target_local = true and forgotten = false" + "select * from c where c._cn = @x1 " + + " and c.mx_roomserver_membership.room_nid = @x2" + + " and c.mx_roomserver_membership.target_local = true" + + " and c.mx_roomserver_membership.forgotten = false" -const selectMembershipForUpdateSQL = "" + - "SELECT membership_nid FROM roomserver_membership" + - " WHERE room_nid = $1 AND target_nid = $2" +// const selectMembershipForUpdateSQL = "" + +// "SELECT membership_nid FROM roomserver_membership" + +// " WHERE room_nid = $1 AND target_nid = $2" -const updateMembershipSQL = "" + - "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" + - " WHERE room_nid = $5 AND target_nid = $6" +// const updateMembershipSQL = "" + +// "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" + +// " WHERE room_nid = $5 AND target_nid = $6" -const updateMembershipForgetRoom = "" + - "UPDATE roomserver_membership SET forgotten = $1" + - " WHERE room_nid = $2 AND target_nid = $3" +// const updateMembershipForgetRoom = "" + +// "UPDATE roomserver_membership SET forgotten = $1" + +// " WHERE room_nid = $2 AND target_nid = $3" +// "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false" const selectRoomsWithMembershipSQL = "" + - "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false" + "select * from c where c._cn = @x1 " + + " and c.mx_roomserver_membership.membership_nid = @x2" + + " and c.mx_roomserver_membership.target_nid = true" + + " and c.mx_roomserver_membership.forgotten = false" // selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is // joined to. Since this information is used to populate the user directory, we will // only return users that the user would ordinarily be able to see anyway. -var selectKnownUsersSQL = "" + - "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " + - "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + - " WHERE room_nid IN (" + - " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + - ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3" +// var selectKnownUsersSQL = "" + +// "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " + +// "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + +// " WHERE room_nid IN (" + +// " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + +// ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3" + +var selectKnownUsersSQLRooms = "" + + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_membership.room_id)" + +var selectKnownUsersSQLDistinctRoom = "" + + "select distinct top @x4 c.mx_roomserver_membership.room_nid as room_nid from c where c._cn = @x1 " + + "and c.mx_roomserver_membership.target_nid = @x2 " + + "and c.mx_roomserver_membership.membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " " + + "and contains(c.mx_roomserver_membership.event_state_key, @x3) " type membershipStatements struct { - db *sql.DB - insertMembershipStmt *sql.Stmt - selectMembershipForUpdateStmt *sql.Stmt - selectMembershipFromRoomAndTargetStmt *sql.Stmt - selectMembershipsFromRoomAndMembershipStmt *sql.Stmt - selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt - selectMembershipsFromRoomStmt *sql.Stmt - selectLocalMembershipsFromRoomStmt *sql.Stmt - selectRoomsWithMembershipStmt *sql.Stmt - updateMembershipStmt *sql.Stmt - selectKnownUsersStmt *sql.Stmt - updateMembershipForgetRoomStmt *sql.Stmt + db *Database + // insertMembershipStmt *sql.Stmt + // selectMembershipForUpdateStmt string + // selectMembershipFromRoomAndTargetStmt string + selectMembershipsFromRoomAndMembershipStmt string + selectLocalMembershipsFromRoomAndMembershipStmt string + selectMembershipsFromRoomStmt string + selectLocalMembershipsFromRoomStmt string + selectRoomsWithMembershipStmt string + // updateMembershipStmt *sql.Stmt + // selectKnownUsersStmt string + // updateMembershipForgetRoomStmt *sql.Stmt + tableName string } -func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { +func queryMembership(s *membershipStatements, ctx context.Context, qry string, params map[string]interface{}) ([]MembershipCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []MembershipCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func getMembership(s *membershipStatements, ctx context.Context, pk string, docId string) (*MembershipCosmosData, error) { + response := MembershipCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, cosmosdbutil.ErrNoRows + } + + return &response, err +} + +func setMembership(s *membershipStatements, ctx context.Context, pk string, membership MembershipCosmosData) (*MembershipCosmosData, error) { + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, membership.ETag) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + membership.Id, + &membership, + optionsReplace) + return &membership, ex +} + +func NewCosmosDBMembershipTable(db *Database) (tables.Membership, error) { s := &membershipStatements{ db: db, } - return s, shared.StatementList{ - {&s.insertMembershipStmt, insertMembershipSQL}, - {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, - {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, - {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL}, - {&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL}, - {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, - {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, - {&s.updateMembershipStmt, updateMembershipSQL}, - {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, - {&s.selectKnownUsersStmt, selectKnownUsersSQL}, - {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, - }.Prepare(db) -} + // return s, shared.StatementList{ + // {&s.insertMembershipStmt, insertMembershipSQL}, + // s.selectMembershipForUpdateStmt = selectMembershipForUpdateSQL + // s.selectMembershipFromRoomAndTargetStmt = selectMembershipFromRoomAndTargetSQL + s.selectMembershipsFromRoomAndMembershipStmt = selectMembershipsFromRoomAndMembershipSQL + s.selectLocalMembershipsFromRoomAndMembershipStmt = selectLocalMembershipsFromRoomAndMembershipSQL + s.selectMembershipsFromRoomStmt = selectMembershipsFromRoomSQL + s.selectLocalMembershipsFromRoomStmt = selectLocalMembershipsFromRoomSQL + // {&s.updateMembershipStmt, updateMembershipSQL}, + s.selectRoomsWithMembershipStmt = selectRoomsWithMembershipSQL + // {&s.selectKnownUsersStmt, selectKnownUsersSQL}, + // {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, + // }.Prepare(db) -func (s *membershipStatements) execSchema(db *sql.DB) error { - _, err := db.Exec(membershipSchema) - return err + s.tableName = "memberships" + return s, nil } func (s *membershipStatements) InsertMembership( @@ -145,8 +253,45 @@ func (s *membershipStatements) InsertMembership( roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool, ) error { - stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) - _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) + + // "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" + + // " VALUES ($1, $2, $3)" + + // " ON CONFLICT DO NOTHING" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + + // UNIQUE (room_nid, target_nid) + docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + data := MembershipCosmos{ + EventNID: 0, + Forgotten: false, + MembershipNID: 1, + RoomNID: int64(roomNID), + SenderNID: 0, + TargetLocal: false, + TargetNID: int64(targetUserNID), + } + + var dbData = MembershipCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + Membership: data, + } + + // " ON CONFLICT DO NOTHING" + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + _, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + return err } @@ -154,10 +299,18 @@ func (s *membershipStatements) SelectMembershipForUpdate( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (membership tables.MembershipState, err error) { - stmt := sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt) - err = stmt.QueryRowContext( - ctx, roomNID, targetUserNID, - ).Scan(&membership) + + // "SELECT membership_nid FROM roomserver_membership" + + // " WHERE room_nid = $1 AND target_nid = $2" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + response, err := getMembership(s, ctx, pk, cosmosDocId) + if response != nil { + membership = tables.MembershipState(response.Membership.MembershipNID) + } return } @@ -165,9 +318,20 @@ func (s *membershipStatements) SelectMembershipFromRoomAndTarget( ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { - err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( - ctx, roomNID, targetUserNID, - ).Scan(&membership, &eventNID, &forgotten) + + // "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" + + // " WHERE room_nid = $1 AND target_nid = $2" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + response, err := getMembership(s, ctx, pk, cosmosDocId) + if response != nil { + eventNID = types.EventNID(response.Membership.EventNID) + forgotten = response.Membership.Forgotten + membership = tables.MembershipState(response.Membership.MembershipNID) + } return } @@ -175,24 +339,31 @@ func (s *membershipStatements) SelectMembershipsFromRoom( ctx context.Context, roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { - var selectStmt *sql.Stmt + var selectStmt string + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomNID, + } if localOnly { + // "SELECT event_nid FROM roomserver_membership" + + // " WHERE room_nid = $1" + + // " AND target_local = true and forgotten = false" selectStmt = s.selectLocalMembershipsFromRoomStmt + } else { + // "SELECT event_nid FROM roomserver_membership" + + // " WHERE room_nid = $1 and forgotten = false" selectStmt = s.selectMembershipsFromRoomStmt } - rows, err := selectStmt.QueryContext(ctx, roomNID) + response, err := queryMembership(s, ctx, selectStmt, params) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoom: rows.close() failed") - for rows.Next() { - var eNID types.EventNID - if err = rows.Scan(&eNID); err != nil { - return - } - eventNIDs = append(eventNIDs, eNID) + for _, item := range response { + eventNIDs = append(eventNIDs, types.EventNID(item.Membership.EventNID)) } return } @@ -201,24 +372,31 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( ctx context.Context, roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { - var stmt *sql.Stmt + var stmt string + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomNID, + "@x3": membership, + } if localOnly { + // "SELECT event_nid FROM roomserver_membership" + + // " WHERE room_nid = $1 AND membership_nid = $2" + + // " AND target_local = true and forgotten = false" stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt } else { + // "SELECT event_nid FROM roomserver_membership" + + // " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false" stmt = s.selectMembershipsFromRoomAndMembershipStmt } - rows, err := stmt.QueryContext(ctx, roomNID, membership) + response, err := queryMembership(s, ctx, stmt, params) if err != nil { - return + return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoomAndMembership: rows.close() failed") - for rows.Next() { - var eNID types.EventNID - if err = rows.Scan(&eNID); err != nil { - return - } - eventNIDs = append(eventNIDs, eNID) + for _, item := range response { + eventNIDs = append(eventNIDs, types.EventNID(item.Membership.EventNID)) } return } @@ -228,28 +406,48 @@ func (s *membershipStatements) UpdateMembership( roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, forgotten bool, ) error { - stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) - _, err := stmt.ExecContext( - ctx, senderUserNID, membership, eventNID, forgotten, roomNID, targetUserNID, - ) + + // "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" + + // " WHERE room_nid = $5 AND target_nid = $6" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + dbData, err := getMembership(s, ctx, pk, cosmosDocId) + + if err != nil { + return err + } + + dbData.Membership.SenderNID = int64(senderUserNID) + dbData.Membership.MembershipNID = int64(membership) + dbData.Membership.EventNID = int64(eventNID) + dbData.Membership.Forgotten = forgotten + + _, err = setMembership(s, ctx, pk, *dbData) return err } func (s *membershipStatements) SelectRoomsWithMembership( ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, ) ([]types.RoomNID, error) { - rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) + + // "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": membershipState, + "@x3": userID, + } + response, err := queryMembership(s, ctx, s.selectRoomsWithMembershipStmt, params) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed") var roomNIDs []types.RoomNID - for rows.Next() { - var roomNID types.RoomNID - if err := rows.Scan(&roomNID); err != nil { - return nil, err - } - roomNIDs = append(roomNIDs, roomNID) + for _, item := range response { + roomNIDs = append(roomNIDs, types.RoomNID(item.Membership.RoomNID)) } return roomNIDs, nil } @@ -259,39 +457,136 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, for i, v := range roomNIDs { iRoomNIDs[i] = v } - query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) - rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...) + + // "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" + + // " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + + // " GROUP BY target_nid" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomNIDs, + } + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []MembershipJoinedCountCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(selectJoinedUsersSetForRoomsSQL, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed") + result := make(map[types.EventStateKeyNID]int) - for rows.Next() { - var userID types.EventStateKeyNID - var count int - if err := rows.Scan(&userID, &count); err != nil { - return nil, err - } + for _, item := range response { + userID := types.EventStateKeyNID(item.TargetNID) + count := item.RoomCount result[userID] = count } - return result, rows.Err() + return result, nil } func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { - rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) + + // " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + + // ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": userID, + "@x3": searchString, + "@x4": limit, + } + + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var responseDistinctRoom []MembershipCosmos + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(selectKnownUsersSQLDistinctRoom, params) // + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &responseDistinctRoom, + optionsQry) + if err != nil { return nil, err } + + rooms := []int64{} + for _, item := range responseDistinctRoom { + rooms = append(rooms, item.RoomNID) + } + + // "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " + + // "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + + // " WHERE room_nid IN (" + + + params = map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": rooms, + } + + var responseRooms []MembershipCosmos + query = cosmosdbapi.GetQuery(selectKnownUsersSQLRooms, params) + _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &responseRooms, + optionsQry) + + if err != nil { + return nil, err + } + + targetNIDs := []int64{} + for _, item := range responseRooms { + targetNIDs = append(targetNIDs, item.TargetNID) + } + + // HACK: Joined table + var dbCollectionNameEventStateKeys = cosmosdbapi.GetCollectionName(s.db.databaseName, "event_state_keys") + params = map[string]interface{}{ + "@x1": dbCollectionNameEventStateKeys, + "@x2": targetNIDs, + } + + bulkSelectEventStateKeyStmt := "select * from c where c._cn = @x1 and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_state_keys.event_state_key_nid)" + + var responseEventStateKeys []EventStateKeysCosmos + query = cosmosdbapi.GetQuery(bulkSelectEventStateKeyStmt, params) + _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &responseEventStateKeys, + optionsQry) + + if err != nil { + return nil, err + } + + // SELECT DISTINCT event_state_key + result := []string{} - defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed") - for rows.Next() { - var userID string - if err := rows.Scan(&userID); err != nil { - return nil, err - } + for _, item := range responseEventStateKeys { + userID := item.EventStateKey result = append(result, userID) } - return result, rows.Err() + return result, nil } func (s *membershipStatements) UpdateForgetMembership( @@ -299,8 +594,22 @@ func (s *membershipStatements) UpdateForgetMembership( txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool, ) error { - _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( - ctx, forget, roomNID, targetUserNID, - ) + + // "UPDATE roomserver_membership SET forgotten = $1" + + // " WHERE room_nid = $2 AND target_nid = $3" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + dbData, err := getMembership(s, ctx, pk, cosmosDocId) + + if err != nil { + return err + } + + dbData.Membership.Forgotten = forget + + _, err = setMembership(s, ctx, pk, *dbData) return err } diff --git a/roomserver/storage/cosmosdb/previous_events_table.go b/roomserver/storage/cosmosdb/previous_events_table.go index 1062ab1cf..0e92cf99e 100644 --- a/roomserver/storage/cosmosdb/previous_events_table.go +++ b/roomserver/storage/cosmosdb/previous_events_table.go @@ -20,9 +20,12 @@ import ( "database/sql" "fmt" "strings" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -32,59 +35,90 @@ import ( // In Postgres an empty BYTEA field is not NULL so it's fine there. In SQLite it // seems to care that it's empty and therefore hits a NOT NULL constraint on insert. // We should really work out what the right thing to do here is. -const previousEventSchema = ` - CREATE TABLE IF NOT EXISTS roomserver_previous_events ( - previous_event_id TEXT NOT NULL, - previous_reference_sha256 BLOB, - event_nids TEXT NOT NULL, - UNIQUE (previous_event_id, previous_reference_sha256) - ); -` +// const previousEventSchema = ` +// CREATE TABLE IF NOT EXISTS roomserver_previous_events ( +// previous_event_id TEXT NOT NULL, +// previous_reference_sha256 BLOB, +// event_nids TEXT NOT NULL, +// UNIQUE (previous_event_id, previous_reference_sha256) +// ); +// ` + +type PreviousEventCosmos struct { + PreviousEventID string `json:"previous_event_id"` + PreviousReferenceSha256 []byte `json:"previous_reference_sha256"` + EventNIDs string `json:"event_nids"` +} + +type PreviousEventCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + PreviousEvent PreviousEventCosmos `json:"mx_roomserver_previous_event"` +} // Insert an entry into the previous_events table. // If there is already an entry indicating that an event references that previous event then // add the event NID to the list to indicate that this event references that previous event as well. // This should only be modified while holding a "FOR UPDATE" lock on the row in the rooms table for this room. // The lock is necessary to avoid data races when checking whether an event is already referenced by another event. -const insertPreviousEventSQL = ` - INSERT OR REPLACE INTO roomserver_previous_events - (previous_event_id, previous_reference_sha256, event_nids) - VALUES ($1, $2, $3) -` +// const insertPreviousEventSQL = ` +// INSERT OR REPLACE INTO roomserver_previous_events +// (previous_event_id, previous_reference_sha256, event_nids) +// VALUES ($1, $2, $3) +// ` -const selectPreviousEventNIDsSQL = ` - SELECT event_nids FROM roomserver_previous_events - WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 -` +// const selectPreviousEventNIDsSQL = ` +// SELECT event_nids FROM roomserver_previous_events +// WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 +// ` // Check if the event is referenced by another event in the table. // This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room. -const selectPreviousEventExistsSQL = ` - SELECT 1 FROM roomserver_previous_events - WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 -` +// const selectPreviousEventExistsSQL = ` +// SELECT 1 FROM roomserver_previous_events +// WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 +// ` type previousEventStatements struct { - db *sql.DB - insertPreviousEventStmt *sql.Stmt - selectPreviousEventNIDsStmt *sql.Stmt - selectPreviousEventExistsStmt *sql.Stmt + db *Database + // insertPreviousEventStmt *sql.Stmt + // selectPreviousEventNIDsStmt *sql.Stmt + // selectPreviousEventExistsStmt *sql.Stmt + tableName string } -func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { +func getPreviousEvent(s *previousEventStatements, ctx context.Context, pk string, docId string) (*PreviousEventCosmosData, error) { + response := PreviousEventCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, cosmosdbutil.ErrNoRows + } + + return &response, err +} + +func NewCosmosDBPrevEventsTable(db *Database) (tables.PreviousEvents, error) { s := &previousEventStatements{ db: db, } - _, err := db.Exec(previousEventSchema) - if err != nil { - return nil, err - } - return s, shared.StatementList{ - {&s.insertPreviousEventStmt, insertPreviousEventSQL}, - {&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL}, - {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, - }.Prepare(db) + // return s, shared.StatementList{ + // {&s.insertPreviousEventStmt, insertPreviousEventSQL}, + // {&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL}, + // {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, + // }.Prepare(db) + s.tableName = "previous_events" + return s, nil } func (s *previousEventStatements) InsertPreviousEvent( @@ -94,28 +128,71 @@ func (s *previousEventStatements) InsertPreviousEvent( previousEventReferenceSHA256 []byte, eventNID types.EventNID, ) error { - var eventNIDs string eventNIDAsString := fmt.Sprintf("%d", eventNID) - selectStmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) - err := selectStmt.QueryRowContext(ctx, previousEventID, previousEventReferenceSHA256).Scan(&eventNIDs) - if err != nil && err != sql.ErrNoRows { - return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + + // UNIQUE (previous_event_id, previous_reference_sha256) + // TODO: Check value + // docId := fmt.Sprintf("%s_%s", previousEventID, previousEventReferenceSHA256) + docId := previousEventID + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + // SELECT 1 FROM roomserver_previous_events + // WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 + existing, err := getPreviousEvent(s, ctx, pk, cosmosDocId) + + if err != nil { + if err != cosmosdbutil.ErrNoRows { + return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err) + } } + + var dbData PreviousEventCosmosData + // Doesnt exist, create a new one + if existing == nil { + data := PreviousEventCosmos{ + EventNIDs: "", + PreviousEventID: previousEventID, + PreviousReferenceSha256: previousEventReferenceSHA256, + } + + dbData = PreviousEventCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + PreviousEvent: data, + } + } else { + dbData = *existing + } + var nids []string - if eventNIDs != "" { - nids = strings.Split(eventNIDs, ",") + if dbData.PreviousEvent.EventNIDs != "" { + nids = strings.Split(dbData.PreviousEvent.EventNIDs, ",") for _, nid := range nids { if nid == eventNIDAsString { return nil } } - eventNIDs = strings.Join(append(nids, eventNIDAsString), ",") + dbData.PreviousEvent.EventNIDs = strings.Join(append(nids, eventNIDAsString), ",") } else { - eventNIDs = eventNIDAsString + dbData.PreviousEvent.EventNIDs = eventNIDAsString } - insertStmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) - _, err = insertStmt.ExecContext( - ctx, previousEventID, previousEventReferenceSHA256, eventNIDs, + + // INSERT OR REPLACE INTO roomserver_previous_events + // (previous_event_id, previous_reference_sha256, event_nids) + // VALUES ($1, $2, $3) + + var optionsReplace = cosmosdbapi.GetUpsertDocumentOptions(pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + optionsReplace, ) return err } @@ -125,7 +202,24 @@ func (s *previousEventStatements) InsertPreviousEvent( func (s *previousEventStatements) SelectPreviousEventExists( ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, ) error { - var ok int64 - stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) - return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + + // UNIQUE (previous_event_id, previous_reference_sha256) + // TODO: Check value + // docId := fmt.Sprintf("%s_%s", previousEventID, previousEventReferenceSHA256) + docId := eventID + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, string(docId)) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + // SELECT 1 FROM roomserver_previous_events + // WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 + dbData, err := getPreviousEvent(s, ctx, pk, cosmosDocId) + if err != nil { + return err + } + + if dbData == nil { + return cosmosdbutil.ErrNoRows + } + return nil } diff --git a/roomserver/storage/cosmosdb/published_table.go b/roomserver/storage/cosmosdb/published_table.go index 0de948aa6..0d26faeeb 100644 --- a/roomserver/storage/cosmosdb/published_table.go +++ b/roomserver/storage/cosmosdb/published_table.go @@ -17,89 +17,199 @@ package cosmosdb import ( "context" "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) -const publishedSchema = ` --- Stores which rooms are published in the room directory -CREATE TABLE IF NOT EXISTS roomserver_published ( - -- The room ID of the room - room_id TEXT NOT NULL PRIMARY KEY, - -- Whether it is published or not - published BOOLEAN NOT NULL DEFAULT false -); -` +// const publishedSchema = ` +// -- Stores which rooms are published in the room directory +// CREATE TABLE IF NOT EXISTS roomserver_published ( +// -- The room ID of the room +// room_id TEXT NOT NULL PRIMARY KEY, +// -- Whether it is published or not +// published BOOLEAN NOT NULL DEFAULT false +// ); +// ` -const upsertPublishedSQL = "" + - "INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)" - -const selectAllPublishedSQL = "" + - "SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC" - -const selectPublishedSQL = "" + - "SELECT published FROM roomserver_published WHERE room_id = $1" - -type publishedStatements struct { - db *sql.DB - upsertPublishedStmt *sql.Stmt - selectAllPublishedStmt *sql.Stmt - selectPublishedStmt *sql.Stmt +type PublishCosmos struct { + RoomID string `json:"room_id"` + Published bool `json:"published"` } -func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) { - s := &publishedStatements{ - db: db, - } - _, err := db.Exec(publishedSchema) +type PublishCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Publish PublishCosmos `json:"mx_roomserver_publish"` +} + +// const upsertPublishedSQL = "" + +// "INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)" + +// "SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC" +const selectAllPublishedSQL = "" + + "select * from c where c._cn = @x1 " + + " and c.mx_roomserver_publish.published = @x2" + + " order by c.mx_roomserver_publish.room_id asc" + +// const selectPublishedSQL = "" + +// "SELECT published FROM roomserver_published WHERE room_id = $1" + +type publishedStatements struct { + db *Database + // upsertPublishedStmt *sql.Stmt + selectAllPublishedStmt string + // selectPublishedStmt *sql.Stmt + tableName string +} + +func queryPublish(s *publishedStatements, ctx context.Context, qry string, params map[string]interface{}) ([]PublishCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []PublishCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } - return s, shared.StatementList{ - {&s.upsertPublishedStmt, upsertPublishedSQL}, - {&s.selectAllPublishedStmt, selectAllPublishedSQL}, - {&s.selectPublishedStmt, selectPublishedSQL}, - }.Prepare(db) + return response, nil +} + +func getPublish(s *publishedStatements, ctx context.Context, pk string, docId string) (*PublishCosmosData, error) { + response := PublishCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, cosmosdbutil.ErrNoRows + } + + return &response, err +} + +func setPublish(s *publishedStatements, ctx context.Context, pk string, publish PublishCosmosData) (*PublishCosmosData, error) { + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, publish.ETag) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + publish.Id, + &publish, + optionsReplace) + return &publish, ex +} + +func NewCosmosDBPublishedTable(db *Database) (tables.Published, error) { + s := &publishedStatements{ + db: db, + } + // _, err := db.Exec(publishedSchema) + // if err != nil { + // return nil, err + // } + // return s, shared.StatementList{ + // {&s.upsertPublishedStmt, upsertPublishedSQL}, + s.selectAllPublishedStmt = selectAllPublishedSQL + // {&s.selectPublishedStmt, selectPublishedSQL}, + // }.Prepare(db) + s.tableName = "published" + return s, nil } func (s *publishedStatements) UpsertRoomPublished( ctx context.Context, txn *sql.Tx, roomID string, published bool, ) error { - stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt) - _, err := stmt.ExecContext(ctx, roomID, published) + + // "INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // room_id TEXT NOT NULL PRIMARY KEY, + docId := roomID + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + data := PublishCosmos{ + RoomID: roomID, + Published: false, + } + + var dbData = PublishCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + Publish: data, + } + + // "INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)" + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + _, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + return err } func (s *publishedStatements) SelectPublishedFromRoomID( ctx context.Context, roomID string, ) (published bool, err error) { - err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published) - if err == sql.ErrNoRows { - return false, nil + + // "SELECT published FROM roomserver_published WHERE room_id = $1" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // room_id TEXT NOT NULL PRIMARY KEY, + docId := roomID + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + response, err := getPublish(s, ctx, pk, cosmosDocId) + if err != nil { + return false, err } - return + return response.Publish.Published, nil } func (s *publishedStatements) SelectAllPublishedRooms( ctx context.Context, published bool, ) ([]string, error) { - rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published) + + // "SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": published, + } + + response, err := queryPublish(s, ctx, s.selectAllPublishedStmt, params) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectAllPublishedStmt: rows.close() failed") var roomIDs []string - for rows.Next() { - var roomID string - if err = rows.Scan(&roomID); err != nil { - return nil, err - } - - roomIDs = append(roomIDs, roomID) + for _, item := range response { + roomIDs = append(roomIDs, item.Publish.RoomID) } - return roomIDs, rows.Err() + return roomIDs, nil } diff --git a/roomserver/storage/cosmosdb/redactions_table.go b/roomserver/storage/cosmosdb/redactions_table.go index 0d2ee27eb..91ed95d2d 100644 --- a/roomserver/storage/cosmosdb/redactions_table.go +++ b/roomserver/storage/cosmosdb/redactions_table.go @@ -17,84 +17,207 @@ package cosmosdb import ( "context" "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) -const redactionsSchema = ` --- Stores information about the redacted state of events. --- We need to track redactions rather than blindly updating the event JSON table on receipt of a redaction --- because we might receive the redaction BEFORE we receive the event which it redacts (think backfill). -CREATE TABLE IF NOT EXISTS roomserver_redactions ( - redaction_event_id TEXT PRIMARY KEY, - redacts_event_id TEXT NOT NULL, - -- Initially FALSE, set to TRUE when the redaction has been validated according to rooms v3+ spec - -- https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events - validated BOOLEAN NOT NULL -); -` +// const redactionsSchema = ` +// -- Stores information about the redacted state of events. +// -- We need to track redactions rather than blindly updating the event JSON table on receipt of a redaction +// -- because we might receive the redaction BEFORE we receive the event which it redacts (think backfill). +// CREATE TABLE IF NOT EXISTS roomserver_redactions ( +// redaction_event_id TEXT PRIMARY KEY, +// redacts_event_id TEXT NOT NULL, +// -- Initially FALSE, set to TRUE when the redaction has been validated according to rooms v3+ spec +// -- https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events +// validated BOOLEAN NOT NULL +// ); +// ` -const insertRedactionSQL = "" + - "INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" + - " VALUES ($1, $2, $3)" - -const selectRedactionInfoByRedactionEventIDSQL = "" + - "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" + - " WHERE redaction_event_id = $1" - -const selectRedactionInfoByEventBeingRedactedSQL = "" + - "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" + - " WHERE redacts_event_id = $1" - -const markRedactionValidatedSQL = "" + - " UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1" - -type redactionStatements struct { - db *sql.DB - insertRedactionStmt *sql.Stmt - selectRedactionInfoByRedactionEventIDStmt *sql.Stmt - selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt - markRedactionValidatedStmt *sql.Stmt +type RedactionCosmos struct { + RedactionEventID string `json:"redaction_event_id"` + RedactsEventID string `json:"redacts_event_id"` + Validated bool `json:"validated"` } -func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) { - s := &redactionStatements{ - db: db, - } - _, err := db.Exec(redactionsSchema) +type RedactionCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Redaction RedactionCosmos `json:"mx_roomserver_redaction"` +} + +// const insertRedactionSQL = "" + +// "INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" + +// " VALUES ($1, $2, $3)" + +// const selectRedactionInfoByRedactionEventIDSQL = "" + +// "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" + +// " WHERE redaction_event_id = $1" + +// "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" + +// " WHERE redacts_event_id = $1" +const selectRedactionInfoByEventBeingRedactedSQL = "" + + "select * from c where c._cn = @x1 " + + " and c.mx_roomserver_redaction.redacts_event_id = @x2" + +// const markRedactionValidatedSQL = "" + +// " UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1" + +type redactionStatements struct { + db *Database + // insertRedactionStmt *sql.Stmt + // selectRedactionInfoByRedactionEventIDStmt *sql.Stmt + selectRedactionInfoByEventBeingRedactedStmt string + // markRedactionValidatedStmt *sql.Stmt + tableName string +} + +func queryRedaction(s *redactionStatements, ctx context.Context, qry string, params map[string]interface{}) ([]RedactionCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []RedactionCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } + return response, nil +} - return s, shared.StatementList{ - {&s.insertRedactionStmt, insertRedactionSQL}, - {&s.selectRedactionInfoByRedactionEventIDStmt, selectRedactionInfoByRedactionEventIDSQL}, - {&s.selectRedactionInfoByEventBeingRedactedStmt, selectRedactionInfoByEventBeingRedactedSQL}, - {&s.markRedactionValidatedStmt, markRedactionValidatedSQL}, - }.Prepare(db) +func getRedaction(s *redactionStatements, ctx context.Context, pk string, docId string) (*RedactionCosmosData, error) { + response := RedactionCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, cosmosdbutil.ErrNoRows + } + + return &response, err +} + +func setRedaction(s *redactionStatements, ctx context.Context, pk string, redaction RedactionCosmosData) (*RedactionCosmosData, error) { + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, redaction.ETag) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + redaction.Id, + &redaction, + optionsReplace) + return &redaction, ex +} + +func NewCosmosDBRedactionsTable(db *Database) (tables.Redactions, error) { + s := &redactionStatements{ + db: db, + } + + // return s, shared.StatementList{ + // {&s.insertRedactionStmt, insertRedactionSQL}, + // {&s.selectRedactionInfoByRedactionEventIDStmt, selectRedactionInfoByRedactionEventIDSQL}, + s.selectRedactionInfoByEventBeingRedactedStmt = selectRedactionInfoByEventBeingRedactedSQL + // {&s.markRedactionValidatedStmt, markRedactionValidatedSQL}, + // }.Prepare(db) + s.tableName = "redactions" + return s, nil } func (s *redactionStatements) InsertRedaction( ctx context.Context, txn *sql.Tx, info tables.RedactionInfo, ) error { - stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt) - _, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated) - return err + + // "INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" + + // " VALUES ($1, $2, $3)" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // redaction_event_id TEXT PRIMARY KEY, + docId := info.RedactionEventID + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + data := RedactionCosmos{ + RedactionEventID: info.RedactionEventID, + RedactsEventID: info.RedactsEventID, + Validated: info.Validated, + } + + var dbData = RedactionCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + Redaction: data, + } + + // "INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" + + var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + _, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + + // TODO: Just forDebug - Remove exception + if err != nil { + return err + } + + //Ignore Error + return nil } func (s *redactionStatements) SelectRedactionInfoByRedactionEventID( ctx context.Context, txn *sql.Tx, redactionEventID string, ) (info *tables.RedactionInfo, err error) { info = &tables.RedactionInfo{} - stmt := sqlutil.TxStmt(txn, s.selectRedactionInfoByRedactionEventIDStmt) - err = stmt.QueryRowContext(ctx, redactionEventID).Scan( - &info.RedactionEventID, &info.RedactsEventID, &info.Validated, - ) - if err == sql.ErrNoRows { + + // "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" + + // " WHERE redaction_event_id = $1" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // redaction_event_id TEXT PRIMARY KEY, + docId := redactionEventID + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + response, err := getRedaction(s, ctx, pk, cosmosDocId) + if err != nil { + info = nil + err = err + return + } + + if response == nil { info = nil err = nil + return + } + info = &tables.RedactionInfo{ + RedactionEventID: response.Redaction.RedactionEventID, + RedactsEventID: response.Redaction.RedactsEventID, + Validated: response.Redaction.Validated, } return } @@ -102,14 +225,31 @@ func (s *redactionStatements) SelectRedactionInfoByRedactionEventID( func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted( ctx context.Context, txn *sql.Tx, eventID string, ) (info *tables.RedactionInfo, err error) { - info = &tables.RedactionInfo{} - stmt := sqlutil.TxStmt(txn, s.selectRedactionInfoByEventBeingRedactedStmt) - err = stmt.QueryRowContext(ctx, eventID).Scan( - &info.RedactionEventID, &info.RedactsEventID, &info.Validated, - ) - if err == sql.ErrNoRows { + + // "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" + + // " WHERE redacts_event_id = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventID, + } + response, err := queryRedaction(s, ctx, s.selectRedactionInfoByEventBeingRedactedStmt, params) + + if err != nil { + return nil, err + } + + if len(response) == 0 { info = nil err = nil + return + } + // TODO: Check this is ok to return the 1st one + *info = tables.RedactionInfo{ + RedactionEventID: response[0].Redaction.RedactionEventID, + RedactsEventID: response[0].Redaction.RedactsEventID, + Validated: response[0].Redaction.Validated, } return } @@ -117,7 +257,22 @@ func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted( func (s *redactionStatements) MarkRedactionValidated( ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool, ) error { - stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt) - _, err := stmt.ExecContext(ctx, redactionEventID, validated) + + // " UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // redaction_event_id TEXT PRIMARY KEY, + docId := redactionEventID + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + response, err := getRedaction(s, ctx, pk, cosmosDocId) + if err != nil { + return err + } + + response.Redaction.Validated = validated + + _, err = setRedaction(s, ctx, pk, *response) return err } diff --git a/roomserver/storage/cosmosdb/room_aliases_table.go b/roomserver/storage/cosmosdb/room_aliases_table.go index 3592257ba..e79f0ca14 100644 --- a/roomserver/storage/cosmosdb/room_aliases_table.go +++ b/roomserver/storage/cosmosdb/room_aliases_table.go @@ -18,84 +18,185 @@ package cosmosdb import ( "context" "database/sql" + "time" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) -const roomAliasesSchema = ` - CREATE TABLE IF NOT EXISTS roomserver_room_aliases ( - alias TEXT NOT NULL PRIMARY KEY, - room_id TEXT NOT NULL, - creator_id TEXT NOT NULL - ); +// const roomAliasesSchema = ` +// CREATE TABLE IF NOT EXISTS roomserver_room_aliases ( +// alias TEXT NOT NULL PRIMARY KEY, +// room_id TEXT NOT NULL, +// creator_id TEXT NOT NULL +// ); - CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id); -` +// CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id); +// ` -const insertRoomAliasSQL = ` - INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3) -` - -const selectRoomIDFromAliasSQL = ` - SELECT room_id FROM roomserver_room_aliases WHERE alias = $1 -` - -const selectAliasesFromRoomIDSQL = ` - SELECT alias FROM roomserver_room_aliases WHERE room_id = $1 -` - -const selectCreatorIDFromAliasSQL = ` - SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1 -` - -const deleteRoomAliasSQL = ` - DELETE FROM roomserver_room_aliases WHERE alias = $1 -` - -type roomAliasesStatements struct { - db *sql.DB - insertRoomAliasStmt *sql.Stmt - selectRoomIDFromAliasStmt *sql.Stmt - selectAliasesFromRoomIDStmt *sql.Stmt - selectCreatorIDFromAliasStmt *sql.Stmt - deleteRoomAliasStmt *sql.Stmt +type RoomAliasCosmos struct { + Alias string `json:"alias"` + RoomID string `json:"room_id"` + CreatorID string `json:"creator_id"` } -func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { - s := &roomAliasesStatements{ - db: db, - } - _, err := db.Exec(roomAliasesSchema) +type RoomAliasCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + RoomAlias RoomAliasCosmos `json:"mx_roomserver_room_alias"` +} + +// const insertRoomAliasSQL = ` +// INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3) +// ` + +// const selectRoomIDFromAliasSQL = ` +// SELECT room_id FROM roomserver_room_aliases WHERE alias = $1 +// ` + +// SELECT alias FROM roomserver_room_aliases WHERE room_id = $1 +const selectAliasesFromRoomIDSQL = ` + select * from c where c._cn = @x1 and c.mx_roomserver_room_alias.room_id = @x2 +` + +// const selectCreatorIDFromAliasSQL = ` +// SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1 +// ` + +// const deleteRoomAliasSQL = ` +// DELETE FROM roomserver_room_aliases WHERE alias = $1 +// ` + +type roomAliasesStatements struct { + db *Database + // insertRoomAliasStmt *sql.Stmt + // selectRoomIDFromAliasStmt string + selectAliasesFromRoomIDStmt string + // selectCreatorIDFromAliasStmt string + // deleteRoomAliasStmt *sql.Stmt + tableName string +} + +func queryRoomAlias(s *roomAliasesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]RoomAliasCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []RoomAliasCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } - return s, shared.StatementList{ - {&s.insertRoomAliasStmt, insertRoomAliasSQL}, - {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, - {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, - {&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL}, - {&s.deleteRoomAliasStmt, deleteRoomAliasSQL}, - }.Prepare(db) + return response, nil +} + +func getRoomAlias(s *roomAliasesStatements, ctx context.Context, pk string, docId string) (*RoomAliasCosmosData, error) { + response := RoomAliasCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, cosmosdbutil.ErrNoRows + } + + return &response, err +} + +func NewCosmosDBRoomAliasesTable(db *Database) (tables.RoomAliases, error) { + s := &roomAliasesStatements{ + db: db, + } + // _, err := db.Exec(roomAliasesSchema) + // if err != nil { + // return nil, err + // } + // return s, shared.StatementList{ + // {&s.insertRoomAliasStmt, insertRoomAliasSQL}, + // {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, + s.selectAliasesFromRoomIDStmt = selectAliasesFromRoomIDSQL + // {&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL}, + // {&s.deleteRoomAliasStmt, deleteRoomAliasSQL}, + // }.Prepare(db) + s.tableName = "room_aliases" + return s, nil } func (s *roomAliasesStatements) InsertRoomAlias( ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string, ) error { - stmt := sqlutil.TxStmt(txn, s.insertRoomAliasStmt) - _, err := stmt.ExecContext(ctx, alias, roomID, creatorUserID) + + // INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3) + data := RoomAliasCosmos{ + Alias: alias, + CreatorID: creatorUserID, + RoomID: roomID, + } + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + + // alias TEXT NOT NULL PRIMARY KEY, + docId := alias + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + var dbData = RoomAliasCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + RoomAlias: data, + } + + var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + _, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + return err } func (s *roomAliasesStatements) SelectRoomIDFromAlias( ctx context.Context, alias string, ) (roomID string, err error) { - err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) - if err == sql.ErrNoRows { + + // SELECT room_id FROM roomserver_room_aliases WHERE alias = $1 + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + + // alias TEXT NOT NULL PRIMARY KEY, + docId := alias + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + response, err := getRoomAlias(s, ctx, pk, cosmosDocId) + + if err != nil { + return "", err + } + + if response == nil { return "", nil } + roomID = response.RoomAlias.RoomID return } @@ -103,20 +204,23 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID( ctx context.Context, roomID string, ) (aliases []string, err error) { aliases = []string{} - rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) - if err != nil { - return + + // SELECT alias FROM roomserver_room_aliases WHERE room_id = $1 + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, } - defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed") + response, err := queryRoomAlias(s, ctx, s.selectAliasesFromRoomIDStmt, params) - for rows.Next() { - var alias string - if err = rows.Scan(&alias); err != nil { - return - } + if err != nil { + return nil, err + } - aliases = append(aliases, alias) + for _, item := range response { + aliases = append(aliases, item.RoomAlias.Alias) } return @@ -125,17 +229,48 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID( func (s *roomAliasesStatements) SelectCreatorIDFromAlias( ctx context.Context, alias string, ) (creatorID string, err error) { - err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) - if err == sql.ErrNoRows { + + // SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1 + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + + // alias TEXT NOT NULL PRIMARY KEY, + docId := alias + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + response, err := getRoomAlias(s, ctx, pk, cosmosDocId) + + if err != nil { + return "", err + } + + if response == nil { return "", nil } + creatorID = response.RoomAlias.CreatorID return } func (s *roomAliasesStatements) DeleteRoomAlias( ctx context.Context, txn *sql.Tx, alias string, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteRoomAliasStmt) - _, err := stmt.ExecContext(ctx, alias) + + // DELETE FROM roomserver_room_aliases WHERE alias = $1 + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + docId := alias + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var options = cosmosdbapi.GetDeleteDocumentOptions(pk) + var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + cosmosDocId, + options) + + if err != nil { + return err + } return err } diff --git a/roomserver/storage/cosmosdb/room_seq.go b/roomserver/storage/cosmosdb/room_seq.go new file mode 100644 index 000000000..97749fdc1 --- /dev/null +++ b/roomserver/storage/cosmosdb/room_seq.go @@ -0,0 +1,12 @@ +package cosmosdb + +import ( + "context" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" +) + +func GetNextRoomNID(s *roomStatements, ctx context.Context) (int64, error) { + const docId = "roomnid_seq" + return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1) +} diff --git a/roomserver/storage/cosmosdb/rooms_table.go b/roomserver/storage/cosmosdb/rooms_table.go index 8570a9723..3348618ef 100644 --- a/roomserver/storage/cosmosdb/rooms_table.go +++ b/roomserver/storage/cosmosdb/rooms_table.go @@ -18,128 +18,227 @@ package cosmosdb import ( "context" "database/sql" - "encoding/json" "fmt" - "strings" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) -const roomsSchema = ` - CREATE TABLE IF NOT EXISTS roomserver_rooms ( - room_nid INTEGER PRIMARY KEY AUTOINCREMENT, - room_id TEXT NOT NULL UNIQUE, - latest_event_nids TEXT NOT NULL DEFAULT '[]', - last_event_sent_nid INTEGER NOT NULL DEFAULT 0, - state_snapshot_nid INTEGER NOT NULL DEFAULT 0, - room_version TEXT NOT NULL - ); -` +// const roomsSchema = ` +// CREATE TABLE IF NOT EXISTS roomserver_rooms ( +// room_nid INTEGER PRIMARY KEY AUTOINCREMENT, +// room_id TEXT NOT NULL UNIQUE, +// latest_event_nids TEXT NOT NULL DEFAULT '[]', +// last_event_sent_nid INTEGER NOT NULL DEFAULT 0, +// state_snapshot_nid INTEGER NOT NULL DEFAULT 0, +// room_version TEXT NOT NULL +// ); +// ` -// Same as insertEventTypeNIDSQL -const insertRoomNIDSQL = ` - INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2) - ON CONFLICT DO NOTHING; -` - -const selectRoomNIDSQL = "" + - "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" - -const selectLatestEventNIDsSQL = "" + - "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" - -const selectLatestEventNIDsForUpdateSQL = "" + - "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" - -const updateLatestEventNIDsSQL = "" + - "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4" - -const selectRoomVersionsForRoomNIDsSQL = "" + - "SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)" - -const selectRoomInfoSQL = "" + - "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" - -const selectRoomIDsSQL = "" + - "SELECT room_id FROM roomserver_rooms" - -const bulkSelectRoomIDsSQL = "" + - "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" - -const bulkSelectRoomNIDsSQL = "" + - "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)" - -type roomStatements struct { - db *sql.DB - insertRoomNIDStmt *sql.Stmt - selectRoomNIDStmt *sql.Stmt - selectLatestEventNIDsStmt *sql.Stmt - selectLatestEventNIDsForUpdateStmt *sql.Stmt - updateLatestEventNIDsStmt *sql.Stmt - //selectRoomVersionForRoomNIDStmt *sql.Stmt - selectRoomInfoStmt *sql.Stmt - selectRoomIDsStmt *sql.Stmt +type RoomCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Room RoomCosmos `json:"mx_roomserver_room"` } -func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { +type RoomCosmos struct { + RoomNID int64 `json:"room_nid"` + RoomID string `json:"room_id"` + LatestEventNIDs []int64 `json:"latest_event_nids"` + LastEventSentNID int64 `json:"last_event_sent_nid"` + StateSnapshotNID int64 `json:"state_snapshot_nid"` + RoomVersion string `json:"room_version"` +} + +// Same as insertEventTypeNIDSQL +// const insertRoomNIDSQL = ` +// INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2) +// ON CONFLICT DO NOTHING; +// ` + +// "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" +// const selectRoomNIDSQL = "" + +// "select * from c where c._cn = @x1 and c.mx_roomserver_room.room_nid = @x1" + +// "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" +const selectLatestEventNIDsSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_roomserver_room.room_nid = @x2" + +// "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" +const selectLatestEventNIDsForUpdateSQL = "" + + "select * from c where c._cn = @x1 " + + " and c.mx_roomserver_room.room_nid = @x2" + +// const updateLatestEventNIDsSQL = "" + +// "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4" + +// "SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)" +const selectRoomVersionsForRoomNIDsSQL = "" + + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_room.room_nid)" + +// "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" +// const selectRoomInfoSQL = "" + +// "select * from c where c._cn = @x1 and c.mx_roomserver_room.room_id = @x2" + +// "SELECT room_id FROM roomserver_rooms" +const selectRoomIDsSQL = "" + + "select * from c where c._cn = @x1" + +// "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" +const bulkSelectRoomIDsSQL = "" + + "select * from c where c._cn = @x1 " + + " and ARRAY_CONTAINS(@x2, c.mx_roomserver_room.room_nid)" + +// "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)" +const bulkSelectRoomNIDsSQL = "" + + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_room.room_nid)" + +type roomStatements struct { + db *Database + // insertRoomNIDStmt *sql.Stmt + // selectRoomNIDStmt string + selectLatestEventNIDsStmt string + selectLatestEventNIDsForUpdateStmt string + updateLatestEventNIDsStmt string + selectRoomVersionForRoomNIDStmt string + // selectRoomInfoStmt *sql.Stmt + selectRoomIDsStmt string + tableName string +} + +func NewCosmosDBRoomsTable(db *Database) (tables.Rooms, error) { s := &roomStatements{ db: db, } - _, err := db.Exec(roomsSchema) + // return s, shared.StatementList{ + // {&s.insertRoomNIDStmt, insertRoomNIDSQL}, + // {&s.selectRoomNIDStmt, selectRoomNIDSQL}, + s.selectLatestEventNIDsStmt = selectLatestEventNIDsSQL + s.selectLatestEventNIDsForUpdateStmt = selectLatestEventNIDsForUpdateSQL + // {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, + //{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL}, + // {&s.selectRoomInfoStmt, selectRoomInfoSQL}, + s.selectRoomIDsStmt = selectRoomIDsSQL + // }.Prepare(db) + s.tableName = "rooms" + return s, nil +} + +func mapToRoomEventNIDArray(eventNIDs []int64) []types.EventNID { + result := []types.EventNID{} + for i := 0; i < len(eventNIDs); i++ { + result = append(result, types.EventNID(eventNIDs[i])) + } + return result +} + +func queryRoom(s *roomStatements, ctx context.Context, qry string, params map[string]interface{}) ([]RoomCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []RoomCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } - return s, shared.StatementList{ - {&s.insertRoomNIDStmt, insertRoomNIDSQL}, - {&s.selectRoomNIDStmt, selectRoomNIDSQL}, - {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, - {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, - {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, - //{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL}, - {&s.selectRoomInfoStmt, selectRoomInfoSQL}, - {&s.selectRoomIDsStmt, selectRoomIDsSQL}, - }.Prepare(db) + return response, nil +} + +func getRoom(s *roomStatements, ctx context.Context, pk string, docId string) (*RoomCosmosData, error) { + response := RoomCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, cosmosdbutil.ErrNoRows + } + + return &response, err +} + +func setRoom(s *roomStatements, ctx context.Context, pk string, room RoomCosmosData) (*RoomCosmosData, error) { + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, room.ETag) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + room.Id, + &room, + optionsReplace) + return &room, ex } func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { - rows, err := s.selectRoomIDsStmt.QueryContext(ctx) + + // "SELECT room_id FROM roomserver_rooms" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + } + + response, err := queryRoom(s, ctx, s.selectRoomIDsStmt, params) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed") + var roomIDs []string - for rows.Next() { - var roomID string - if err = rows.Scan(&roomID); err != nil { - return nil, err - } - roomIDs = append(roomIDs, roomID) + for _, item := range response { + roomIDs = append(roomIDs, item.Room.RoomID) } return roomIDs, nil } func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { - var info types.RoomInfo - var latestNIDsJSON string - err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan( - &info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON, - ) + info := types.RoomInfo{} + + // "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // room_id TEXT NOT NULL UNIQUE, + docId := roomID + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + room, err := getRoom(s, ctx, pk, cosmosDocId) + if err != nil { - if err == sql.ErrNoRows { + if err == cosmosdbutil.ErrNoRows { return nil, nil } return nil, err } - var latestNIDs []int64 - if err = json.Unmarshal([]byte(latestNIDsJSON), &latestNIDs); err != nil { - return nil, err - } - info.IsStub = len(latestNIDs) == 0 + + info.RoomVersion = gomatrixserverlib.RoomVersion(room.Room.RoomVersion) + info.RoomNID = types.RoomNID(room.Room.RoomNID) + info.StateSnapshotNID = types.StateSnapshotNID(room.Room.StateSnapshotNID) + info.IsStub = len(room.Room.LatestEventNIDs) == 0 return &info, err } @@ -147,60 +246,135 @@ func (s *roomStatements) InsertRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, ) (roomNID types.RoomNID, err error) { - insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt) - _, err = insertStmt.ExecContext(ctx, roomID, roomVersion) - if err != nil { - return 0, fmt.Errorf("insertStmt.ExecContext: %w", err) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + + // INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2) + // ON CONFLICT DO NOTHING; + // room_id TEXT NOT NULL UNIQUE, + docId := roomID + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + dbData, errGet := getRoom(s, ctx, pk, cosmosDocId) + + if errGet == cosmosdbutil.ErrNoRows { + // room_nid INTEGER PRIMARY KEY AUTOINCREMENT, + roomNIDSeq, seqErr := GetNextRoomNID(s, ctx) + if seqErr != nil { + return 0, seqErr + } + + data := RoomCosmos{ + RoomNID: int64(roomNIDSeq), + RoomID: roomID, + RoomVersion: string(roomVersion), + } + + dbData = &RoomCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + Room: data, + } } - roomNID, err = s.SelectRoomNID(ctx, txn, roomID) + + // ON CONFLICT DO NOTHING; - Do Upsert + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + if err != nil { return 0, fmt.Errorf("s.SelectRoomNID: %w", err) } + + roomNID = types.RoomNID(dbData.Room.RoomNID) + return } func (s *roomStatements) SelectRoomNID( ctx context.Context, txn *sql.Tx, roomID string, ) (types.RoomNID, error) { - var roomNID int64 - stmt := sqlutil.TxStmt(txn, s.selectRoomNIDStmt) - err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) - return types.RoomNID(roomNID), err + + // "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // room_id TEXT NOT NULL UNIQUE, + docId := roomID + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + room, err := getRoom(s, ctx, pk, cosmosDocId) + + if err != nil { + return 0, err + } + + if room == nil { + return 0, nil + } + return types.RoomNID(room.Room.RoomNID), err } func (s *roomStatements) SelectLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { - var eventNIDs []types.EventNID - var nidsJSON string - var stateSnapshotNID int64 - stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt) - err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nidsJSON, &stateSnapshotNID) + + // "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomNID, + } + + response, err := queryRoom(s, ctx, s.selectLatestEventNIDsStmt, params) + if err != nil { return nil, 0, err } - if err := json.Unmarshal([]byte(nidsJSON), &eventNIDs); err != nil { - return nil, 0, err + + // TODO: Check the error handling + if len(response) == 0 { + return nil, 0, cosmosdbutil.ErrNoRows } - return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil + + //Assume 1 per RoomNID + room := response[0] + return mapToRoomEventNIDArray(room.Room.LatestEventNIDs), types.StateSnapshotNID(room.Room.StateSnapshotNID), nil } func (s *roomStatements) SelectLatestEventsNIDsForUpdate( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) { - var eventNIDs []types.EventNID - var nidsJSON string - var lastEventSentNID int64 - var stateSnapshotNID int64 - stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt) - err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nidsJSON, &lastEventSentNID, &stateSnapshotNID) + + // "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomNID, + } + + response, err := queryRoom(s, ctx, s.selectLatestEventNIDsForUpdateStmt, params) + if err != nil { return nil, 0, 0, err } - if err := json.Unmarshal([]byte(nidsJSON), &eventNIDs); err != nil { - return nil, 0, 0, err + + // TODO: Check the error handling + if len(response) == 0 { + return nil, 0, 0, cosmosdbutil.ErrNoRows } - return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil + + //Assume 1 per RoomNID + room := response[0] + return mapToRoomEventNIDArray(room.Room.LatestEventNIDs), types.EventNID(room.Room.LastEventSentNID), types.StateSnapshotNID(room.Room.StateSnapshotNID), nil } func (s *roomStatements) UpdateLatestEventNIDs( @@ -211,86 +385,113 @@ func (s *roomStatements) UpdateLatestEventNIDs( lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID, ) error { - stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt) - _, err := stmt.ExecContext( - ctx, - eventNIDsAsArray(eventNIDs), - int64(lastEventSentNID), - int64(stateSnapshotNID), - roomNID, - ) + + // "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomNID, + } + + response, err := queryRoom(s, ctx, s.selectLatestEventNIDsForUpdateStmt, params) + + if err != nil { + return err + } + + // TODO: Check the error handling + if len(response) == 0 { + return cosmosdbutil.ErrNoRows + } + + //Assume 1 per RoomNID + room := response[0] + + room.Room.LatestEventNIDs = mapFromEventNIDArray(eventNIDs) + room.Room.LastEventSentNID = int64(lastEventSentNID) + room.Room.StateSnapshotNID = int64(stateSnapshotNID) + + _, err = setRoom(s, ctx, room.Pk, room) return err } func (s *roomStatements) SelectRoomVersionsForRoomNIDs( ctx context.Context, roomNIDs []types.RoomNID, ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) { - sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) - sqlPrep, err := s.db.Prepare(sqlStr) + if roomNIDs == nil || len(roomNIDs) == 0 { + return make(map[types.RoomNID]gomatrixserverlib.RoomVersion), nil + } + + // "SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomNIDs, + } + + response, err := queryRoom(s, ctx, selectRoomVersionsForRoomNIDsSQL, params) + if err != nil { return nil, err } - iRoomNIDs := make([]interface{}, len(roomNIDs)) - for i, v := range roomNIDs { - iRoomNIDs[i] = v - } - rows, err := sqlPrep.QueryContext(ctx, iRoomNIDs...) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed") + result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion) - for rows.Next() { - var roomNID types.RoomNID - var roomVersion gomatrixserverlib.RoomVersion - if err = rows.Scan(&roomNID, &roomVersion); err != nil { - return nil, err - } - result[roomNID] = roomVersion + for _, item := range response { + result[types.RoomNID(item.Room.RoomNID)] = gomatrixserverlib.RoomVersion(item.Room.RoomVersion) } return result, nil } func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { - iRoomNIDs := make([]interface{}, len(roomNIDs)) - for i, v := range roomNIDs { - iRoomNIDs[i] = v + if roomNIDs == nil || len(roomNIDs) == 0 { + return []string{}, nil } - sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) - rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...) + + // "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomNIDs, + } + + response, err := queryRoom(s, ctx, bulkSelectRoomIDsSQL, params) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed") + var roomIDs []string - for rows.Next() { - var roomID string - if err = rows.Scan(&roomID); err != nil { - return nil, err - } - roomIDs = append(roomIDs, roomID) + for _, item := range response { + roomIDs = append(roomIDs, item.Room.RoomID) } return roomIDs, nil } func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { - iRoomIDs := make([]interface{}, len(roomIDs)) - for i, v := range roomIDs { - iRoomIDs[i] = v + if roomIDs == nil || len(roomIDs) == 0 { + return []types.RoomNID{}, nil } - sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1) - rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...) + + // "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomIDs, + } + + response, err := queryRoom(s, ctx, bulkSelectRoomNIDsSQL, params) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed") + var roomNIDs []types.RoomNID - for rows.Next() { - var roomNID types.RoomNID - if err = rows.Scan(&roomNID); err != nil { - return nil, err - } - roomNIDs = append(roomNIDs, roomNID) + for _, item := range response { + roomNIDs = append(roomNIDs, types.RoomNID(item.Room.RoomNID)) } return roomNIDs, nil } diff --git a/roomserver/storage/cosmosdb/state_block_table.go b/roomserver/storage/cosmosdb/state_block_table.go index f0c8169dd..f8cb9f317 100644 --- a/roomserver/storage/cosmosdb/state_block_table.go +++ b/roomserver/storage/cosmosdb/state_block_table.go @@ -20,33 +20,54 @@ import ( "database/sql" "fmt" "sort" - "strings" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" ) -const stateDataSchema = ` - CREATE TABLE IF NOT EXISTS roomserver_state_block ( - state_block_nid INTEGER NOT NULL, - event_type_nid INTEGER NOT NULL, - event_state_key_nid INTEGER NOT NULL, - event_nid INTEGER NOT NULL, - UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) - ); -` +// const stateDataSchema = ` +// CREATE TABLE IF NOT EXISTS roomserver_state_block ( +// state_block_nid INTEGER NOT NULL, +// event_type_nid INTEGER NOT NULL, +// event_state_key_nid INTEGER NOT NULL, +// event_nid INTEGER NOT NULL, +// UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) +// ); +// ` -const insertStateDataSQL = "" + - "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + - " VALUES ($1, $2, $3, $4)" +type StateBlockCosmos struct { + StateBlockNID int64 `json:"state_block_nid"` + EventTypeNID int64 `json:"event_type_nid"` + EventStateKeyNID int64 `json:"event_state_key_nid"` + EventNID int64 `json:"event_nid"` +} -const selectNextStateBlockNIDSQL = ` -SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block -` +type StateBlockCosmosMaxNID struct { + Max int64 `json:"maxstateblocknid"` +} + +type StateBlockCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + StateBlock StateBlockCosmos `json:"mx_roomserver_state_block"` +} + +// const insertStateDataSQL = "" + +// "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + +// " VALUES ($1, $2, $3, $4)" + +// SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block +const selectNextStateBlockNIDSQL = "" + + "select sub.maxinner != null ? sub.maxinner + 1 : 1 as maxstateblocknid " + + "from " + + "(select MAX(c.mx_roomserver_state_block.state_block_nid) maxinner from c where c._cn = @x1) as sub" // Bulk state lookup by numeric state block ID. // Sort by the state_block_nid, event_type_nid, event_state_key_nid @@ -54,10 +75,17 @@ SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block // together in the list and those entries will sorted by event_type_nid // and event_state_key_nid. This property makes it easier to merge two // state data blocks together. +// "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + +// " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + +// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" const bulkSelectStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_state_block.state_block_nid) " + + "order by c.mx_roomserver_state_block.state_block_nid " + + // Cant do multi field order by - The order by query does not have a corresponding composite index that it can be served from + // ", c.mx_roomserver_state_block.event_type_nid " + + // ", c.mx_roomserver_state_block.event_state_key_nid " + + " asc" // Bulk state lookup by numeric state block ID. // Filters the rows in each block to the requested types and state keys. @@ -66,35 +94,126 @@ const bulkSelectStateBlockEntriesSQL = "" + // of types and a list state_keys. So we have to filter the result in the // application to restrict it to the list of event types and state keys we // actually wanted. +// "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + +// " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + +// " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" + +// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" const bulkSelectFilteredStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + - " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_state_block.state_block_nid) " + + "and ARRAY_CONTAINS(@x3, c.mx_roomserver_state_block.event_type_nid) " + + "and ARRAY_CONTAINS(@x4, c.mx_roomserver_state_block.event_state_key_nid) " + + "order by c.mx_roomserver_state_block.state_block_nid " + + // Cant do multi field order by - The order by query does not have a corresponding composite index that it can be served from + // ", c.mx_roomserver_state_block.event_type_nid " + + // ", c.mx_roomserver_state_block.event_state_key_nid " + + "asc" type stateBlockStatements struct { - db *sql.DB - insertStateDataStmt *sql.Stmt - selectNextStateBlockNIDStmt *sql.Stmt - bulkSelectStateBlockEntriesStmt *sql.Stmt - bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt + db *Database + // insertStateDataStmt *sql.Stmt + selectNextStateBlockNIDStmt string + bulkSelectStateBlockEntriesStmt string + bulkSelectFilteredStateBlockEntriesStmt string + tableName string } -func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { - s := &stateBlockStatements{ - db: db, - } - _, err := db.Exec(stateDataSchema) +func queryStateBlock(s *stateBlockStatements, ctx context.Context, qry string, params map[string]interface{}) ([]StateBlockCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []StateBlockCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } + return response, nil +} - return s, shared.StatementList{ - {&s.insertStateDataStmt, insertStateDataSQL}, - {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, - {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, - {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, - }.Prepare(db) +func NewCosmosDBStateBlockTable(db *Database) (tables.StateBlock, error) { + s := &stateBlockStatements{ + db: db, + } + + // return s, shared.StatementList{ + // {&s.insertStateDataStmt, insertStateDataSQL}, + s.selectNextStateBlockNIDStmt = selectNextStateBlockNIDSQL + s.bulkSelectStateBlockEntriesStmt = bulkSelectStateBlockEntriesSQL + s.bulkSelectFilteredStateBlockEntriesStmt = bulkSelectFilteredStateBlockEntriesSQL + // }.Prepare(db) + s.tableName = "state_block" + return s, nil +} + +func inertStateBlockCore(s *stateBlockStatements, ctx context.Context, stateBlockNID types.StateBlockNID, entry types.StateEntry) error { + + // "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + + // " VALUES ($1, $2, $3, $4)" + data := StateBlockCosmos{ + EventNID: int64(entry.EventNID), + EventStateKeyNID: int64(entry.EventStateKeyNID), + EventTypeNID: int64(entry.EventTypeNID), + StateBlockNID: int64(stateBlockNID), + } + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + + // UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) + docId := fmt.Sprintf("%d_%d_%d", data.StateBlockNID, data.EventTypeNID, data.EventStateKeyNID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + var dbData = StateBlockCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + StateBlock: data, + } + + var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + _, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + + return err + +} + +func getNextStateBlockNID(s *stateBlockStatements, ctx context.Context) (int64, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var stateBlockNext []StateBlockCosmosMaxNID + params := map[string]interface{}{ + "@x1": dbCollectionName, + } + + var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions() + var query = cosmosdbapi.GetQuery(s.selectNextStateBlockNIDStmt, params) + var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &stateBlockNext, + optionsQry) + + if err != nil { + return 0, err + } + + return stateBlockNext[0].Max, nil } func (s *stateBlockStatements) BulkInsertStateData( @@ -104,75 +223,64 @@ func (s *stateBlockStatements) BulkInsertStateData( if len(entries) == 0 { return 0, nil } - var stateBlockNID types.StateBlockNID - err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) - if err != nil { - return 0, err + + nextID, errNextID := getNextStateBlockNID(s, ctx) + if errNextID != nil { + return 0, errNextID } + + stateBlockNID := types.StateBlockNID(nextID) + for _, entry := range entries { - _, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext( - ctx, - int64(stateBlockNID), - int64(entry.EventTypeNID), - int64(entry.EventStateKeyNID), - int64(entry.EventNID), - ) + err := inertStateBlockCore(s, ctx, stateBlockNID, entry) if err != nil { return 0, err } } - return stateBlockNID, err + return stateBlockNID, nil } func (s *stateBlockStatements) BulkSelectStateBlockEntries( ctx context.Context, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { - nids := make([]interface{}, len(stateBlockNIDs)) - for k, v := range stateBlockNIDs { - nids[k] = v + + // "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + + // " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + + // " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var response []StateBlockCosmosData + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": stateBlockNIDs, } - selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) - selectStmt, err := s.db.Prepare(selectOrig) + + response, err := queryStateBlock(s, ctx, s.bulkSelectStateBlockEntriesStmt, params) + if err != nil { return nil, err } - rows, err := selectStmt.QueryContext(ctx, nids...) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed") results := make([]types.StateEntryList, len(stateBlockNIDs)) // current is a pointer to the StateEntryList to append the state entries to. var current *types.StateEntryList i := 0 - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err := rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { - return nil, err - } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID { + for _, item := range response { + entry := types.StateEntry{} + entry.EventTypeNID = types.EventTypeNID(item.StateBlock.EventTypeNID) + entry.EventStateKeyNID = types.EventStateKeyNID(item.StateBlock.EventStateKeyNID) + entry.EventNID = types.EventNID(item.StateBlock.EventNID) + + if current == nil || types.StateBlockNID(item.StateBlock.StateBlockNID) != current.StateBlockNID { // The state entry row is for a different state data block to the current one. // So we start appending to the next entry in the list. current = &results[i] - current.StateBlockNID = types.StateBlockNID(stateBlockNID) + current.StateBlockNID = types.StateBlockNID(item.StateBlock.StateBlockNID) i++ } current.StateEntries = append(current.StateEntries, entry) } - if i != len(nids) { - return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids)) + if i != len(stateBlockNIDs) { + return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs)) } return results, nil } @@ -187,34 +295,33 @@ func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( sort.Sort(tuples) eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1) - sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1) - sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1) + // sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1) + // sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1) + // sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1) - var params []interface{} - for _, val := range stateBlockNIDs { - params = append(params, int64(val)) - } - for _, val := range eventTypeNIDArray { - params = append(params, val) - } - for _, val := range eventStateKeyNIDArray { - params = append(params, val) + // "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + + // " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + + // " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" + + // " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var response []StateBlockCosmosData + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": stateBlockNIDs, + "@x3": eventTypeNIDArray, + "@x4": eventStateKeyNIDArray, } - rows, err := s.db.QueryContext( - ctx, - sqlStatement, - params..., - ) + response, err := queryStateBlock(s, ctx, s.bulkSelectFilteredStateBlockEntriesStmt, params) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed") var results []types.StateEntryList var current types.StateEntryList - for rows.Next() { + for _, item := range response { var ( stateBlockNID int64 eventTypeNID int64 @@ -222,11 +329,10 @@ func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( eventNID int64 entry types.StateEntry ) - if err := rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { - return nil, err - } + stateBlockNID = item.StateBlock.StateBlockNID + eventTypeNID = item.StateBlock.EventTypeNID + eventStateKeyNID = item.StateBlock.EventStateKeyNID + eventNID = item.StateBlock.EventNID entry.EventTypeNID = types.EventTypeNID(eventTypeNID) entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) entry.EventNID = types.EventNID(eventNID) diff --git a/roomserver/storage/cosmosdb/state_snapshot_seq.go b/roomserver/storage/cosmosdb/state_snapshot_seq.go new file mode 100644 index 000000000..c6c1b66b5 --- /dev/null +++ b/roomserver/storage/cosmosdb/state_snapshot_seq.go @@ -0,0 +1,12 @@ +package cosmosdb + +import ( + "context" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" +) + +func GetNextStateSnapshotNID(s *stateSnapshotStatements, ctx context.Context) (int64, error) { + const docId = "statesnapshotnid_seq" + return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1) +} diff --git a/roomserver/storage/cosmosdb/state_snapshot_table.go b/roomserver/storage/cosmosdb/state_snapshot_table.go index f75b18755..48dd89a51 100644 --- a/roomserver/storage/cosmosdb/state_snapshot_table.go +++ b/roomserver/storage/cosmosdb/state_snapshot_table.go @@ -18,106 +18,169 @@ package cosmosdb import ( "context" "database/sql" - "encoding/json" "fmt" - "strings" + "time" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) -const stateSnapshotSchema = ` - CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( - state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, - room_nid INTEGER NOT NULL, - state_block_nids TEXT NOT NULL DEFAULT '[]' - ); -` +// const stateSnapshotSchema = ` +// CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( +// state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, +// room_nid INTEGER NOT NULL, +// state_block_nids TEXT NOT NULL DEFAULT '[]' +// ); +// ` -const insertStateSQL = ` - INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids) - VALUES ($1, $2);` +type StateSnapshotCosmos struct { + StateSnapshotNID int64 `json:"state_snapshot_nid"` + RoomNID int64 `json:"room_nid"` + StateBlockNIDs []int64 `json:"state_block_nids"` +} + +type StateSnapshotCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + StateSnapshot StateSnapshotCosmos `json:"mx_roomserver_state_snapshot"` +} + +// const insertStateSQL = ` +// INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids) +// VALUES ($1, $2);` // Bulk state data NID lookup. // Sorting by state_snapshot_nid means we can use binary search over the result // to lookup the state data NIDs for a state snapshot NID. +// "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + +// " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC" const bulkSelectStateBlockNIDsSQL = "" + - "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + - " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC" + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_state_snapshot.state_snapshot_nid) " + + "order by c.mx_roomserver_state_snapshot.state_snapshot_nid asc" type stateSnapshotStatements struct { - db *sql.DB - insertStateStmt *sql.Stmt - bulkSelectStateBlockNIDsStmt *sql.Stmt + db *Database + // insertStateStmt *sql.Stmt + bulkSelectStateBlockNIDsStmt string + tableName string } -func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { +func mapFromStateBlockNIDArray(stateBlockNIDs []types.StateBlockNID) []int64 { + result := []int64{} + for i := 0; i < len(stateBlockNIDs); i++ { + result = append(result, int64(stateBlockNIDs[i])) + } + return result +} + +func mapToStateBlockNIDArray(stateBlockNIDs []int64) []types.StateBlockNID { + result := []types.StateBlockNID{} + for i := 0; i < len(stateBlockNIDs); i++ { + result = append(result, types.StateBlockNID(stateBlockNIDs[i])) + } + return result +} + +func NewCosmosDBStateSnapshotTable(db *Database) (tables.StateSnapshot, error) { s := &stateSnapshotStatements{ db: db, } - _, err := db.Exec(stateSnapshotSchema) - if err != nil { - return nil, err - } - return s, shared.StatementList{ - {&s.insertStateStmt, insertStateSQL}, - {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, - }.Prepare(db) + // return s, shared.StatementList{ + // {&s.insertStateStmt, insertStateSQL}, + s.bulkSelectStateBlockNIDsStmt = bulkSelectStateBlockNIDsSQL + // }.Prepare(db) + s.tableName = "state_snapshots" + return s, nil } func (s *stateSnapshotStatements) InsertState( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, ) (stateNID types.StateSnapshotNID, err error) { - stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs) - if err != nil { - return + + // INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids) + // VALUES ($1, $2);` + stateSnapshotNIDSeq, seqErr := GetNextStateSnapshotNID(s, ctx) + if seqErr != nil { + return 0, seqErr } - insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt) - res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)) + + data := StateSnapshotCosmos{ + RoomNID: int64(roomNID), + StateBlockNIDs: mapFromStateBlockNIDArray(stateBlockNIDs), + StateSnapshotNID: int64(stateSnapshotNIDSeq), + } + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + + // state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, + docId := fmt.Sprintf("%d", stateSnapshotNIDSeq) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + var dbData = StateSnapshotCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + StateSnapshot: data, + } + + var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + if err != nil { return 0, err } - lastRowID, err := res.LastInsertId() - if err != nil { - return 0, err - } - stateNID = types.StateSnapshotNID(lastRowID) + + stateNID = types.StateSnapshotNID(stateSnapshotNIDSeq) return } func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( ctx context.Context, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { - nids := make([]interface{}, len(stateNIDs)) - for k, v := range stateNIDs { - nids[k] = v + + // "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + + // " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []StateSnapshotCosmosData + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": stateNIDs, } - selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) - selectStmt, err := s.db.Prepare(selectOrig) + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(s.bulkSelectStateBlockNIDsStmt, params) + var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) if err != nil { return nil, err } - rows, err := selectStmt.QueryContext(ctx, nids...) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed") results := make([]types.StateBlockNIDList, len(stateNIDs)) i := 0 - for ; rows.Next(); i++ { + for _, item := range response { result := &results[i] - var stateBlockNIDsJSON string - if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil { - return nil, err - } - if err := json.Unmarshal([]byte(stateBlockNIDsJSON), &result.StateBlockNIDs); err != nil { - return nil, err - } + result.StateSnapshotNID = types.StateSnapshotNID(item.StateSnapshot.StateSnapshotNID) + result.StateBlockNIDs = mapToStateBlockNIDArray(item.StateSnapshot.StateBlockNIDs) + i++ } if i != len(stateNIDs) { return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs)) diff --git a/roomserver/storage/cosmosdb/storage.go b/roomserver/storage/cosmosdb/storage.go index bb3f6af2e..3aef24afa 100644 --- a/roomserver/storage/cosmosdb/storage.go +++ b/roomserver/storage/cosmosdb/storage.go @@ -17,14 +17,15 @@ package cosmosdb import ( "context" - "database/sql" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" _ "github.com/mattn/go-sqlite3" "github.com/matrix-org/dendrite/internal/caching" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" - "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" @@ -33,16 +34,26 @@ import ( // A Database is used to store room events and stream offsets. type Database struct { shared.Database + connection cosmosdbapi.CosmosConnection + databaseName string + cosmosConfig cosmosdbapi.CosmosConfig + serverName gomatrixserverlib.ServerName } // Open a sqlite database. func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { - var d Database - var db *sql.DB - var err error - if db, err = sqlutil.Open(dbProperties); err != nil { - return nil, err + conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString) + config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString) + d := &Database{ + databaseName: "roomserver", + connection: conn, + cosmosConfig: config, } + // var db *sql.DB + // var err error + // if db, err = sqlutil.Open(dbProperties); err != nil { + // return nil, err + // } //db.Exec("PRAGMA journal_mode=WAL;") //db.Exec("PRAGMA read_uncommitted = true;") @@ -51,89 +62,91 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) // cause the roomserver to be unresponsive to new events because something will // acquire the global mutex and never unlock it because it is waiting for a connection // which it will never obtain. - db.SetMaxOpenConns(20) + // db.SetMaxOpenConns(20) // Create tables before executing migrations so we don't fail if the table is missing, // and THEN prepare statements so we don't fail due to referencing new columns - ms := membershipStatements{} - if err := ms.execSchema(db); err != nil { - return nil, err - } - m := sqlutil.NewMigrations() - deltas.LoadAddForgottenColumn(m) - if err := m.RunDeltas(db, dbProperties); err != nil { - return nil, err - } - if err := d.prepare(db, cache); err != nil { + // ms := membershipStatements{} + // if err := ms.execSchema(db); err != nil { + // return nil, err + // } + // m := sqlutil.NewMigrations() + // deltas.LoadAddForgottenColumn(m) + // if err := m.RunDeltas(db, dbProperties); err != nil { + // return nil, err + // } + if err := d.prepare(cache); err != nil { return nil, err } - return &d, nil + return d, nil } // nolint: gocyclo -func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { +func (d *Database) prepare(cache caching.RoomServerCaches) error { var err error - eventStateKeys, err := NewSqliteEventStateKeysTable(db) + d.databaseName = "roomserver" + eventStateKeys, err := NewCosmosDBEventStateKeysTable(d) if err != nil { return err } - eventTypes, err := NewSqliteEventTypesTable(db) + eventTypes, err := NewCosmosDBEventTypesTable(d) if err != nil { return err } - eventJSON, err := NewSqliteEventJSONTable(db) + eventJSON, err := NewCosmosDBEventJSONTable(d) if err != nil { return err } - events, err := NewSqliteEventsTable(db) + events, err := NewCosmosDBEventsTable(d) if err != nil { return err } - rooms, err := NewSqliteRoomsTable(db) + rooms, err := NewCosmosDBRoomsTable(d) if err != nil { return err } - transactions, err := NewSqliteTransactionsTable(db) + transactions, err := NewCosmosDBTransactionsTable(d) if err != nil { return err } - stateBlock, err := NewSqliteStateBlockTable(db) + stateBlock, err := NewCosmosDBStateBlockTable(d) if err != nil { return err } - stateSnapshot, err := NewSqliteStateSnapshotTable(db) + stateSnapshot, err := NewCosmosDBStateSnapshotTable(d) if err != nil { return err } - prevEvents, err := NewSqlitePrevEventsTable(db) + prevEvents, err := NewCosmosDBPrevEventsTable(d) if err != nil { return err } - roomAliases, err := NewSqliteRoomAliasesTable(db) + roomAliases, err := NewCosmosDBRoomAliasesTable(d) if err != nil { return err } - invites, err := NewSqliteInvitesTable(db) + invites, err := NewCosmosDBInvitesTable(d) if err != nil { return err } - membership, err := NewSqliteMembershipTable(db) + membership, err := NewCosmosDBMembershipTable(d) if err != nil { return err } - published, err := NewSqlitePublishedTable(db) + published, err := NewCosmosDBPublishedTable(d) if err != nil { return err } - redactions, err := NewSqliteRedactionsTable(db) + redactions, err := NewCosmosDBRedactionsTable(d) if err != nil { return err } d.Database = shared.Database{ - DB: db, - Cache: cache, - Writer: sqlutil.NewExclusiveWriter(), + DB: nil, + Cache: cache, + //Use the Fake SQL Writer here + Writer: cosmosdbutil.NewExclusiveWriterFake(), EventsTable: events, EventTypesTable: eventTypes, EventStateKeysTable: eventStateKeys, diff --git a/roomserver/storage/cosmosdb/transactions_table.go b/roomserver/storage/cosmosdb/transactions_table.go index 9be93ed34..3c3f20973 100644 --- a/roomserver/storage/cosmosdb/transactions_table.go +++ b/roomserver/storage/cosmosdb/transactions_table.go @@ -18,50 +18,84 @@ package cosmosdb import ( "context" "database/sql" + "fmt" + "time" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) -const transactionsSchema = ` - CREATE TABLE IF NOT EXISTS roomserver_transactions ( - transaction_id TEXT NOT NULL, - session_id INTEGER NOT NULL, - user_id TEXT NOT NULL, - event_id TEXT NOT NULL, - PRIMARY KEY (transaction_id, session_id, user_id) - ); -` -const insertTransactionSQL = ` - INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id) - VALUES ($1, $2, $3, $4) -` +// const transactionsSchema = ` +// CREATE TABLE IF NOT EXISTS roomserver_transactions ( +// transaction_id TEXT NOT NULL, +// session_id INTEGER NOT NULL, +// user_id TEXT NOT NULL, +// event_id TEXT NOT NULL, +// PRIMARY KEY (transaction_id, session_id, user_id) +// ); +// ` -const selectTransactionEventIDSQL = ` - SELECT event_id FROM roomserver_transactions - WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3 -` - -type transactionStatements struct { - db *sql.DB - insertTransactionStmt *sql.Stmt - selectTransactionEventIDStmt *sql.Stmt +type TransactionCosmos struct { + TransactionID string `json:"transaction_id"` + SessionID int64 `json:"session_id"` + UserID string `json:"user_id"` + EventID string `json:"event_id"` } -func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) { +type TransactionCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Transaction TransactionCosmos `json:"mx_roomserver_transaction"` +} + +// const insertTransactionSQL = ` +// INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id) +// VALUES ($1, $2, $3, $4) +// ` + +// const selectTransactionEventIDSQL = ` +// SELECT event_id FROM roomserver_transactions +// WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3 +// ` + +type transactionStatements struct { + db *Database + // insertTransactionStmt *sql.Stmt + selectTransactionEventIDStmt *sql.Stmt + tableName string +} + +func getTransaction(s *transactionStatements, ctx context.Context, pk string, docId string) (*TransactionCosmosData, error) { + response := TransactionCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, cosmosdbutil.ErrNoRows + } + + return &response, err +} + +func NewCosmosDBTransactionsTable(db *Database) (tables.Transactions, error) { s := &transactionStatements{ db: db, } - _, err := db.Exec(transactionsSchema) - if err != nil { - return nil, err - } - - return s, shared.StatementList{ - {&s.insertTransactionStmt, insertTransactionSQL}, - {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, - }.Prepare(db) + // return s, shared.StatementList{ + // {&s.insertTransactionStmt, insertTransactionSQL}, + // {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, + // }.Prepare(db) + s.tableName = "transactions" + return s, nil } func (s *transactionStatements) InsertTransaction( @@ -71,10 +105,39 @@ func (s *transactionStatements) InsertTransaction( userID string, eventID string, ) error { - stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt) - _, err := stmt.ExecContext( - ctx, transactionID, sessionID, userID, eventID, - ) + + // INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id) + // VALUES ($1, $2, $3, $4) + data := TransactionCosmos{ + EventID: eventID, + SessionID: sessionID, + TransactionID: transactionID, + UserID: userID, + } + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + + // PRIMARY KEY (transaction_id, session_id, user_id) + docId := fmt.Sprintf("%s_%d_%s", transactionID, sessionID, userID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + var dbData = TransactionCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + Transaction: data, + } + + var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + _, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + return err } @@ -84,8 +147,21 @@ func (s *transactionStatements) SelectTransactionEventID( sessionID int64, userID string, ) (eventID string, err error) { - err = s.selectTransactionEventIDStmt.QueryRowContext( - ctx, transactionID, sessionID, userID, - ).Scan(&eventID) - return + + // SELECT event_id FROM roomserver_transactions + // WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3 + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // PRIMARY KEY (transaction_id, session_id, user_id) + docId := fmt.Sprintf("%s_%d_%s", transactionID, sessionID, userID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + response, err := getTransaction(s, ctx, pk, cosmosDocId) + + if err != nil { + return "", err + } + + return response.Transaction.EventID, err } diff --git a/userapi/storage/accounts/cosmosdb/account_data_table.go b/userapi/storage/accounts/cosmosdb/account_data_table.go index 6a471d07a..c690d84b3 100644 --- a/userapi/storage/accounts/cosmosdb/account_data_table.go +++ b/userapi/storage/accounts/cosmosdb/account_data_table.go @@ -71,6 +71,27 @@ func (s *accountDataStatements) prepare(db *Database) (err error) { return } +func queryAccountData(s *accountDataStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountDataCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []AccountDataCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + func (s *accountDataStatements) insertAccountData( ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ) error { @@ -92,10 +113,14 @@ func (s *accountDataStatements) insertAccountData( id = fmt.Sprintf("%s_%s_%s", result.LocalPart, result.RoomId, result.Type) } + docId := id + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var dbData = AccountDataCosmosData{ - Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, id), + Id: cosmosDocId, Cn: dbCollectionName, - Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName), + Pk: pk, Timestamp: time.Now().Unix(), AccountData: result, } @@ -120,24 +145,15 @@ func (s *accountDataStatements) selectAccountData( ) { // "SELECT room_id, type, content FROM account_data WHERE localpart = $1" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) - response := []AccountDataCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": localpart, } - var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(s.selectAccountDataStmt, params) - var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - options) - if ex != nil { - return nil, nil, ex + response, err := queryAccountData(s, ctx, s.selectAccountDataStmt, params) + + if err != nil { + return nil, nil, err } global := map[string]json.RawMessage{} @@ -166,26 +182,17 @@ func (s *accountDataStatements) selectAccountDataByType( // "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) - response := []AccountDataCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": localpart, "@x3": roomID, "@x4": dataType, } - var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(s.selectAccountDataByTypeStmt, params) - var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - options) - if ex != nil { - return nil, ex + response, err := queryAccountData(s, ctx, s.selectAccountDataByTypeStmt, params) + + if err != nil { + return nil, err } if len(response) == 0 { diff --git a/userapi/storage/accounts/cosmosdb/accounts_table.go b/userapi/storage/accounts/cosmosdb/accounts_table.go index d20e01af3..2e13c4047 100644 --- a/userapi/storage/accounts/cosmosdb/accounts_table.go +++ b/userapi/storage/accounts/cosmosdb/accounts_table.go @@ -21,6 +21,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/userapi/api" @@ -87,17 +88,42 @@ func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.Serv return } -func getAccount(s *accountsStatements, ctx context.Context, pk string, docId string) (*AccountCosmosData, error) { - response := AccountCosmosData{} - var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk) - var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument( +func queryAccount(s *accountsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []AccountCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func getAccount(s *accountsStatements, ctx context.Context, pk string, docId string) (*AccountCosmosData, error) { + response := AccountCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, docId, - optionsGet, &response) - return &response, ex + + if response.Id == "" { + return nil, cosmosdbutil.ErrNoRows + } + + return &response, err } func setAccount(s *accountsStatements, ctx context.Context, pk string, account AccountCosmosData) (*AccountCosmosData, error) { @@ -155,10 +181,14 @@ func (s *accountsStatements) insertAccount( var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) + docId := result.Localpart + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var dbData = AccountCosmosData{ - Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Localpart), + Id: cosmosDocId, Cn: dbCollectionName, - Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName), + Pk: pk, Timestamp: time.Now().Unix(), Account: data, } @@ -184,10 +214,11 @@ func (s *accountsStatements) updatePassword( // "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) - var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + docId := localpart + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) - var response, exGet = getAccount(s, ctx, pk, docId) + var response, exGet = getAccount(s, ctx, pk, cosmosDocId) if exGet != nil { return exGet } @@ -207,10 +238,12 @@ func (s *accountsStatements) deactivateAccount( // "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) - var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) - var response, exGet = getAccount(s, ctx, pk, docId) + docId := localpart + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + var response, exGet = getAccount(s, ctx, pk, cosmosDocId) if exGet != nil { return exGet } @@ -230,24 +263,15 @@ func (s *accountsStatements) selectPasswordHash( // "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) - response := []AccountCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": localpart, } - var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(s.selectPasswordHashStmt, params) - var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - options) - if ex != nil { - return "", ex + response, err := queryAccount(s, ctx, s.selectPasswordHashStmt, params) + + if err != nil { + return "", err } if len(response) == 0 { @@ -268,24 +292,15 @@ func (s *accountsStatements) selectAccountByLocalpart( // "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) - response := []AccountCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": localpart, } - var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(s.selectAccountByLocalpartStmt, params) - var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - options) - if ex != nil { - return nil, ex + response, err := queryAccount(s, ctx, s.selectAccountByLocalpartStmt, params) + + if err != nil { + return nil, err } if len(response) == 0 { diff --git a/userapi/storage/accounts/cosmosdb/openid_table.go b/userapi/storage/accounts/cosmosdb/openid_table.go index 2567b8857..c0cb65ec5 100644 --- a/userapi/storage/accounts/cosmosdb/openid_table.go +++ b/userapi/storage/accounts/cosmosdb/openid_table.go @@ -62,6 +62,27 @@ func mapToToken(api api.OpenIDToken) OpenIDTokenCosmos { } } +func queryOpenIdToken(s *tokenStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OpenIdTokenCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []OpenIdTokenCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + func (s *tokenStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) { s.db = db s.selectTokenStmt = "select * from c where c._cn = @x1 and c.mx_userapi_openidtoken.token = @x2" @@ -87,10 +108,14 @@ func (s *tokenStatements) insertToken( var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) + docId := result.Token + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var dbData = OpenIdTokenCosmosData{ - Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Token), + Id: cosmosDocId, Cn: dbCollectionName, - Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName), + Pk: pk, Timestamp: time.Now().Unix(), OpenIdToken: mapToToken(*result), } @@ -120,24 +145,14 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes( // "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) - response := []OpenIdTokenCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": token, } - var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(s.selectTokenStmt, params) - var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - options) + response, err := queryOpenIdToken(s, ctx, s.selectTokenStmt, params) - if ex != nil { - return nil, ex + if err != nil { + return nil, err } if len(response) == 0 { diff --git a/userapi/storage/accounts/cosmosdb/profile_table.go b/userapi/storage/accounts/cosmosdb/profile_table.go index bb02f4867..dbe6bd392 100644 --- a/userapi/storage/accounts/cosmosdb/profile_table.go +++ b/userapi/storage/accounts/cosmosdb/profile_table.go @@ -21,6 +21,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -87,17 +88,42 @@ func (s *profilesStatements) prepare(db *Database) (err error) { return } -func getProfile(s *profilesStatements, ctx context.Context, pk string, docId string) (*ProfileCosmosData, error) { - response := ProfileCosmosData{} - var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk) - var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument( +func queryProfile(s *profilesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ProfileCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []ProfileCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func getProfile(s *profilesStatements, ctx context.Context, pk string, docId string) (*ProfileCosmosData, error) { + response := ProfileCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, docId, - optionsGet, &response) - return &response, ex + + if response.Id == "" { + return nil, cosmosdbutil.ErrNoRows + } + + return &response, err } func setProfile(s *profilesStatements, ctx context.Context, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) { @@ -123,10 +149,14 @@ func (s *profilesStatements) insertProfile( var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) + docId := localpart + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var dbData = ProfileCosmosData{ - Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Localpart), + Id: cosmosDocId, Cn: dbCollectionName, - Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName), + Pk: pk, Timestamp: time.Now().Unix(), Profile: mapToProfile(*result), } @@ -148,24 +178,15 @@ func (s *profilesStatements) selectProfileByLocalpart( // "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) - response := []ProfileCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": localpart, } - var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(s.selectProfileByLocalpartStmt, params) - var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - options) - if ex != nil { - return nil, ex + response, err := queryProfile(s, ctx, s.selectProfileByLocalpartStmt, params) + + if err != nil { + return nil, err } if len(response) == 0 { @@ -186,10 +207,11 @@ func (s *profilesStatements) setAvatarURL( // "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) - var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart) + docId := localpart + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) - var response, exGet = getProfile(s, ctx, pk, docId) + var response, exGet = getProfile(s, ctx, pk, cosmosDocId) if exGet != nil { return exGet } @@ -209,9 +231,10 @@ func (s *profilesStatements) setDisplayName( // "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) - var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart) - var response, exGet = getProfile(s, ctx, pk, docId) + docId := localpart + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response, exGet = getProfile(s, ctx, pk, cosmosDocId) if exGet != nil { return exGet } @@ -232,25 +255,16 @@ func (s *profilesStatements) selectProfilesBySearch( // "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) - response := []ProfileCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": searchString, "@x3": limit, } - var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(s.selectProfilesBySearchStmt, params) - var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - options) - if ex != nil { - return nil, ex + response, err := queryProfile(s, ctx, s.selectProfilesBySearchStmt, params) + + if err != nil { + return nil, err } for i := 0; i < len(response); i++ { diff --git a/userapi/storage/accounts/cosmosdb/storage.go b/userapi/storage/accounts/cosmosdb/storage.go index 20c4d2071..0f344945e 100644 --- a/userapi/storage/accounts/cosmosdb/storage.go +++ b/userapi/storage/accounts/cosmosdb/storage.go @@ -20,6 +20,8 @@ import ( "errors" "strconv" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/internal/cosmosdbutil" @@ -27,7 +29,6 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -37,6 +38,7 @@ import ( // Database represents an account database type Database struct { sqlutil.PartitionOffsetStatements + writer sqlutil.Writer accounts accountsStatements profiles profilesStatements accountDatas accountDataStatements @@ -62,7 +64,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver connection: conn, cosmosConfig: config, // db: db, - // writer: sqlutil.NewExclusiveWriter(), + writer: sqlutil.NewExclusiveWriter(), // bcryptCost: bcryptCost, // openIDTokenLifetimeMS: openIDTokenLifetimeMS, } diff --git a/userapi/storage/accounts/cosmosdb/threepid_table.go b/userapi/storage/accounts/cosmosdb/threepid_table.go index b8bf12263..358b370e6 100644 --- a/userapi/storage/accounts/cosmosdb/threepid_table.go +++ b/userapi/storage/accounts/cosmosdb/threepid_table.go @@ -69,31 +69,43 @@ func (s *threepidStatements) prepare(db *Database) (err error) { return } +func queryThreePID(s *threepidStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ThreePIDCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []ThreePIDCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + func (s *threepidStatements) selectLocalpartForThreePID( ctx context.Context, threepid string, medium string, ) (localpart string, err error) { // "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) - response := []ThreePIDCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": threepid, "@x3": medium, } - var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(s.selectLocalpartForThreePIDStmt, params) - var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - options) - if ex != nil { - return "", ex + response, err := queryThreePID(s, ctx, s.selectLocalpartForThreePIDStmt, params) + + if err != nil { + return "", err } if len(response) == 0 { @@ -109,24 +121,14 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( // "SELECT threepid, medium FROM account_threepid WHERE localpart = $1" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) - response := []ThreePIDCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": localpart, } - var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(s.selectThreePIDsForLocalpartStmt, params) - var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - options) + response, err := queryThreePID(s, ctx, s.selectThreePIDsForLocalpartStmt, params) - if ex != nil { - return threepids, ex + if err != nil { + return threepids, err } if len(response) == 0 { @@ -158,10 +160,11 @@ func (s *threepidStatements) insertThreePID( docId := fmt.Sprintf("%s_%s", threepid, medium) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) var dbData = ThreePIDCosmosData{ Id: cosmosDocId, Cn: dbCollectionName, - Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName), + Pk: pk, Timestamp: time.Now().Unix(), ThreePID: result, } diff --git a/userapi/storage/devices/cosmosdb/devices_table.go b/userapi/storage/devices/cosmosdb/devices_table.go index d968c6208..ae1062140 100644 --- a/userapi/storage/devices/cosmosdb/devices_table.go +++ b/userapi/storage/devices/cosmosdb/devices_table.go @@ -16,10 +16,11 @@ package cosmosdb import ( "context" - "errors" "fmt" "time" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/userapi/api" @@ -82,15 +83,15 @@ type DeviceCosmosSessionCount struct { } type devicesStatements struct { - db *Database + db *Database selectDevicesCountStmt string selectDeviceByTokenStmt string // selectDeviceByIDStmt *sql.Stmt selectDevicesByIDStmt string selectDevicesByLocalpartStmt string selectDevicesByLocalpartExceptIDStmt string - serverName gomatrixserverlib.ServerName - tableName string + serverName gomatrixserverlib.ServerName + tableName string } func mapFromDevice(db DeviceCosmos) api.Device { @@ -121,17 +122,42 @@ func mapTodevice(api api.Device, s *devicesStatements) DeviceCosmos { } } -func getDevice(s *devicesStatements, ctx context.Context, pk string, docId string) (*DeviceCosmosData, error) { - response := DeviceCosmosData{} - var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk) - var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument( +func queryDevice(s *devicesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []DeviceCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func getDevice(s *devicesStatements, ctx context.Context, pk string, docId string) (*DeviceCosmosData, error) { + response := DeviceCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, docId, - optionsGet, &response) - return &response, ex + + if response.Id == "" { + return nil, cosmosdbutil.ErrNoRows + } + + return &response, err } func setDevice(s *devicesStatements, ctx context.Context, pk string, device DeviceCosmosData) (*DeviceCosmosData, error) { @@ -209,10 +235,11 @@ func (s *devicesStatements) insertDevice( // HACK: check for duplicate PK as we are using the UNIQUE key for the DocId docId := fmt.Sprintf("%s_%s", localpart, id) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + var dbData = DeviceCosmosData{ Id: cosmosDocId, Cn: dbCollectionName, - Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName), + Pk: pk, Timestamp: time.Now().Unix(), Device: data, } @@ -260,7 +287,6 @@ func (s *devicesStatements) deleteDevices( ) error { // "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) var response []DeviceCosmosData params := map[string]interface{}{ "@x1": dbCollectionName, @@ -268,15 +294,8 @@ func (s *devicesStatements) deleteDevices( "@x3": devices, } - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartStmt, params) - var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) + response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartStmt, params) + if err != nil { return err } @@ -291,8 +310,6 @@ func (s *devicesStatements) deleteDevicesByLocalpart( ) error { // "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) - var response []DeviceCosmosData exceptDevices := []string{ exceptDeviceID, } @@ -302,15 +319,8 @@ func (s *devicesStatements) deleteDevicesByLocalpart( "@x3": exceptDevices, } - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartStmt, params) - var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) + response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartStmt, params) + if err != nil { return err } @@ -325,9 +335,9 @@ func (s *devicesStatements) updateDeviceName( ) error { // "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) docId := fmt.Sprintf("%s_%s", localpart, deviceID) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) var response, exGet = getDevice(s, ctx, pk, cosmosDocId) if exGet != nil { return exGet @@ -347,27 +357,19 @@ func (s *devicesStatements) selectDeviceByToken( ) (*api.Device, error) { // "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) var response []DeviceCosmosData params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": accessToken, } - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(s.selectDeviceByTokenStmt, params) - var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) + response, err := queryDevice(s, ctx, s.selectDeviceByTokenStmt, params) + if err != nil { return nil, err } if len(response) == 0 { - return nil, errors.New(fmt.Sprintf("No Devices found with AccessToken %s", accessToken)) + return nil, cosmosdbutil.ErrNoRows } if err == nil { @@ -384,9 +386,9 @@ func (s *devicesStatements) selectDeviceByID( ) (*api.Device, error) { // "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) docId := fmt.Sprintf("%s_%s", localpart, deviceID) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) var response, exGet = getDevice(s, ctx, pk, cosmosDocId) if exGet != nil { return nil, exGet @@ -401,23 +403,14 @@ func (s *devicesStatements) selectDevicesByLocalpart( devices := []api.Device{} // "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) - var response []DeviceCosmosData params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": localpart, "@x3": exceptDeviceID, } - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartExceptIDStmt, params) - var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) + response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartExceptIDStmt, params) + if err != nil { return nil, err } @@ -435,22 +428,14 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s // "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)" var devices []api.Device var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) var response []DeviceCosmosData params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": deviceIDs, } - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(s.selectDevicesByIDStmt, params) - var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) + response, err := queryDevice(s, ctx, s.selectDevicesByIDStmt, params) + if err != nil { return nil, err } @@ -466,9 +451,9 @@ func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, localpart, // "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4" var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) docId := fmt.Sprintf("%s_%s", localpart, deviceID) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) var response, exGet = getDevice(s, ctx, pk, cosmosDocId) if exGet != nil { return exGet diff --git a/userapi/storage/devices/cosmosdb/storage.go b/userapi/storage/devices/cosmosdb/storage.go index a5ddd5977..47c7a6d4e 100644 --- a/userapi/storage/devices/cosmosdb/storage.go +++ b/userapi/storage/devices/cosmosdb/storage.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/dendrite/internal/cosmosdbutil" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" @@ -37,7 +36,6 @@ var deviceIDByteLength = 6 // Database represents a device database. type Database struct { - writer sqlutil.Writer devices devicesStatements connection cosmosdbapi.CosmosConnection databaseName string