Implement Cosmos DB for the RoomServer Service (#5)

* - Implement Cosmos for the devices_table
- Use the ConnectionString in the YAML to include the Tenant
- Revert all other non implemented tables back to use SQLLite3

* - Change the Config to use "test.criticicalarc.com" Container
- Add generic function GetDocumentOrNil to standardize GetDocument
- Add func to return CrossPartition queries for Aggregates
- Add func GetNextSequence() as generic seq generator for AutoIncrement
- Add cosmosdbutil.ErrNoRows to return (emulate) sql.ErrNoRows
- Add a "fake" ExclusiveWriterFake
- Add standard "getXX", "setXX" and "queryXX" to all TABLE class files
- Add specific Table SEQ for the Events table
- Add specific Table SEQ for the Rooms table
- Add specific Table SEQ for the StateSnapshot table
This commit is contained in:
alexfca 2021-05-20 14:42:33 +10:00 committed by GitHub
parent b696923333
commit 5d68daef80
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
33 changed files with 4012 additions and 1564 deletions

View file

@ -291,7 +291,7 @@ room_server:
listen: http://localhost:7770
connect: http://localhost:7770
database:
connection_string: file:roomserver.db
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;"
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@ -354,12 +354,12 @@ user_api:
listen: http://localhost:7781
connect: http://localhost:7781
account_database:
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=criticalarc.com;"
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;"
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
device_database:
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=criticalarc.com;"
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;"
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1

View file

@ -1,8 +1,8 @@
package cosmosdbapi
import (
"context"
"fmt"
)
func GetDocumentId(tenantName string, collectionName string, id string) string {
@ -11,4 +11,25 @@ func GetDocumentId(tenantName string, collectionName string, id string) string {
func GetPartitionKey(tenantName string, collectionName string) string {
return fmt.Sprintf("%s,%s", collectionName, tenantName)
}
}
func GetDocumentOrNil(connection CosmosConnection, config CosmosConfig, ctx context.Context, partitionKey string, cosmosDocId string, dbData interface{}) error {
var _, err = GetClient(connection).GetDocument(
ctx,
config.DatabaseName,
config.ContainerName,
cosmosDocId,
GetGetDocumentOptions(partitionKey),
&dbData,
)
if err != nil {
if err.Error() == "Resource that no longer exists" {
dbData = nil
return nil
}
return err
}
return nil
}

View file

@ -6,14 +6,14 @@ import (
func GetCreateDocumentOptions(pk string) cosmosapi.CreateDocumentOptions {
return cosmosapi.CreateDocumentOptions{
IsUpsert: false,
IsUpsert: false,
PartitionKeyValue: pk,
}
}
func GetUpsertDocumentOptions(pk string) cosmosapi.CreateDocumentOptions {
return cosmosapi.CreateDocumentOptions{
IsUpsert: true,
IsUpsert: true,
PartitionKeyValue: pk,
}
}
@ -21,8 +21,16 @@ func GetUpsertDocumentOptions(pk string) cosmosapi.CreateDocumentOptions {
func GetQueryDocumentsOptions(pk string) cosmosapi.QueryDocumentsOptions {
return cosmosapi.QueryDocumentsOptions{
PartitionKeyValue: pk,
IsQuery: true,
ContentType: cosmosapi.QUERY_CONTENT_TYPE,
IsQuery: true,
ContentType: cosmosapi.QUERY_CONTENT_TYPE,
}
}
func GetQueryAllPartitionsDocumentsOptions() cosmosapi.QueryDocumentsOptions {
return cosmosapi.QueryDocumentsOptions{
IsQuery: true,
EnableCrossPartition: true,
ContentType: cosmosapi.QUERY_CONTENT_TYPE,
}
}
@ -35,7 +43,7 @@ func GetGetDocumentOptions(pk string) cosmosapi.GetDocumentOptions {
func GetReplaceDocumentOptions(pk string, etag string) cosmosapi.ReplaceDocumentOptions {
return cosmosapi.ReplaceDocumentOptions{
PartitionKeyValue: pk,
IfMatch: etag,
IfMatch: etag,
}
}

View file

@ -0,0 +1,76 @@
package cosmosdbutil
import (
"context"
"fmt"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
)
type SequenceCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Value int64 `json:"_value"`
}
func GetNextSequence(
ctx context.Context,
connection cosmosdbapi.CosmosConnection,
config cosmosdbapi.CosmosConfig,
serviceName string,
tableName string,
seqId string,
initial int64,
) (int64, error) {
collName := fmt.Sprintf("%s_%s", tableName, seqId)
dbCollectionName := cosmosdbapi.GetCollectionName(serviceName, collName)
cosmosDocId := cosmosdbapi.GetDocumentId(config.ContainerName, dbCollectionName, seqId)
pk := cosmosDocId
dbData := SequenceCosmosData{}
cosmosdbapi.GetDocumentOrNil(
connection,
config,
ctx,
pk,
cosmosDocId,
&dbData,
)
if dbData.Id == "" {
dbData = SequenceCosmosData{}
dbData.Id = cosmosDocId
dbData.Pk = pk
dbData.Cn = dbCollectionName
dbData.Value = initial
var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(connection).CreateDocument(
ctx,
config.DatabaseName,
config.ContainerName,
dbData,
optionsCreate,
)
if err != nil {
return -1, err
}
} else {
dbData.Value++
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(dbData.Pk, dbData.ETag)
var _, _, err = cosmosdbapi.GetClient(connection).ReplaceDocument(
ctx,
config.DatabaseName,
config.ContainerName,
cosmosDocId,
dbData,
optionsReplace,
)
if err != nil {
return -1, err
}
}
return dbData.Value, nil
}

View file

@ -0,0 +1,12 @@
package cosmosdbutil
import (
"database/sql"
"errors"
)
// ErrNoRows is returned by Scan when QueryRow doesn't return a
// row. Used to simulate the SQL responses as its used for business logic
var ErrNoRows = sql.ErrNoRows
var ErrNotImplemented = errors.New("cosmosdb: not implemented")

View file

@ -0,0 +1,77 @@
package cosmosdbutil
import (
"database/sql"
"errors"
"github.com/matrix-org/dendrite/internal/sqlutil"
"go.uber.org/atomic"
)
// ExclusiveWriter implements sqlutil.Writer.
// Allows non-SQL DBs to still use the same infrastructure
// as Matrix assumes SQL TXN
type ExclusiveWriterFake struct {
running atomic.Bool
todo chan transactionWriterTaskFake
}
func NewExclusiveWriterFake() sqlutil.Writer {
return &ExclusiveWriterFake{
todo: make(chan transactionWriterTaskFake),
}
}
// transactionWriterTask represents a specific task.
type transactionWriterTaskFake struct {
db *sql.DB
txn *sql.Tx
f func(txn *sql.Tx) error
wait chan error
}
func (w *ExclusiveWriterFake) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
if w.todo == nil {
return errors.New("not initialised")
}
if !w.running.Load() {
go w.run()
}
task := transactionWriterTaskFake{
db: db,
txn: txn,
f: f,
wait: make(chan error, 1),
}
w.todo <- task
return <-task.wait
}
// run processes the tasks for a given transaction writer. Only one
// of these goroutines will run at a time. A transaction will be
// opened using the database object from the task and then this will
// be passed as a parameter to the task function.
func (w *ExclusiveWriterFake) run() {
if !w.running.CAS(false, true) {
return
}
// if tracingEnabled {
// gid := goid()
// goidToWriter.Store(gid, w)
// defer goidToWriter.Delete(gid)
// }
defer w.running.Store(false)
for task := range w.todo {
if task.db != nil && task.txn != nil {
task.wait <- task.f(task.txn)
// } else if task.db != nil && task.txn == nil {
// task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
// return task.f(txn)
// })
} else {
task.wait <- task.f(nil)
}
close(task.wait)
}
}

View file

@ -0,0 +1,82 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
}
func LoadAddForgottenColumn(m *sqlutil.Migrations) {
m.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
}
func UpAddForgottenColumn(tx *sql.Tx) error {
_, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp;
CREATE TABLE IF NOT EXISTS roomserver_membership (
room_nid INTEGER NOT NULL,
target_nid INTEGER NOT NULL,
sender_nid INTEGER NOT NULL DEFAULT 0,
membership_nid INTEGER NOT NULL DEFAULT 1,
event_nid INTEGER NOT NULL DEFAULT 0,
target_local BOOLEAN NOT NULL DEFAULT false,
forgotten BOOLEAN NOT NULL DEFAULT false,
UNIQUE (room_nid, target_nid)
);
INSERT
INTO roomserver_membership (
room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local
) SELECT
room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local
FROM roomserver_membership_tmp
;
DROP TABLE roomserver_membership_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownAddForgottenColumn(tx *sql.Tx) error {
_, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp;
CREATE TABLE IF NOT EXISTS roomserver_membership (
room_nid INTEGER NOT NULL,
target_nid INTEGER NOT NULL,
sender_nid INTEGER NOT NULL DEFAULT 0,
membership_nid INTEGER NOT NULL DEFAULT 1,
event_nid INTEGER NOT NULL DEFAULT 0,
target_local BOOLEAN NOT NULL DEFAULT false,
UNIQUE (room_nid, target_nid)
);
INSERT
INTO roomserver_membership (
room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local
) SELECT
room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local
FROM roomserver_membership_tmp
;
DROP TABLE roomserver_membership_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -18,76 +18,152 @@ package cosmosdb
import (
"context"
"database/sql"
"strings"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
)
const eventJSONSchema = `
CREATE TABLE IF NOT EXISTS roomserver_event_json (
event_nid INTEGER NOT NULL PRIMARY KEY,
event_json TEXT NOT NULL
);
`
// const eventJSONSchema = `
// CREATE TABLE IF NOT EXISTS roomserver_event_json (
// event_nid INTEGER NOT NULL PRIMARY KEY,
// event_json TEXT NOT NULL
// );
// `
const insertEventJSONSQL = `
INSERT OR REPLACE INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2)
`
type EventJSONCosmos struct {
EventNID int64 `json:"event_nid"`
EventJSON []byte `json:"event_json"`
}
type EventJSONCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
EventJSON EventJSONCosmos `json:"mx_roomserver_event_json"`
}
// const insertEventJSONSQL = `
// INSERT OR REPLACE INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2)
// `
// Bulk event JSON lookup by numeric event ID.
// Sort by the numeric event ID.
// This means that we can use binary search to lookup by numeric event ID.
const bulkSelectEventJSONSQL = `
SELECT event_nid, event_json FROM roomserver_event_json
WHERE event_nid IN ($1)
ORDER BY event_nid ASC
`
// SELECT event_nid, event_json FROM roomserver_event_json
// WHERE event_nid IN ($1)
// ORDER BY event_nid ASC
const bulkSelectEventJSONSQL = "" +
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_json.event_nid) " +
"order by c.mx_roomserver_event_json.event_nid asc"
type eventJSONStatements struct {
db *sql.DB
insertEventJSONStmt *sql.Stmt
bulkSelectEventJSONStmt *sql.Stmt
db *Database
// insertEventJSONStmt *sql.Stmt
bulkSelectEventJSONStmt string
tableName string
}
func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
s := &eventJSONStatements{
db: db,
}
_, err := db.Exec(eventJSONSchema)
func queryEventJSON(s *eventJSONStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventJSONCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []EventJSONCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return s, shared.StatementList{
{&s.insertEventJSONStmt, insertEventJSONSQL},
{&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL},
}.Prepare(db)
return response, nil
}
func NewCosmosDBEventJSONTable(db *Database) (tables.EventJSON, error) {
s := &eventJSONStatements{
db: db,
}
// _, err := db.Exec(eventJSONSchema)
// if err != nil {
// return nil, err
// }
// return s, shared.StatementList{
// {&s.insertEventJSONStmt, insertEventJSONSQL},
s.bulkSelectEventJSONStmt = bulkSelectEventJSONSQL
// }.Prepare(db)
s.tableName = "event_json"
return s, nil
}
func (s *eventJSONStatements) InsertEventJSON(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte,
) error {
_, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
// _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
// INSERT OR REPLACE INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := fmt.Sprintf("%d", eventNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := EventJSONCosmos{
EventNID: int64(eventNID),
EventJSON: eventJSON,
}
var dbData = EventJSONCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
EventJSON: data,
}
//Insert OR Replace
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err
}
func (s *eventJSONStatements) BulkSelectEventJSON(
ctx context.Context, eventNIDs []types.EventNID,
) ([]tables.EventJSONPair, error) {
iEventNIDs := make([]interface{}, len(eventNIDs))
for k, v := range eventNIDs {
iEventNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...)
// SELECT event_nid, event_json FROM roomserver_event_json
// WHERE event_nid IN ($1)
// ORDER BY event_nid ASC
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": eventNIDs,
}
response, err := queryEventJSON(s, ctx, s.bulkSelectEventJSONStmt, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventJSON: rows.close() failed")
// We know that we will only get as many results as event NIDs
// because of the unique constraint on event NIDs.
@ -95,13 +171,11 @@ func (s *eventJSONStatements) BulkSelectEventJSON(
// We might get fewer results than NIDs so we adjust the length of the slice before returning it.
results := make([]tables.EventJSONPair, len(eventNIDs))
i := 0
for ; rows.Next(); i++ {
for _, item := range response {
result := &results[i]
var eventNID int64
if err := rows.Scan(&eventNID, &result.EventJSON); err != nil {
return nil, err
}
result.EventNID = types.EventNID(eventNID)
result.EventNID = types.EventNID(item.EventJSON.EventNID)
result.EventJSON = item.EventJSON.EventJSON
i++
}
return results[:i], nil
}

View file

@ -0,0 +1,24 @@
package cosmosdb
import (
"context"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
)
func GetNextEventStateKeyNID(s *eventStateKeyStatements, ctx context.Context) (int64, error) {
const docId = "eventstatekeynid_seq"
//1 insert start at 2
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 2)
}
func GetNextEventTypeNID(s *eventTypeStatements, ctx context.Context) (int64, error) {
const docId = "eventtypenid_seq"
//7 inserts start at 8
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 8)
}
func GetNextEventNID(s *eventStatements, ctx context.Context) (int64, error) {
const docId = "eventnid_seq"
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1)
}

View file

@ -18,96 +18,248 @@ package cosmosdb
import (
"context"
"database/sql"
"strings"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
)
const eventStateKeysSchema = `
CREATE TABLE IF NOT EXISTS roomserver_event_state_keys (
event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT,
event_state_key TEXT NOT NULL UNIQUE
);
INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key)
VALUES (1, '')
ON CONFLICT DO NOTHING;
`
// const eventStateKeysSchema = `
// CREATE TABLE IF NOT EXISTS roomserver_event_state_keys (
// event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT,
// event_state_key TEXT NOT NULL UNIQUE
// );
// INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key)
// VALUES (1, '')
// ON CONFLICT DO NOTHING;
// `
type EventStateKeysCosmos struct {
EventStateKeyNID int64 `json:"event_state_key_nid"`
EventStateKey string `json:"event_state_key"`
}
type EventStateKeysCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
EventStateKeys EventStateKeysCosmos `json:"mx_roomserver_event_state_keys"`
}
// Same as insertEventTypeNIDSQL
const insertEventStateKeyNIDSQL = `
INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1)
ON CONFLICT DO NOTHING;
`
// const insertEventStateKeyNIDSQL = `
// INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1)
// ON CONFLICT DO NOTHING;
// `
const selectEventStateKeyNIDSQL = `
SELECT event_state_key_nid FROM roomserver_event_state_keys
WHERE event_state_key = $1
`
// SELECT event_state_key_nid FROM roomserver_event_state_keys
// WHERE event_state_key = $1
const selectEventStateKeyNIDSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_roomserver_event_state_keys.event_state_key = @x2"
// Bulk lookup from string state key to numeric ID for that state key.
// Takes an array of strings as the query parameter.
const bulkSelectEventStateKeySQL = `
SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
WHERE event_state_key IN ($1)
`
// // Bulk lookup from string state key to numeric ID for that state key.
// // Takes an array of strings as the query parameter.
// SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
// WHERE event_state_key IN ($1)
const bulkSelectEventStateKeySQL = "" +
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_state_keys.event_state_key_nid)"
// Bulk lookup from numeric ID to string state key for that state key.
// Takes an array of strings as the query parameter.
const bulkSelectEventStateKeyNIDSQL = `
SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
WHERE event_state_key_nid IN ($1)
`
// SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
// WHERE event_state_key_nid IN ($1)
const bulkSelectEventStateKeyNIDSQL = "" +
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_state_keys.event_state_key)"
type eventStateKeyStatements struct {
db *sql.DB
insertEventStateKeyNIDStmt *sql.Stmt
selectEventStateKeyNIDStmt *sql.Stmt
bulkSelectEventStateKeyNIDStmt *sql.Stmt
bulkSelectEventStateKeyStmt *sql.Stmt
db *Database
insertEventStateKeyNIDStmt string
selectEventStateKeyNIDStmt string
bulkSelectEventStateKeyNIDStmt string
bulkSelectEventStateKeyStmt string
tableName string
}
func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) {
s := &eventStateKeyStatements{
db: db,
}
_, err := db.Exec(eventStateKeysSchema)
func queryEventStateKeys(s *eventStateKeyStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventStateKeysCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []EventStateKeysCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return s, shared.StatementList{
{&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL},
{&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL},
{&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL},
{&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL},
}.Prepare(db)
return response, nil
}
func getEventStateKeys(s *eventStateKeyStatements, ctx context.Context, pk string, docId string) (*EventStateKeysCosmosData, error) {
response := EventStateKeysCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, nil
}
return &response, err
}
func NewCosmosDBEventStateKeysTable(db *Database) (tables.EventStateKeys, error) {
s := &eventStateKeyStatements{
db: db,
}
// return s, shared.StatementList{
// {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL},
s.selectEventStateKeyNIDStmt = selectEventStateKeyNIDSQL
s.bulkSelectEventStateKeyNIDStmt = bulkSelectEventStateKeyNIDSQL
s.bulkSelectEventStateKeyStmt = bulkSelectEventStateKeySQL
// }.Prepare(db)
s.tableName = "event_state_keys"
//Add in the initial data
ensureEventStateKeys(s, context.Background())
return s, nil
}
func ensureEventStateKeys(s *eventStateKeyStatements, ctx context.Context) {
// INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key)
// VALUES (1, '')
// ON CONFLICT DO NOTHING;
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// event_state_key TEXT NOT NULL UNIQUE
docId := ""
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := EventStateKeysCosmos{
EventStateKey: "",
EventStateKeyNID: 1,
}
// event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT,
dbData := EventStateKeysCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
EventStateKeys: data,
}
insertEventStateKeyCore(s, ctx, dbData)
}
func insertEventStateKeyCore(s *eventStateKeyStatements, ctx context.Context, dbData EventStateKeysCosmosData) error {
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData,
options)
if err != nil {
return err
}
return nil
}
func (s *eventStateKeyStatements) InsertEventStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt)
res, err := insertStmt.ExecContext(ctx, eventStateKey)
if err != nil {
return 0, err
// INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1)
// ON CONFLICT DO NOTHING;
if len(eventStateKey) == 0 {
return 0, cosmosdbutil.ErrNoRows
}
eventStateKeyNID, err := res.LastInsertId()
if err != nil {
return 0, err
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// event_state_key TEXT NOT NULL UNIQUE
docId := eventStateKey
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
existing, _ := getEventStateKeys(s, ctx, pk, cosmosDocId)
var dbData EventStateKeysCosmosData
if existing == nil {
//Not exists, we need to create a new one with a SEQ
eventStateKeyNIDSeq, seqErr := GetNextEventStateKeyNID(s, ctx)
if seqErr != nil {
return -1, seqErr
}
data := EventStateKeysCosmos{
EventStateKey: eventStateKey,
EventStateKeyNID: eventStateKeyNIDSeq,
}
// event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT,
dbData = EventStateKeysCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
EventStateKeys: data,
}
} else {
dbData.EventStateKeys = existing.EventStateKeys
}
return types.EventStateKeyNID(eventStateKeyNID), err
err := insertEventStateKeyCore(s, ctx, dbData)
return types.EventStateKeyNID(dbData.EventStateKeys.EventStateKeyNID), err
}
func (s *eventStateKeyStatements) SelectEventStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
var eventStateKeyNID int64
stmt := sqlutil.TxStmt(txn, s.selectEventStateKeyNIDStmt)
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
return types.EventStateKeyNID(eventStateKeyNID), err
// SELECT event_state_key_nid FROM roomserver_event_state_keys
// WHERE event_state_key = $1
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": eventStateKey,
}
response, err := queryEventStateKeys(s, ctx, s.selectEventStateKeyNIDStmt, params)
if err != nil {
return 0, err
}
//See storage.assignStateKeyNID()
if len(response) == 0 {
return 0, cosmosdbutil.ErrNoRows
}
return types.EventStateKeyNID(response[0].EventStateKeys.EventStateKeyNID), err
}
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
@ -117,21 +269,25 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
for k, v := range eventStateKeys {
iEventStateKeys[k] = v
}
selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1)
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...)
// SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
// WHERE event_state_key IN ($1)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": eventStateKeys,
}
response, err := queryEventStateKeys(s, ctx, s.bulkSelectEventStateKeyNIDStmt, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKeyNID: rows.close() failed")
result := make(map[string]types.EventStateKeyNID, len(eventStateKeys))
for rows.Next() {
var stateKey string
var stateKeyNID int64
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
return nil, err
}
result[stateKey] = types.EventStateKeyNID(stateKeyNID)
for _, item := range response {
result[item.EventStateKeys.EventStateKey] = types.EventStateKeyNID(item.EventStateKeys.EventStateKeyNID)
}
return result, nil
}
@ -139,25 +295,24 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
for k, v := range eventStateKeyNIDs {
iEventStateKeyNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1)
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...)
// SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
// WHERE event_state_key_nid IN ($1)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": eventStateKeyNIDs,
}
response, err := queryEventStateKeys(s, ctx, s.bulkSelectEventStateKeyStmt, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKey: rows.close() failed")
result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs))
for rows.Next() {
var stateKey string
var stateKeyNID int64
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
return nil, err
}
result[types.EventStateKeyNID(stateKeyNID)] = stateKey
for _, item := range response {
result[types.EventStateKeyNID(item.EventStateKeys.EventStateKeyNID)] = item.EventStateKeys.EventStateKey
}
return result, nil
}

View file

@ -18,30 +18,44 @@ package cosmosdb
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
)
const eventTypesSchema = `
CREATE TABLE IF NOT EXISTS roomserver_event_types (
event_type_nid INTEGER PRIMARY KEY AUTOINCREMENT,
event_type TEXT NOT NULL UNIQUE
);
INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES
(1, 'm.room.create'),
(2, 'm.room.power_levels'),
(3, 'm.room.join_rules'),
(4, 'm.room.third_party_invite'),
(5, 'm.room.member'),
(6, 'm.room.redaction'),
(7, 'm.room.history_visibility') ON CONFLICT DO NOTHING;
`
// const eventTypesSchema = `
// CREATE TABLE IF NOT EXISTS roomserver_event_types (
// event_type_nid INTEGER PRIMARY KEY AUTOINCREMENT,
// event_type TEXT NOT NULL UNIQUE
// );
// INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES
// (1, 'm.room.create'),
// (2, 'm.room.power_levels'),
// (3, 'm.room.join_rules'),
// (4, 'm.room.third_party_invite'),
// (5, 'm.room.member'),
// (6, 'm.room.redaction'),
// (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING;
// `
type EventTypeCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
EventType EventTypeCosmos `json:"mx_roomserver_event_type"`
}
type EventTypeCosmos struct {
EventTypeNID int64 `json:"event_type_nid"`
EventType string `json:"event_type"`
}
// Assign a new numeric event type ID.
// The usual case is that the event type is not in the database.
@ -56,105 +70,243 @@ const eventTypesSchema = `
// return it. Modifying the rows will cause postgres to assign a new tuple for the
// row even though the data doesn't change resulting in unncesssary modifications
// to the indexes.
const insertEventTypeNIDSQL = `
INSERT INTO roomserver_event_types (event_type) VALUES ($1)
ON CONFLICT DO NOTHING;
`
// const insertEventTypeNIDSQL = `
// INSERT INTO roomserver_event_types (event_type) VALUES ($1)
// ON CONFLICT DO NOTHING;
// `
const insertEventTypeNIDResultSQL = `
SELECT event_type_nid FROM roomserver_event_types
WHERE rowid = last_insert_rowid();
`
// const insertEventTypeNIDResultSQL = `
// SELECT event_type_nid FROM roomserver_event_types
// WHERE rowid = last_insert_rowid();
// `
const selectEventTypeNIDSQL = `
SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1
`
// const selectEventTypeNIDSQL = `
// SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1
// `
// Bulk lookup from string event type to numeric ID for that event type.
// Takes an array of strings as the query parameter.
const bulkSelectEventTypeNIDSQL = `
SELECT event_type, event_type_nid FROM roomserver_event_types
WHERE event_type IN ($1)
`
// SELECT event_type, event_type_nid FROM roomserver_event_types
// WHERE event_type IN ($1)
const bulkSelectEventTypeNIDSQL = "" +
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_type.event_type)"
type eventTypeStatements struct {
db *sql.DB
insertEventTypeNIDStmt *sql.Stmt
insertEventTypeNIDResultStmt *sql.Stmt
selectEventTypeNIDStmt *sql.Stmt
bulkSelectEventTypeNIDStmt *sql.Stmt
db *Database
// insertEventTypeNIDStmt *sql.Stmt
// insertEventTypeNIDResultStmt *sql.Stmt
// selectEventTypeNIDStmt *sql.Stmt
bulkSelectEventTypeNIDStmt string
tableName string
}
func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
func NewCosmosDBEventTypesTable(db *Database) (tables.EventTypes, error) {
s := &eventTypeStatements{
db: db,
}
_, err := db.Exec(eventTypesSchema)
// return s, shared.StatementList{
// {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL},
// {&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL},
// {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL},
s.bulkSelectEventTypeNIDStmt = bulkSelectEventTypeNIDSQL
// }.Prepare(db)
s.tableName = "event_types"
ensureEventTypes(s, context.Background())
return s, nil
}
func queryEventTypes(s *eventTypeStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventTypeCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []EventTypeCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return s, shared.StatementList{
{&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL},
{&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL},
{&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL},
{&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL},
}.Prepare(db)
return response, nil
}
func (s *eventTypeStatements) InsertEventTypeNID(
ctx context.Context, txn *sql.Tx, eventType string,
) (types.EventTypeNID, error) {
var eventTypeNID int64
insertStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDStmt)
resultStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDResultStmt)
_, err := insertStmt.ExecContext(ctx, eventType)
//We need to create a new one with a SEQ
eventTypeNIDSeq, seqErr := GetNextEventTypeNID(s, ctx)
if seqErr != nil {
return -1, seqErr
}
data := EventTypeCosmos{
EventType: eventType,
EventTypeNID: eventTypeNIDSeq,
}
dbData, err := insertEventTypeCore(s, ctx, data)
if err != nil {
return 0, fmt.Errorf("insertStmt.ExecContext: %w", err)
return 0, err
}
if err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID); err != nil {
return 0, fmt.Errorf("resultStmt.QueryRowContext.Scan: %w", err)
return types.EventTypeNID(dbData.EventTypeNID), err
}
func insertEventTypeCore(s *eventTypeStatements, ctx context.Context, eventType EventTypeCosmos) (*EventTypeCosmos, error) {
// INSERT INTO roomserver_event_types (event_type) VALUES ($1)
// ON CONFLICT DO NOTHING;
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
//Unique on eventType
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, eventType.EventType)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = EventTypeCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
EventType: eventType,
}
return types.EventTypeNID(eventTypeNID), err
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
if err != nil {
dbData, errGet := selectEventTypeCore(s, ctx, eventType.EventType)
if errGet != nil {
return nil, errGet
}
return dbData, nil
}
return &dbData.EventType, err
}
func ensureEventTypes(s *eventTypeStatements, ctx context.Context) error {
// INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES
// (1, 'm.room.create'),
// (2, 'm.room.power_levels'),
// (3, 'm.room.join_rules'),
// (4, 'm.room.third_party_invite'),
// (5, 'm.room.member'),
// (6, 'm.room.redaction'),
// (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING;
// (1, 'm.room.create'),
_, err := insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 1, EventType: "m.room.create"})
if err != nil {
return err
}
// (2, 'm.room.power_levels'),
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 2, EventType: "m.room.power_levels"})
if err != nil {
return err
}
// (3, 'm.room.join_rules'),
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 3, EventType: "m.room.join_rules"})
if err != nil {
return err
}
// (4, 'm.room.third_party_invite'),
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 4, EventType: "m.room.third_party_invite"})
if err != nil {
return err
}
// (5, 'm.room.member'),
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 5, EventType: "m.room.member"})
if err != nil {
return err
}
// (6, 'm.room.redaction'),
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 6, EventType: "m.room.redaction"})
if err != nil {
return err
}
// (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING;
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 7, EventType: "m.room.history_visibility"})
if err != nil {
return err
}
return nil
}
func selectEventTypeCore(s *eventTypeStatements, ctx context.Context, eventType string) (*EventTypeCosmos, error) {
var response EventTypeCosmosData
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, eventType)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
cosmosDocId,
&response)
if err != nil {
return nil, err
}
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response.EventType, nil
}
func (s *eventTypeStatements) SelectEventTypeNID(
ctx context.Context, tx *sql.Tx, eventType string,
) (types.EventTypeNID, error) {
var eventTypeNID int64
selectStmt := sqlutil.TxStmt(tx, s.selectEventTypeNIDStmt)
err := selectStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID)
return types.EventTypeNID(eventTypeNID), err
// SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1
dbData, err := selectEventTypeCore(s, ctx, eventType)
if err != nil {
return -1, err
}
return types.EventTypeNID(dbData.EventTypeNID), nil
}
func (s *eventTypeStatements) BulkSelectEventTypeNID(
ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
///////////////
iEventTypes := make([]interface{}, len(eventTypes))
for k, v := range eventTypes {
iEventTypes[k] = v
}
selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventTypes)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
///////////////
rows, err := selectPrep.QueryContext(ctx, iEventTypes...)
// SELECT event_type, event_type_nid FROM roomserver_event_types
// WHERE event_type IN ($1)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": eventTypes,
}
response, err := queryEventTypes(s, ctx, s.bulkSelectEventTypeNIDStmt, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventTypeNID: rows.close() failed")
result := make(map[string]types.EventTypeNID, len(eventTypes))
for rows.Next() {
for _, item := range response {
var eventType string
var eventTypeNID int64
if err := rows.Scan(&eventType, &eventTypeNID); err != nil {
return nil, err
}
eventType = item.EventType.EventType
eventTypeNID = item.EventType.EventTypeNID
result[eventType] = types.EventTypeNID(eventTypeNID)
}
return result, nil

File diff suppressed because it is too large Load diff

View file

@ -18,73 +18,148 @@ package cosmosdb
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
)
const inviteSchema = `
CREATE TABLE IF NOT EXISTS roomserver_invites (
invite_event_id TEXT PRIMARY KEY,
room_nid INTEGER NOT NULL,
target_nid INTEGER NOT NULL,
sender_nid INTEGER NOT NULL DEFAULT 0,
retired BOOLEAN NOT NULL DEFAULT FALSE,
invite_event_json TEXT NOT NULL
);
// const inviteSchema = `
// CREATE TABLE IF NOT EXISTS roomserver_invites (
// invite_event_id TEXT PRIMARY KEY,
// room_nid INTEGER NOT NULL,
// target_nid INTEGER NOT NULL,
// sender_nid INTEGER NOT NULL DEFAULT 0,
// retired BOOLEAN NOT NULL DEFAULT FALSE,
// invite_event_json TEXT NOT NULL
// );
CREATE INDEX IF NOT EXISTS roomserver_invites_active_idx ON roomserver_invites (target_nid, room_nid)
WHERE NOT retired;
`
const insertInviteEventSQL = "" +
"INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," +
" sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT DO NOTHING"
// CREATE INDEX IF NOT EXISTS roomserver_invites_active_idx ON roomserver_invites (target_nid, room_nid)
// WHERE NOT retired;
// `
type InviteCosmos struct {
InviteEventID string `json:"invite_event_id"`
RoomNID int64 `json:"room_nid"`
TargetNID int64 `json:"target_nid"`
SenderNID int64 `json:"sender_nid"`
Retired bool `json:"retired"`
InviteEventJSON []byte `json:"invite_event_json"`
}
type InviteCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Invite InviteCosmos `json:"mx_roomserver_invite"`
}
// const insertInviteEventSQL = "" +
// "INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," +
// " sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" +
// " ON CONFLICT DO NOTHING"
// "SELECT invite_event_id, sender_nid FROM roomserver_invites" +
// " WHERE target_nid = $1 AND room_nid = $2" +
// " AND NOT retired"
const selectInviteActiveForUserInRoomSQL = "" +
"SELECT invite_event_id, sender_nid FROM roomserver_invites" +
" WHERE target_nid = $1 AND room_nid = $2" +
" AND NOT retired"
"select * from c where c._cn = @x1 " +
" and c.mx_roomserver_invite.target_nid = @x2" +
" and c.mx_roomserver_invite.room_nid = @x3" +
" and c.mx_roomserver_invite.retired = false"
// Retire every active invite for a user in a room.
// Ideally we'd know which invite events were retired by a given update so we
// wouldn't need to remove every active invite.
// However the matrix protocol doesn't give us a way to reliably identify the
// invites that were retired, so we are forced to retire all of them.
const updateInviteRetiredSQL = `
UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
`
// const updateInviteRetiredSQL = `
// UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
// `
const selectInvitesAboutToRetireSQL = `
SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
`
// SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
const selectInvitesAboutToRetireSQL = "" +
"select * from c where c._cn = @x1 " +
" and c.mx_roomserver_invite.room_nid = @x2" +
" and c.mx_roomserver_invite.target_nid = @x3" +
" and c.mx_roomserver_invite.retired = false"
type inviteStatements struct {
db *sql.DB
insertInviteEventStmt *sql.Stmt
selectInviteActiveForUserInRoomStmt *sql.Stmt
updateInviteRetiredStmt *sql.Stmt
selectInvitesAboutToRetireStmt *sql.Stmt
db *Database
// insertInviteEventStmt *sql.Stmt
selectInviteActiveForUserInRoomStmt string
// updateInviteRetiredStmt *sql.Stmt
selectInvitesAboutToRetireStmt string
tableName string
}
func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) {
s := &inviteStatements{
db: db,
}
_, err := db.Exec(inviteSchema)
func queryInvite(s *inviteStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InviteCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []InviteCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
return s, shared.StatementList{
{&s.insertInviteEventStmt, insertInviteEventSQL},
{&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL},
{&s.updateInviteRetiredStmt, updateInviteRetiredSQL},
{&s.selectInvitesAboutToRetireStmt, selectInvitesAboutToRetireSQL},
}.Prepare(db)
func getInvite(s *inviteStatements, ctx context.Context, pk string, docId string) (*InviteCosmosData, error) {
response := InviteCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func setInvite(s *inviteStatements, ctx context.Context, invite InviteCosmosData) (*InviteCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(invite.Pk, invite.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
invite.Id,
&invite,
optionsReplace)
return &invite, ex
}
func NewCosmosDBInvitesTable(db *Database) (tables.Invites, error) {
s := &inviteStatements{
db: db,
}
// return s, shared.StatementList{
// {&s.insertInviteEventStmt, insertInviteEventSQL},
s.selectInviteActiveForUserInRoomStmt = selectInviteActiveForUserInRoomSQL
// {&s.updateInviteRetiredStmt, updateInviteRetiredSQL},
s.selectInvitesAboutToRetireStmt = selectInvitesAboutToRetireSQL
// }.Prepare(db)
s.tableName = "invites"
return s, nil
}
func (s *inviteStatements) InsertInviteEvent(
@ -93,42 +168,84 @@ func (s *inviteStatements) InsertInviteEvent(
targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte,
) (bool, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
result, err := stmt.ExecContext(
ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
)
// "INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," +
// " sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" +
// " ON CONFLICT DO NOTHING"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
data := InviteCosmos{
InviteEventID: inviteEventID,
InviteEventJSON: inviteEventJSON,
Retired: false,
RoomNID: int64(roomNID),
SenderNID: int64(senderUserNID),
TargetNID: int64(targetUserNID),
}
// invite_event_id TEXT PRIMARY KEY,
docId := inviteEventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = InviteCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
Invite: data,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
if err != nil {
return false, err
}
count, err = result.RowsAffected()
if err != nil {
return false, err
}
return count != 0, err
// TODO: Is this important?
// count, err = result.RowsAffected()
// return count != 0, err
return true, nil
}
func (s *inviteStatements) UpdateInviteRetired(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) {
// "SELECT invite_event_id, sender_nid FROM roomserver_invites" +
// " WHERE target_nid = $1 AND room_nid = $2" +
// " AND NOT retired"
// gather all the event IDs we will retire
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": targetUserNID,
"@x3": roomNID,
}
response, err := queryInvite(s, ctx, s.selectInvitesAboutToRetireStmt, params)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed")
for rows.Next() {
var inviteEventID string
if err = rows.Scan(&inviteEventID); err != nil {
return
}
eventIDs = append(eventIDs, inviteEventID)
for _, item := range response {
eventIDs = append(eventIDs, item.Invite.InviteEventID)
// UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
// now retire the invites
item.Invite.Retired = true
_, err = setInvite(s, ctx, item)
}
// now retire the invites
stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
_, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
return
}
@ -137,21 +254,27 @@ func (s *inviteStatements) SelectInviteActiveForUserInRoom(
ctx context.Context,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, []string, error) {
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
ctx, targetUserNID, roomNID,
)
// SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomNID,
"@x3": targetUserNID,
}
response, err := queryInvite(s, ctx, s.selectInviteActiveForUserInRoomStmt, params)
if err != nil {
return nil, nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed")
var result []types.EventStateKeyNID
var eventIDs []string
for rows.Next() {
var eventID string
var senderUserNID int64
if err := rows.Scan(&eventID, &senderUserNID); err != nil {
return nil, nil, err
}
for _, item := range response {
var eventID = item.Invite.InviteEventID
var senderUserNID = item.Invite.SenderNID
result = append(result, types.EventStateKeyNID(senderUserNID))
eventIDs = append(eventIDs, eventID)
}

View file

@ -19,125 +19,233 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
)
const membershipSchema = `
CREATE TABLE IF NOT EXISTS roomserver_membership (
room_nid INTEGER NOT NULL,
target_nid INTEGER NOT NULL,
sender_nid INTEGER NOT NULL DEFAULT 0,
membership_nid INTEGER NOT NULL DEFAULT 1,
event_nid INTEGER NOT NULL DEFAULT 0,
target_local BOOLEAN NOT NULL DEFAULT false,
forgotten BOOLEAN NOT NULL DEFAULT false,
UNIQUE (room_nid, target_nid)
);
`
// const membershipSchema = `
// CREATE TABLE IF NOT EXISTS roomserver_membership (
// room_nid INTEGER NOT NULL,
// target_nid INTEGER NOT NULL,
// sender_nid INTEGER NOT NULL DEFAULT 0,
// membership_nid INTEGER NOT NULL DEFAULT 1,
// event_nid INTEGER NOT NULL DEFAULT 0,
// target_local BOOLEAN NOT NULL DEFAULT false,
// forgotten BOOLEAN NOT NULL DEFAULT false,
// UNIQUE (room_nid, target_nid)
// );
// `
type MembershipCosmos struct {
RoomNID int64 `json:"room_nid"`
TargetNID int64 `json:"target_nid"`
SenderNID int64 `json:"sender_nid"`
MembershipNID int64 `json:"membership_nid"`
EventNID int64 `json:"event_nid"`
TargetLocal bool `json:"target_local"`
Forgotten bool `json:"forgotten"`
}
type MembershipCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Membership MembershipCosmos `json:"mx_roomserver_membership"`
}
type MembershipJoinedCountCosmosData struct {
TargetNID int64 `json:"target_nid"`
RoomCount int `json:"room_count"`
}
// "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" +
// " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
// " GROUP BY target_nid"
var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
"select c.mx_roomserver_membership.target_nid, count(c.mx_roomserver_membership.room_id) as room_count from c where c._cn = @x1 " +
" and ARRAY_CONTAINS(@x2, c.mx_roomserver_membership.room_id)" +
" and c.mx_roomserver_membership.membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
" and c.mx_roomserver_membership.forgotten = false" +
" group by c.mx_roomserver_membership.target_nid"
// Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE
const insertMembershipSQL = "" +
"INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" +
" VALUES ($1, $2, $3)" +
" ON CONFLICT DO NOTHING"
// const insertMembershipSQL = "" +
// "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" +
// " VALUES ($1, $2, $3)" +
// " ON CONFLICT DO NOTHING"
const selectMembershipFromRoomAndTargetSQL = "" +
"SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2"
// const selectMembershipFromRoomAndTargetSQL = "" +
// "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" +
// " WHERE room_nid = $1 AND target_nid = $2"
// "SELECT event_nid FROM roomserver_membership" +
// " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false"
const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false"
"select * from c where c._cn = @x1 " +
" and c.mx_roomserver_membership.room_nid = @x2" +
" and c.mx_roomserver_membership.membership_nid = @x3" +
" and c.mx_roomserver_membership.forgotten = false"
// "SELECT event_nid FROM roomserver_membership" +
// " WHERE room_nid = $1 AND membership_nid = $2" +
// " AND target_local = true and forgotten = false"
const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2" +
" AND target_local = true and forgotten = false"
"select * from c where c._cn = @x1 " +
" and c.mx_roomserver_membership.room_nid = @x2" +
" and c.mx_roomserver_membership.membership_nid = @x3" +
" and c.mx_roomserver_membership.target_local = true" +
" and c.mx_roomserver_membership.forgotten = false"
// "SELECT event_nid FROM roomserver_membership" +
// " WHERE room_nid = $1 and forgotten = false"
const selectMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 and forgotten = false"
"select * from c where c._cn = @x1 " +
" and c.mx_roomserver_membership.room_nid = @x2" +
" and c.mx_roomserver_membership.forgotten = false"
// "SELECT event_nid FROM roomserver_membership" +
// " WHERE room_nid = $1" +
// " AND target_local = true and forgotten = false"
const selectLocalMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1" +
" AND target_local = true and forgotten = false"
"select * from c where c._cn = @x1 " +
" and c.mx_roomserver_membership.room_nid = @x2" +
" and c.mx_roomserver_membership.target_local = true" +
" and c.mx_roomserver_membership.forgotten = false"
const selectMembershipForUpdateSQL = "" +
"SELECT membership_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2"
// const selectMembershipForUpdateSQL = "" +
// "SELECT membership_nid FROM roomserver_membership" +
// " WHERE room_nid = $1 AND target_nid = $2"
const updateMembershipSQL = "" +
"UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" +
" WHERE room_nid = $5 AND target_nid = $6"
// const updateMembershipSQL = "" +
// "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" +
// " WHERE room_nid = $5 AND target_nid = $6"
const updateMembershipForgetRoom = "" +
"UPDATE roomserver_membership SET forgotten = $1" +
" WHERE room_nid = $2 AND target_nid = $3"
// const updateMembershipForgetRoom = "" +
// "UPDATE roomserver_membership SET forgotten = $1" +
// " WHERE room_nid = $2 AND target_nid = $3"
// "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
const selectRoomsWithMembershipSQL = "" +
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
"select * from c where c._cn = @x1 " +
" and c.mx_roomserver_membership.membership_nid = @x2" +
" and c.mx_roomserver_membership.target_nid = true" +
" and c.mx_roomserver_membership.forgotten = false"
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
// joined to. Since this information is used to populate the user directory, we will
// only return users that the user would ordinarily be able to see anyway.
var selectKnownUsersSQL = "" +
"SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " +
"roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
" WHERE room_nid IN (" +
" SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3"
// var selectKnownUsersSQL = "" +
// "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " +
// "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
// " WHERE room_nid IN (" +
// " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
// ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3"
var selectKnownUsersSQLRooms = "" +
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_membership.room_id)"
var selectKnownUsersSQLDistinctRoom = "" +
"select distinct top @x4 c.mx_roomserver_membership.room_nid as room_nid from c where c._cn = @x1 " +
"and c.mx_roomserver_membership.target_nid = @x2 " +
"and c.mx_roomserver_membership.membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " " +
"and contains(c.mx_roomserver_membership.event_state_key, @x3) "
type membershipStatements struct {
db *sql.DB
insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt
selectMembershipFromRoomAndTargetStmt *sql.Stmt
selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt
selectMembershipsFromRoomStmt *sql.Stmt
selectLocalMembershipsFromRoomStmt *sql.Stmt
selectRoomsWithMembershipStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
selectKnownUsersStmt *sql.Stmt
updateMembershipForgetRoomStmt *sql.Stmt
db *Database
// insertMembershipStmt *sql.Stmt
// selectMembershipForUpdateStmt string
// selectMembershipFromRoomAndTargetStmt string
selectMembershipsFromRoomAndMembershipStmt string
selectLocalMembershipsFromRoomAndMembershipStmt string
selectMembershipsFromRoomStmt string
selectLocalMembershipsFromRoomStmt string
selectRoomsWithMembershipStmt string
// updateMembershipStmt *sql.Stmt
// selectKnownUsersStmt string
// updateMembershipForgetRoomStmt *sql.Stmt
tableName string
}
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
func queryMembership(s *membershipStatements, ctx context.Context, qry string, params map[string]interface{}) ([]MembershipCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []MembershipCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getMembership(s *membershipStatements, ctx context.Context, pk string, docId string) (*MembershipCosmosData, error) {
response := MembershipCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func setMembership(s *membershipStatements, ctx context.Context, pk string, membership MembershipCosmosData) (*MembershipCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, membership.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
membership.Id,
&membership,
optionsReplace)
return &membership, ex
}
func NewCosmosDBMembershipTable(db *Database) (tables.Membership, error) {
s := &membershipStatements{
db: db,
}
return s, shared.StatementList{
{&s.insertMembershipStmt, insertMembershipSQL},
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
{&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL},
{&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL},
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL},
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
{&s.selectKnownUsersStmt, selectKnownUsersSQL},
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
}.Prepare(db)
}
// return s, shared.StatementList{
// {&s.insertMembershipStmt, insertMembershipSQL},
// s.selectMembershipForUpdateStmt = selectMembershipForUpdateSQL
// s.selectMembershipFromRoomAndTargetStmt = selectMembershipFromRoomAndTargetSQL
s.selectMembershipsFromRoomAndMembershipStmt = selectMembershipsFromRoomAndMembershipSQL
s.selectLocalMembershipsFromRoomAndMembershipStmt = selectLocalMembershipsFromRoomAndMembershipSQL
s.selectMembershipsFromRoomStmt = selectMembershipsFromRoomSQL
s.selectLocalMembershipsFromRoomStmt = selectLocalMembershipsFromRoomSQL
// {&s.updateMembershipStmt, updateMembershipSQL},
s.selectRoomsWithMembershipStmt = selectRoomsWithMembershipSQL
// {&s.selectKnownUsersStmt, selectKnownUsersSQL},
// {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
// }.Prepare(db)
func (s *membershipStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(membershipSchema)
return err
s.tableName = "memberships"
return s, nil
}
func (s *membershipStatements) InsertMembership(
@ -145,8 +253,45 @@ func (s *membershipStatements) InsertMembership(
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
localTarget bool,
) error {
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
// "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" +
// " VALUES ($1, $2, $3)" +
// " ON CONFLICT DO NOTHING"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (room_nid, target_nid)
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := MembershipCosmos{
EventNID: 0,
Forgotten: false,
MembershipNID: 1,
RoomNID: int64(roomNID),
SenderNID: 0,
TargetLocal: false,
TargetNID: int64(targetUserNID),
}
var dbData = MembershipCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
Membership: data,
}
// " ON CONFLICT DO NOTHING"
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err
}
@ -154,10 +299,18 @@ func (s *membershipStatements) SelectMembershipForUpdate(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (membership tables.MembershipState, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt)
err = stmt.QueryRowContext(
ctx, roomNID, targetUserNID,
).Scan(&membership)
// "SELECT membership_nid FROM roomserver_membership" +
// " WHERE room_nid = $1 AND target_nid = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response, err := getMembership(s, ctx, pk, cosmosDocId)
if response != nil {
membership = tables.MembershipState(response.Membership.MembershipNID)
}
return
}
@ -165,9 +318,20 @@ func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
ctx context.Context,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID, &forgotten)
// "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" +
// " WHERE room_nid = $1 AND target_nid = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response, err := getMembership(s, ctx, pk, cosmosDocId)
if response != nil {
eventNID = types.EventNID(response.Membership.EventNID)
forgotten = response.Membership.Forgotten
membership = tables.MembershipState(response.Membership.MembershipNID)
}
return
}
@ -175,24 +339,31 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
ctx context.Context,
roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
var selectStmt *sql.Stmt
var selectStmt string
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomNID,
}
if localOnly {
// "SELECT event_nid FROM roomserver_membership" +
// " WHERE room_nid = $1" +
// " AND target_local = true and forgotten = false"
selectStmt = s.selectLocalMembershipsFromRoomStmt
} else {
// "SELECT event_nid FROM roomserver_membership" +
// " WHERE room_nid = $1 and forgotten = false"
selectStmt = s.selectMembershipsFromRoomStmt
}
rows, err := selectStmt.QueryContext(ctx, roomNID)
response, err := queryMembership(s, ctx, selectStmt, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoom: rows.close() failed")
for rows.Next() {
var eNID types.EventNID
if err = rows.Scan(&eNID); err != nil {
return
}
eventNIDs = append(eventNIDs, eNID)
for _, item := range response {
eventNIDs = append(eventNIDs, types.EventNID(item.Membership.EventNID))
}
return
}
@ -201,24 +372,31 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
ctx context.Context,
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
var stmt *sql.Stmt
var stmt string
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomNID,
"@x3": membership,
}
if localOnly {
// "SELECT event_nid FROM roomserver_membership" +
// " WHERE room_nid = $1 AND membership_nid = $2" +
// " AND target_local = true and forgotten = false"
stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt
} else {
// "SELECT event_nid FROM roomserver_membership" +
// " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false"
stmt = s.selectMembershipsFromRoomAndMembershipStmt
}
rows, err := stmt.QueryContext(ctx, roomNID, membership)
response, err := queryMembership(s, ctx, stmt, params)
if err != nil {
return
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoomAndMembership: rows.close() failed")
for rows.Next() {
var eNID types.EventNID
if err = rows.Scan(&eNID); err != nil {
return
}
eventNIDs = append(eventNIDs, eNID)
for _, item := range response {
eventNIDs = append(eventNIDs, types.EventNID(item.Membership.EventNID))
}
return
}
@ -228,28 +406,48 @@ func (s *membershipStatements) UpdateMembership(
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
eventNID types.EventNID, forgotten bool,
) error {
stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
_, err := stmt.ExecContext(
ctx, senderUserNID, membership, eventNID, forgotten, roomNID, targetUserNID,
)
// "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" +
// " WHERE room_nid = $5 AND target_nid = $6"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
dbData, err := getMembership(s, ctx, pk, cosmosDocId)
if err != nil {
return err
}
dbData.Membership.SenderNID = int64(senderUserNID)
dbData.Membership.MembershipNID = int64(membership)
dbData.Membership.EventNID = int64(eventNID)
dbData.Membership.Forgotten = forgotten
_, err = setMembership(s, ctx, pk, *dbData)
return err
}
func (s *membershipStatements) SelectRoomsWithMembership(
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) {
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
// "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": membershipState,
"@x3": userID,
}
response, err := queryMembership(s, ctx, s.selectRoomsWithMembershipStmt, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed")
var roomNIDs []types.RoomNID
for rows.Next() {
var roomNID types.RoomNID
if err := rows.Scan(&roomNID); err != nil {
return nil, err
}
roomNIDs = append(roomNIDs, roomNID)
for _, item := range response {
roomNIDs = append(roomNIDs, types.RoomNID(item.Membership.RoomNID))
}
return roomNIDs, nil
}
@ -259,39 +457,136 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
for i, v := range roomNIDs {
iRoomNIDs[i] = v
}
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...)
// "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" +
// " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
// " GROUP BY target_nid"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomNIDs,
}
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []MembershipJoinedCountCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectJoinedUsersSetForRoomsSQL, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
result := make(map[types.EventStateKeyNID]int)
for rows.Next() {
var userID types.EventStateKeyNID
var count int
if err := rows.Scan(&userID, &count); err != nil {
return nil, err
}
for _, item := range response {
userID := types.EventStateKeyNID(item.TargetNID)
count := item.RoomCount
result[userID] = count
}
return result, rows.Err()
return result, nil
}
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
// " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
// ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": userID,
"@x3": searchString,
"@x4": limit,
}
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var responseDistinctRoom []MembershipCosmos
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectKnownUsersSQLDistinctRoom, params) //
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&responseDistinctRoom,
optionsQry)
if err != nil {
return nil, err
}
rooms := []int64{}
for _, item := range responseDistinctRoom {
rooms = append(rooms, item.RoomNID)
}
// "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " +
// "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
// " WHERE room_nid IN (" +
params = map[string]interface{}{
"@x1": dbCollectionName,
"@x2": rooms,
}
var responseRooms []MembershipCosmos
query = cosmosdbapi.GetQuery(selectKnownUsersSQLRooms, params)
_, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&responseRooms,
optionsQry)
if err != nil {
return nil, err
}
targetNIDs := []int64{}
for _, item := range responseRooms {
targetNIDs = append(targetNIDs, item.TargetNID)
}
// HACK: Joined table
var dbCollectionNameEventStateKeys = cosmosdbapi.GetCollectionName(s.db.databaseName, "event_state_keys")
params = map[string]interface{}{
"@x1": dbCollectionNameEventStateKeys,
"@x2": targetNIDs,
}
bulkSelectEventStateKeyStmt := "select * from c where c._cn = @x1 and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_state_keys.event_state_key_nid)"
var responseEventStateKeys []EventStateKeysCosmos
query = cosmosdbapi.GetQuery(bulkSelectEventStateKeyStmt, params)
_, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&responseEventStateKeys,
optionsQry)
if err != nil {
return nil, err
}
// SELECT DISTINCT event_state_key
result := []string{}
defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed")
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
for _, item := range responseEventStateKeys {
userID := item.EventStateKey
result = append(result, userID)
}
return result, rows.Err()
return result, nil
}
func (s *membershipStatements) UpdateForgetMembership(
@ -299,8 +594,22 @@ func (s *membershipStatements) UpdateForgetMembership(
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
forget bool,
) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
ctx, forget, roomNID, targetUserNID,
)
// "UPDATE roomserver_membership SET forgotten = $1" +
// " WHERE room_nid = $2 AND target_nid = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
dbData, err := getMembership(s, ctx, pk, cosmosDocId)
if err != nil {
return err
}
dbData.Membership.Forgotten = forget
_, err = setMembership(s, ctx, pk, *dbData)
return err
}

View file

@ -20,9 +20,12 @@ import (
"database/sql"
"fmt"
"strings"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
)
@ -32,59 +35,90 @@ import (
// In Postgres an empty BYTEA field is not NULL so it's fine there. In SQLite it
// seems to care that it's empty and therefore hits a NOT NULL constraint on insert.
// We should really work out what the right thing to do here is.
const previousEventSchema = `
CREATE TABLE IF NOT EXISTS roomserver_previous_events (
previous_event_id TEXT NOT NULL,
previous_reference_sha256 BLOB,
event_nids TEXT NOT NULL,
UNIQUE (previous_event_id, previous_reference_sha256)
);
`
// const previousEventSchema = `
// CREATE TABLE IF NOT EXISTS roomserver_previous_events (
// previous_event_id TEXT NOT NULL,
// previous_reference_sha256 BLOB,
// event_nids TEXT NOT NULL,
// UNIQUE (previous_event_id, previous_reference_sha256)
// );
// `
type PreviousEventCosmos struct {
PreviousEventID string `json:"previous_event_id"`
PreviousReferenceSha256 []byte `json:"previous_reference_sha256"`
EventNIDs string `json:"event_nids"`
}
type PreviousEventCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
PreviousEvent PreviousEventCosmos `json:"mx_roomserver_previous_event"`
}
// Insert an entry into the previous_events table.
// If there is already an entry indicating that an event references that previous event then
// add the event NID to the list to indicate that this event references that previous event as well.
// This should only be modified while holding a "FOR UPDATE" lock on the row in the rooms table for this room.
// The lock is necessary to avoid data races when checking whether an event is already referenced by another event.
const insertPreviousEventSQL = `
INSERT OR REPLACE INTO roomserver_previous_events
(previous_event_id, previous_reference_sha256, event_nids)
VALUES ($1, $2, $3)
`
// const insertPreviousEventSQL = `
// INSERT OR REPLACE INTO roomserver_previous_events
// (previous_event_id, previous_reference_sha256, event_nids)
// VALUES ($1, $2, $3)
// `
const selectPreviousEventNIDsSQL = `
SELECT event_nids FROM roomserver_previous_events
WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
`
// const selectPreviousEventNIDsSQL = `
// SELECT event_nids FROM roomserver_previous_events
// WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
// `
// Check if the event is referenced by another event in the table.
// This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room.
const selectPreviousEventExistsSQL = `
SELECT 1 FROM roomserver_previous_events
WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
`
// const selectPreviousEventExistsSQL = `
// SELECT 1 FROM roomserver_previous_events
// WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
// `
type previousEventStatements struct {
db *sql.DB
insertPreviousEventStmt *sql.Stmt
selectPreviousEventNIDsStmt *sql.Stmt
selectPreviousEventExistsStmt *sql.Stmt
db *Database
// insertPreviousEventStmt *sql.Stmt
// selectPreviousEventNIDsStmt *sql.Stmt
// selectPreviousEventExistsStmt *sql.Stmt
tableName string
}
func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
func getPreviousEvent(s *previousEventStatements, ctx context.Context, pk string, docId string) (*PreviousEventCosmosData, error) {
response := PreviousEventCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func NewCosmosDBPrevEventsTable(db *Database) (tables.PreviousEvents, error) {
s := &previousEventStatements{
db: db,
}
_, err := db.Exec(previousEventSchema)
if err != nil {
return nil, err
}
return s, shared.StatementList{
{&s.insertPreviousEventStmt, insertPreviousEventSQL},
{&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL},
{&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
}.Prepare(db)
// return s, shared.StatementList{
// {&s.insertPreviousEventStmt, insertPreviousEventSQL},
// {&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL},
// {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
// }.Prepare(db)
s.tableName = "previous_events"
return s, nil
}
func (s *previousEventStatements) InsertPreviousEvent(
@ -94,28 +128,71 @@ func (s *previousEventStatements) InsertPreviousEvent(
previousEventReferenceSHA256 []byte,
eventNID types.EventNID,
) error {
var eventNIDs string
eventNIDAsString := fmt.Sprintf("%d", eventNID)
selectStmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt)
err := selectStmt.QueryRowContext(ctx, previousEventID, previousEventReferenceSHA256).Scan(&eventNIDs)
if err != nil && err != sql.ErrNoRows {
return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (previous_event_id, previous_reference_sha256)
// TODO: Check value
// docId := fmt.Sprintf("%s_%s", previousEventID, previousEventReferenceSHA256)
docId := previousEventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// SELECT 1 FROM roomserver_previous_events
// WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
existing, err := getPreviousEvent(s, ctx, pk, cosmosDocId)
if err != nil {
if err != cosmosdbutil.ErrNoRows {
return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err)
}
}
var dbData PreviousEventCosmosData
// Doesnt exist, create a new one
if existing == nil {
data := PreviousEventCosmos{
EventNIDs: "",
PreviousEventID: previousEventID,
PreviousReferenceSha256: previousEventReferenceSHA256,
}
dbData = PreviousEventCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
PreviousEvent: data,
}
} else {
dbData = *existing
}
var nids []string
if eventNIDs != "" {
nids = strings.Split(eventNIDs, ",")
if dbData.PreviousEvent.EventNIDs != "" {
nids = strings.Split(dbData.PreviousEvent.EventNIDs, ",")
for _, nid := range nids {
if nid == eventNIDAsString {
return nil
}
}
eventNIDs = strings.Join(append(nids, eventNIDAsString), ",")
dbData.PreviousEvent.EventNIDs = strings.Join(append(nids, eventNIDAsString), ",")
} else {
eventNIDs = eventNIDAsString
dbData.PreviousEvent.EventNIDs = eventNIDAsString
}
insertStmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
_, err = insertStmt.ExecContext(
ctx, previousEventID, previousEventReferenceSHA256, eventNIDs,
// INSERT OR REPLACE INTO roomserver_previous_events
// (previous_event_id, previous_reference_sha256, event_nids)
// VALUES ($1, $2, $3)
var optionsReplace = cosmosdbapi.GetUpsertDocumentOptions(pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
optionsReplace,
)
return err
}
@ -125,7 +202,24 @@ func (s *previousEventStatements) InsertPreviousEvent(
func (s *previousEventStatements) SelectPreviousEventExists(
ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte,
) error {
var ok int64
stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt)
return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (previous_event_id, previous_reference_sha256)
// TODO: Check value
// docId := fmt.Sprintf("%s_%s", previousEventID, previousEventReferenceSHA256)
docId := eventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, string(docId))
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// SELECT 1 FROM roomserver_previous_events
// WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
dbData, err := getPreviousEvent(s, ctx, pk, cosmosDocId)
if err != nil {
return err
}
if dbData == nil {
return cosmosdbutil.ErrNoRows
}
return nil
}

View file

@ -17,89 +17,199 @@ package cosmosdb
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
)
const publishedSchema = `
-- Stores which rooms are published in the room directory
CREATE TABLE IF NOT EXISTS roomserver_published (
-- The room ID of the room
room_id TEXT NOT NULL PRIMARY KEY,
-- Whether it is published or not
published BOOLEAN NOT NULL DEFAULT false
);
`
// const publishedSchema = `
// -- Stores which rooms are published in the room directory
// CREATE TABLE IF NOT EXISTS roomserver_published (
// -- The room ID of the room
// room_id TEXT NOT NULL PRIMARY KEY,
// -- Whether it is published or not
// published BOOLEAN NOT NULL DEFAULT false
// );
// `
const upsertPublishedSQL = "" +
"INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)"
const selectAllPublishedSQL = "" +
"SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC"
const selectPublishedSQL = "" +
"SELECT published FROM roomserver_published WHERE room_id = $1"
type publishedStatements struct {
db *sql.DB
upsertPublishedStmt *sql.Stmt
selectAllPublishedStmt *sql.Stmt
selectPublishedStmt *sql.Stmt
type PublishCosmos struct {
RoomID string `json:"room_id"`
Published bool `json:"published"`
}
func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) {
s := &publishedStatements{
db: db,
}
_, err := db.Exec(publishedSchema)
type PublishCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Publish PublishCosmos `json:"mx_roomserver_publish"`
}
// const upsertPublishedSQL = "" +
// "INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)"
// "SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC"
const selectAllPublishedSQL = "" +
"select * from c where c._cn = @x1 " +
" and c.mx_roomserver_publish.published = @x2" +
" order by c.mx_roomserver_publish.room_id asc"
// const selectPublishedSQL = "" +
// "SELECT published FROM roomserver_published WHERE room_id = $1"
type publishedStatements struct {
db *Database
// upsertPublishedStmt *sql.Stmt
selectAllPublishedStmt string
// selectPublishedStmt *sql.Stmt
tableName string
}
func queryPublish(s *publishedStatements, ctx context.Context, qry string, params map[string]interface{}) ([]PublishCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []PublishCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return s, shared.StatementList{
{&s.upsertPublishedStmt, upsertPublishedSQL},
{&s.selectAllPublishedStmt, selectAllPublishedSQL},
{&s.selectPublishedStmt, selectPublishedSQL},
}.Prepare(db)
return response, nil
}
func getPublish(s *publishedStatements, ctx context.Context, pk string, docId string) (*PublishCosmosData, error) {
response := PublishCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func setPublish(s *publishedStatements, ctx context.Context, pk string, publish PublishCosmosData) (*PublishCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, publish.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
publish.Id,
&publish,
optionsReplace)
return &publish, ex
}
func NewCosmosDBPublishedTable(db *Database) (tables.Published, error) {
s := &publishedStatements{
db: db,
}
// _, err := db.Exec(publishedSchema)
// if err != nil {
// return nil, err
// }
// return s, shared.StatementList{
// {&s.upsertPublishedStmt, upsertPublishedSQL},
s.selectAllPublishedStmt = selectAllPublishedSQL
// {&s.selectPublishedStmt, selectPublishedSQL},
// }.Prepare(db)
s.tableName = "published"
return s, nil
}
func (s *publishedStatements) UpsertRoomPublished(
ctx context.Context, txn *sql.Tx, roomID string, published bool,
) error {
stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt)
_, err := stmt.ExecContext(ctx, roomID, published)
// "INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// room_id TEXT NOT NULL PRIMARY KEY,
docId := roomID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := PublishCosmos{
RoomID: roomID,
Published: false,
}
var dbData = PublishCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
Publish: data,
}
// "INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)"
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err
}
func (s *publishedStatements) SelectPublishedFromRoomID(
ctx context.Context, roomID string,
) (published bool, err error) {
err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published)
if err == sql.ErrNoRows {
return false, nil
// "SELECT published FROM roomserver_published WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// room_id TEXT NOT NULL PRIMARY KEY,
docId := roomID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response, err := getPublish(s, ctx, pk, cosmosDocId)
if err != nil {
return false, err
}
return
return response.Publish.Published, nil
}
func (s *publishedStatements) SelectAllPublishedRooms(
ctx context.Context, published bool,
) ([]string, error) {
rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published)
// "SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": published,
}
response, err := queryPublish(s, ctx, s.selectAllPublishedStmt, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectAllPublishedStmt: rows.close() failed")
var roomIDs []string
for rows.Next() {
var roomID string
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
roomIDs = append(roomIDs, roomID)
for _, item := range response {
roomIDs = append(roomIDs, item.Publish.RoomID)
}
return roomIDs, rows.Err()
return roomIDs, nil
}

View file

@ -17,84 +17,207 @@ package cosmosdb
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
)
const redactionsSchema = `
-- Stores information about the redacted state of events.
-- We need to track redactions rather than blindly updating the event JSON table on receipt of a redaction
-- because we might receive the redaction BEFORE we receive the event which it redacts (think backfill).
CREATE TABLE IF NOT EXISTS roomserver_redactions (
redaction_event_id TEXT PRIMARY KEY,
redacts_event_id TEXT NOT NULL,
-- Initially FALSE, set to TRUE when the redaction has been validated according to rooms v3+ spec
-- https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events
validated BOOLEAN NOT NULL
);
`
// const redactionsSchema = `
// -- Stores information about the redacted state of events.
// -- We need to track redactions rather than blindly updating the event JSON table on receipt of a redaction
// -- because we might receive the redaction BEFORE we receive the event which it redacts (think backfill).
// CREATE TABLE IF NOT EXISTS roomserver_redactions (
// redaction_event_id TEXT PRIMARY KEY,
// redacts_event_id TEXT NOT NULL,
// -- Initially FALSE, set to TRUE when the redaction has been validated according to rooms v3+ spec
// -- https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events
// validated BOOLEAN NOT NULL
// );
// `
const insertRedactionSQL = "" +
"INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" +
" VALUES ($1, $2, $3)"
const selectRedactionInfoByRedactionEventIDSQL = "" +
"SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
" WHERE redaction_event_id = $1"
const selectRedactionInfoByEventBeingRedactedSQL = "" +
"SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
" WHERE redacts_event_id = $1"
const markRedactionValidatedSQL = "" +
" UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1"
type redactionStatements struct {
db *sql.DB
insertRedactionStmt *sql.Stmt
selectRedactionInfoByRedactionEventIDStmt *sql.Stmt
selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt
markRedactionValidatedStmt *sql.Stmt
type RedactionCosmos struct {
RedactionEventID string `json:"redaction_event_id"`
RedactsEventID string `json:"redacts_event_id"`
Validated bool `json:"validated"`
}
func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) {
s := &redactionStatements{
db: db,
}
_, err := db.Exec(redactionsSchema)
type RedactionCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Redaction RedactionCosmos `json:"mx_roomserver_redaction"`
}
// const insertRedactionSQL = "" +
// "INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" +
// " VALUES ($1, $2, $3)"
// const selectRedactionInfoByRedactionEventIDSQL = "" +
// "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
// " WHERE redaction_event_id = $1"
// "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
// " WHERE redacts_event_id = $1"
const selectRedactionInfoByEventBeingRedactedSQL = "" +
"select * from c where c._cn = @x1 " +
" and c.mx_roomserver_redaction.redacts_event_id = @x2"
// const markRedactionValidatedSQL = "" +
// " UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1"
type redactionStatements struct {
db *Database
// insertRedactionStmt *sql.Stmt
// selectRedactionInfoByRedactionEventIDStmt *sql.Stmt
selectRedactionInfoByEventBeingRedactedStmt string
// markRedactionValidatedStmt *sql.Stmt
tableName string
}
func queryRedaction(s *redactionStatements, ctx context.Context, qry string, params map[string]interface{}) ([]RedactionCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []RedactionCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
return s, shared.StatementList{
{&s.insertRedactionStmt, insertRedactionSQL},
{&s.selectRedactionInfoByRedactionEventIDStmt, selectRedactionInfoByRedactionEventIDSQL},
{&s.selectRedactionInfoByEventBeingRedactedStmt, selectRedactionInfoByEventBeingRedactedSQL},
{&s.markRedactionValidatedStmt, markRedactionValidatedSQL},
}.Prepare(db)
func getRedaction(s *redactionStatements, ctx context.Context, pk string, docId string) (*RedactionCosmosData, error) {
response := RedactionCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func setRedaction(s *redactionStatements, ctx context.Context, pk string, redaction RedactionCosmosData) (*RedactionCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, redaction.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
redaction.Id,
&redaction,
optionsReplace)
return &redaction, ex
}
func NewCosmosDBRedactionsTable(db *Database) (tables.Redactions, error) {
s := &redactionStatements{
db: db,
}
// return s, shared.StatementList{
// {&s.insertRedactionStmt, insertRedactionSQL},
// {&s.selectRedactionInfoByRedactionEventIDStmt, selectRedactionInfoByRedactionEventIDSQL},
s.selectRedactionInfoByEventBeingRedactedStmt = selectRedactionInfoByEventBeingRedactedSQL
// {&s.markRedactionValidatedStmt, markRedactionValidatedSQL},
// }.Prepare(db)
s.tableName = "redactions"
return s, nil
}
func (s *redactionStatements) InsertRedaction(
ctx context.Context, txn *sql.Tx, info tables.RedactionInfo,
) error {
stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt)
_, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated)
return err
// "INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" +
// " VALUES ($1, $2, $3)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// redaction_event_id TEXT PRIMARY KEY,
docId := info.RedactionEventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := RedactionCosmos{
RedactionEventID: info.RedactionEventID,
RedactsEventID: info.RedactsEventID,
Validated: info.Validated,
}
var dbData = RedactionCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
Redaction: data,
}
// "INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" +
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
// TODO: Just forDebug - Remove exception
if err != nil {
return err
}
//Ignore Error
return nil
}
func (s *redactionStatements) SelectRedactionInfoByRedactionEventID(
ctx context.Context, txn *sql.Tx, redactionEventID string,
) (info *tables.RedactionInfo, err error) {
info = &tables.RedactionInfo{}
stmt := sqlutil.TxStmt(txn, s.selectRedactionInfoByRedactionEventIDStmt)
err = stmt.QueryRowContext(ctx, redactionEventID).Scan(
&info.RedactionEventID, &info.RedactsEventID, &info.Validated,
)
if err == sql.ErrNoRows {
// "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
// " WHERE redaction_event_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// redaction_event_id TEXT PRIMARY KEY,
docId := redactionEventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response, err := getRedaction(s, ctx, pk, cosmosDocId)
if err != nil {
info = nil
err = err
return
}
if response == nil {
info = nil
err = nil
return
}
info = &tables.RedactionInfo{
RedactionEventID: response.Redaction.RedactionEventID,
RedactsEventID: response.Redaction.RedactsEventID,
Validated: response.Redaction.Validated,
}
return
}
@ -102,14 +225,31 @@ func (s *redactionStatements) SelectRedactionInfoByRedactionEventID(
func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted(
ctx context.Context, txn *sql.Tx, eventID string,
) (info *tables.RedactionInfo, err error) {
info = &tables.RedactionInfo{}
stmt := sqlutil.TxStmt(txn, s.selectRedactionInfoByEventBeingRedactedStmt)
err = stmt.QueryRowContext(ctx, eventID).Scan(
&info.RedactionEventID, &info.RedactsEventID, &info.Validated,
)
if err == sql.ErrNoRows {
// "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
// " WHERE redacts_event_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": eventID,
}
response, err := queryRedaction(s, ctx, s.selectRedactionInfoByEventBeingRedactedStmt, params)
if err != nil {
return nil, err
}
if len(response) == 0 {
info = nil
err = nil
return
}
// TODO: Check this is ok to return the 1st one
*info = tables.RedactionInfo{
RedactionEventID: response[0].Redaction.RedactionEventID,
RedactsEventID: response[0].Redaction.RedactsEventID,
Validated: response[0].Redaction.Validated,
}
return
}
@ -117,7 +257,22 @@ func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted(
func (s *redactionStatements) MarkRedactionValidated(
ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool,
) error {
stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt)
_, err := stmt.ExecContext(ctx, redactionEventID, validated)
// " UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// redaction_event_id TEXT PRIMARY KEY,
docId := redactionEventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response, err := getRedaction(s, ctx, pk, cosmosDocId)
if err != nil {
return err
}
response.Redaction.Validated = validated
_, err = setRedaction(s, ctx, pk, *response)
return err
}

View file

@ -18,84 +18,185 @@ package cosmosdb
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
)
const roomAliasesSchema = `
CREATE TABLE IF NOT EXISTS roomserver_room_aliases (
alias TEXT NOT NULL PRIMARY KEY,
room_id TEXT NOT NULL,
creator_id TEXT NOT NULL
);
// const roomAliasesSchema = `
// CREATE TABLE IF NOT EXISTS roomserver_room_aliases (
// alias TEXT NOT NULL PRIMARY KEY,
// room_id TEXT NOT NULL,
// creator_id TEXT NOT NULL
// );
CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id);
`
// CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id);
// `
const insertRoomAliasSQL = `
INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3)
`
const selectRoomIDFromAliasSQL = `
SELECT room_id FROM roomserver_room_aliases WHERE alias = $1
`
const selectAliasesFromRoomIDSQL = `
SELECT alias FROM roomserver_room_aliases WHERE room_id = $1
`
const selectCreatorIDFromAliasSQL = `
SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1
`
const deleteRoomAliasSQL = `
DELETE FROM roomserver_room_aliases WHERE alias = $1
`
type roomAliasesStatements struct {
db *sql.DB
insertRoomAliasStmt *sql.Stmt
selectRoomIDFromAliasStmt *sql.Stmt
selectAliasesFromRoomIDStmt *sql.Stmt
selectCreatorIDFromAliasStmt *sql.Stmt
deleteRoomAliasStmt *sql.Stmt
type RoomAliasCosmos struct {
Alias string `json:"alias"`
RoomID string `json:"room_id"`
CreatorID string `json:"creator_id"`
}
func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
s := &roomAliasesStatements{
db: db,
}
_, err := db.Exec(roomAliasesSchema)
type RoomAliasCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
RoomAlias RoomAliasCosmos `json:"mx_roomserver_room_alias"`
}
// const insertRoomAliasSQL = `
// INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3)
// `
// const selectRoomIDFromAliasSQL = `
// SELECT room_id FROM roomserver_room_aliases WHERE alias = $1
// `
// SELECT alias FROM roomserver_room_aliases WHERE room_id = $1
const selectAliasesFromRoomIDSQL = `
select * from c where c._cn = @x1 and c.mx_roomserver_room_alias.room_id = @x2
`
// const selectCreatorIDFromAliasSQL = `
// SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1
// `
// const deleteRoomAliasSQL = `
// DELETE FROM roomserver_room_aliases WHERE alias = $1
// `
type roomAliasesStatements struct {
db *Database
// insertRoomAliasStmt *sql.Stmt
// selectRoomIDFromAliasStmt string
selectAliasesFromRoomIDStmt string
// selectCreatorIDFromAliasStmt string
// deleteRoomAliasStmt *sql.Stmt
tableName string
}
func queryRoomAlias(s *roomAliasesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]RoomAliasCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []RoomAliasCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return s, shared.StatementList{
{&s.insertRoomAliasStmt, insertRoomAliasSQL},
{&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL},
{&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL},
{&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL},
{&s.deleteRoomAliasStmt, deleteRoomAliasSQL},
}.Prepare(db)
return response, nil
}
func getRoomAlias(s *roomAliasesStatements, ctx context.Context, pk string, docId string) (*RoomAliasCosmosData, error) {
response := RoomAliasCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func NewCosmosDBRoomAliasesTable(db *Database) (tables.RoomAliases, error) {
s := &roomAliasesStatements{
db: db,
}
// _, err := db.Exec(roomAliasesSchema)
// if err != nil {
// return nil, err
// }
// return s, shared.StatementList{
// {&s.insertRoomAliasStmt, insertRoomAliasSQL},
// {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL},
s.selectAliasesFromRoomIDStmt = selectAliasesFromRoomIDSQL
// {&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL},
// {&s.deleteRoomAliasStmt, deleteRoomAliasSQL},
// }.Prepare(db)
s.tableName = "room_aliases"
return s, nil
}
func (s *roomAliasesStatements) InsertRoomAlias(
ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string,
) error {
stmt := sqlutil.TxStmt(txn, s.insertRoomAliasStmt)
_, err := stmt.ExecContext(ctx, alias, roomID, creatorUserID)
// INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3)
data := RoomAliasCosmos{
Alias: alias,
CreatorID: creatorUserID,
RoomID: roomID,
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// alias TEXT NOT NULL PRIMARY KEY,
docId := alias
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = RoomAliasCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
RoomAlias: data,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err
}
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
ctx context.Context, alias string,
) (roomID string, err error) {
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
if err == sql.ErrNoRows {
// SELECT room_id FROM roomserver_room_aliases WHERE alias = $1
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// alias TEXT NOT NULL PRIMARY KEY,
docId := alias
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response, err := getRoomAlias(s, ctx, pk, cosmosDocId)
if err != nil {
return "", err
}
if response == nil {
return "", nil
}
roomID = response.RoomAlias.RoomID
return
}
@ -103,20 +204,23 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
ctx context.Context, roomID string,
) (aliases []string, err error) {
aliases = []string{}
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
if err != nil {
return
// SELECT alias FROM roomserver_room_aliases WHERE room_id = $1
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed")
response, err := queryRoomAlias(s, ctx, s.selectAliasesFromRoomIDStmt, params)
for rows.Next() {
var alias string
if err = rows.Scan(&alias); err != nil {
return
}
if err != nil {
return nil, err
}
aliases = append(aliases, alias)
for _, item := range response {
aliases = append(aliases, item.RoomAlias.Alias)
}
return
@ -125,17 +229,48 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
ctx context.Context, alias string,
) (creatorID string, err error) {
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
if err == sql.ErrNoRows {
// SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// alias TEXT NOT NULL PRIMARY KEY,
docId := alias
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response, err := getRoomAlias(s, ctx, pk, cosmosDocId)
if err != nil {
return "", err
}
if response == nil {
return "", nil
}
creatorID = response.RoomAlias.CreatorID
return
}
func (s *roomAliasesStatements) DeleteRoomAlias(
ctx context.Context, txn *sql.Tx, alias string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteRoomAliasStmt)
_, err := stmt.ExecContext(ctx, alias)
// DELETE FROM roomserver_room_aliases WHERE alias = $1
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := alias
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
cosmosDocId,
options)
if err != nil {
return err
}
return err
}

View file

@ -0,0 +1,12 @@
package cosmosdb
import (
"context"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
)
func GetNextRoomNID(s *roomStatements, ctx context.Context) (int64, error) {
const docId = "roomnid_seq"
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1)
}

View file

@ -18,128 +18,227 @@ package cosmosdb
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
const roomsSchema = `
CREATE TABLE IF NOT EXISTS roomserver_rooms (
room_nid INTEGER PRIMARY KEY AUTOINCREMENT,
room_id TEXT NOT NULL UNIQUE,
latest_event_nids TEXT NOT NULL DEFAULT '[]',
last_event_sent_nid INTEGER NOT NULL DEFAULT 0,
state_snapshot_nid INTEGER NOT NULL DEFAULT 0,
room_version TEXT NOT NULL
);
`
// const roomsSchema = `
// CREATE TABLE IF NOT EXISTS roomserver_rooms (
// room_nid INTEGER PRIMARY KEY AUTOINCREMENT,
// room_id TEXT NOT NULL UNIQUE,
// latest_event_nids TEXT NOT NULL DEFAULT '[]',
// last_event_sent_nid INTEGER NOT NULL DEFAULT 0,
// state_snapshot_nid INTEGER NOT NULL DEFAULT 0,
// room_version TEXT NOT NULL
// );
// `
// Same as insertEventTypeNIDSQL
const insertRoomNIDSQL = `
INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2)
ON CONFLICT DO NOTHING;
`
const selectRoomNIDSQL = "" +
"SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
const selectLatestEventNIDsSQL = "" +
"SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
const selectLatestEventNIDsForUpdateSQL = "" +
"SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
const updateLatestEventNIDsSQL = "" +
"UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4"
const selectRoomVersionsForRoomNIDsSQL = "" +
"SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)"
const selectRoomInfoSQL = "" +
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
const selectRoomIDsSQL = "" +
"SELECT room_id FROM roomserver_rooms"
const bulkSelectRoomIDsSQL = "" +
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
const bulkSelectRoomNIDsSQL = "" +
"SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)"
type roomStatements struct {
db *sql.DB
insertRoomNIDStmt *sql.Stmt
selectRoomNIDStmt *sql.Stmt
selectLatestEventNIDsStmt *sql.Stmt
selectLatestEventNIDsForUpdateStmt *sql.Stmt
updateLatestEventNIDsStmt *sql.Stmt
//selectRoomVersionForRoomNIDStmt *sql.Stmt
selectRoomInfoStmt *sql.Stmt
selectRoomIDsStmt *sql.Stmt
type RoomCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Room RoomCosmos `json:"mx_roomserver_room"`
}
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
type RoomCosmos struct {
RoomNID int64 `json:"room_nid"`
RoomID string `json:"room_id"`
LatestEventNIDs []int64 `json:"latest_event_nids"`
LastEventSentNID int64 `json:"last_event_sent_nid"`
StateSnapshotNID int64 `json:"state_snapshot_nid"`
RoomVersion string `json:"room_version"`
}
// Same as insertEventTypeNIDSQL
// const insertRoomNIDSQL = `
// INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2)
// ON CONFLICT DO NOTHING;
// `
// "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
// const selectRoomNIDSQL = "" +
// "select * from c where c._cn = @x1 and c.mx_roomserver_room.room_nid = @x1"
// "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
const selectLatestEventNIDsSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_roomserver_room.room_nid = @x2"
// "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
const selectLatestEventNIDsForUpdateSQL = "" +
"select * from c where c._cn = @x1 " +
" and c.mx_roomserver_room.room_nid = @x2"
// const updateLatestEventNIDsSQL = "" +
// "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4"
// "SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)"
const selectRoomVersionsForRoomNIDsSQL = "" +
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_room.room_nid)"
// "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
// const selectRoomInfoSQL = "" +
// "select * from c where c._cn = @x1 and c.mx_roomserver_room.room_id = @x2"
// "SELECT room_id FROM roomserver_rooms"
const selectRoomIDsSQL = "" +
"select * from c where c._cn = @x1"
// "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
const bulkSelectRoomIDsSQL = "" +
"select * from c where c._cn = @x1 " +
" and ARRAY_CONTAINS(@x2, c.mx_roomserver_room.room_nid)"
// "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)"
const bulkSelectRoomNIDsSQL = "" +
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_room.room_nid)"
type roomStatements struct {
db *Database
// insertRoomNIDStmt *sql.Stmt
// selectRoomNIDStmt string
selectLatestEventNIDsStmt string
selectLatestEventNIDsForUpdateStmt string
updateLatestEventNIDsStmt string
selectRoomVersionForRoomNIDStmt string
// selectRoomInfoStmt *sql.Stmt
selectRoomIDsStmt string
tableName string
}
func NewCosmosDBRoomsTable(db *Database) (tables.Rooms, error) {
s := &roomStatements{
db: db,
}
_, err := db.Exec(roomsSchema)
// return s, shared.StatementList{
// {&s.insertRoomNIDStmt, insertRoomNIDSQL},
// {&s.selectRoomNIDStmt, selectRoomNIDSQL},
s.selectLatestEventNIDsStmt = selectLatestEventNIDsSQL
s.selectLatestEventNIDsForUpdateStmt = selectLatestEventNIDsForUpdateSQL
// {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
//{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL},
// {&s.selectRoomInfoStmt, selectRoomInfoSQL},
s.selectRoomIDsStmt = selectRoomIDsSQL
// }.Prepare(db)
s.tableName = "rooms"
return s, nil
}
func mapToRoomEventNIDArray(eventNIDs []int64) []types.EventNID {
result := []types.EventNID{}
for i := 0; i < len(eventNIDs); i++ {
result = append(result, types.EventNID(eventNIDs[i]))
}
return result
}
func queryRoom(s *roomStatements, ctx context.Context, qry string, params map[string]interface{}) ([]RoomCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []RoomCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return s, shared.StatementList{
{&s.insertRoomNIDStmt, insertRoomNIDSQL},
{&s.selectRoomNIDStmt, selectRoomNIDSQL},
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
//{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL},
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
}.Prepare(db)
return response, nil
}
func getRoom(s *roomStatements, ctx context.Context, pk string, docId string) (*RoomCosmosData, error) {
response := RoomCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func setRoom(s *roomStatements, ctx context.Context, pk string, room RoomCosmosData) (*RoomCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, room.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
room.Id,
&room,
optionsReplace)
return &room, ex
}
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
// "SELECT room_id FROM roomserver_rooms"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
}
response, err := queryRoom(s, ctx, s.selectRoomIDsStmt, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
var roomIDs []string
for rows.Next() {
var roomID string
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
roomIDs = append(roomIDs, roomID)
for _, item := range response {
roomIDs = append(roomIDs, item.Room.RoomID)
}
return roomIDs, nil
}
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
var info types.RoomInfo
var latestNIDsJSON string
err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan(
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON,
)
info := types.RoomInfo{}
// "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// room_id TEXT NOT NULL UNIQUE,
docId := roomID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
room, err := getRoom(s, ctx, pk, cosmosDocId)
if err != nil {
if err == sql.ErrNoRows {
if err == cosmosdbutil.ErrNoRows {
return nil, nil
}
return nil, err
}
var latestNIDs []int64
if err = json.Unmarshal([]byte(latestNIDsJSON), &latestNIDs); err != nil {
return nil, err
}
info.IsStub = len(latestNIDs) == 0
info.RoomVersion = gomatrixserverlib.RoomVersion(room.Room.RoomVersion)
info.RoomNID = types.RoomNID(room.Room.RoomNID)
info.StateSnapshotNID = types.StateSnapshotNID(room.Room.StateSnapshotNID)
info.IsStub = len(room.Room.LatestEventNIDs) == 0
return &info, err
}
@ -147,60 +246,135 @@ func (s *roomStatements) InsertRoomNID(
ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (roomNID types.RoomNID, err error) {
insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
_, err = insertStmt.ExecContext(ctx, roomID, roomVersion)
if err != nil {
return 0, fmt.Errorf("insertStmt.ExecContext: %w", err)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2)
// ON CONFLICT DO NOTHING;
// room_id TEXT NOT NULL UNIQUE,
docId := roomID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
dbData, errGet := getRoom(s, ctx, pk, cosmosDocId)
if errGet == cosmosdbutil.ErrNoRows {
// room_nid INTEGER PRIMARY KEY AUTOINCREMENT,
roomNIDSeq, seqErr := GetNextRoomNID(s, ctx)
if seqErr != nil {
return 0, seqErr
}
data := RoomCosmos{
RoomNID: int64(roomNIDSeq),
RoomID: roomID,
RoomVersion: string(roomVersion),
}
dbData = &RoomCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
Room: data,
}
}
roomNID, err = s.SelectRoomNID(ctx, txn, roomID)
// ON CONFLICT DO NOTHING; - Do Upsert
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
if err != nil {
return 0, fmt.Errorf("s.SelectRoomNID: %w", err)
}
roomNID = types.RoomNID(dbData.Room.RoomNID)
return
}
func (s *roomStatements) SelectRoomNID(
ctx context.Context, txn *sql.Tx, roomID string,
) (types.RoomNID, error) {
var roomNID int64
stmt := sqlutil.TxStmt(txn, s.selectRoomNIDStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
return types.RoomNID(roomNID), err
// "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// room_id TEXT NOT NULL UNIQUE,
docId := roomID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
room, err := getRoom(s, ctx, pk, cosmosDocId)
if err != nil {
return 0, err
}
if room == nil {
return 0, nil
}
return types.RoomNID(room.Room.RoomNID), err
}
func (s *roomStatements) SelectLatestEventNIDs(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
) ([]types.EventNID, types.StateSnapshotNID, error) {
var eventNIDs []types.EventNID
var nidsJSON string
var stateSnapshotNID int64
stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt)
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nidsJSON, &stateSnapshotNID)
// "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomNID,
}
response, err := queryRoom(s, ctx, s.selectLatestEventNIDsStmt, params)
if err != nil {
return nil, 0, err
}
if err := json.Unmarshal([]byte(nidsJSON), &eventNIDs); err != nil {
return nil, 0, err
// TODO: Check the error handling
if len(response) == 0 {
return nil, 0, cosmosdbutil.ErrNoRows
}
return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil
//Assume 1 per RoomNID
room := response[0]
return mapToRoomEventNIDArray(room.Room.LatestEventNIDs), types.StateSnapshotNID(room.Room.StateSnapshotNID), nil
}
func (s *roomStatements) SelectLatestEventsNIDsForUpdate(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) {
var eventNIDs []types.EventNID
var nidsJSON string
var lastEventSentNID int64
var stateSnapshotNID int64
stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt)
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nidsJSON, &lastEventSentNID, &stateSnapshotNID)
// "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomNID,
}
response, err := queryRoom(s, ctx, s.selectLatestEventNIDsForUpdateStmt, params)
if err != nil {
return nil, 0, 0, err
}
if err := json.Unmarshal([]byte(nidsJSON), &eventNIDs); err != nil {
return nil, 0, 0, err
// TODO: Check the error handling
if len(response) == 0 {
return nil, 0, 0, cosmosdbutil.ErrNoRows
}
return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil
//Assume 1 per RoomNID
room := response[0]
return mapToRoomEventNIDArray(room.Room.LatestEventNIDs), types.EventNID(room.Room.LastEventSentNID), types.StateSnapshotNID(room.Room.StateSnapshotNID), nil
}
func (s *roomStatements) UpdateLatestEventNIDs(
@ -211,86 +385,113 @@ func (s *roomStatements) UpdateLatestEventNIDs(
lastEventSentNID types.EventNID,
stateSnapshotNID types.StateSnapshotNID,
) error {
stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt)
_, err := stmt.ExecContext(
ctx,
eventNIDsAsArray(eventNIDs),
int64(lastEventSentNID),
int64(stateSnapshotNID),
roomNID,
)
// "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomNID,
}
response, err := queryRoom(s, ctx, s.selectLatestEventNIDsForUpdateStmt, params)
if err != nil {
return err
}
// TODO: Check the error handling
if len(response) == 0 {
return cosmosdbutil.ErrNoRows
}
//Assume 1 per RoomNID
room := response[0]
room.Room.LatestEventNIDs = mapFromEventNIDArray(eventNIDs)
room.Room.LastEventSentNID = int64(lastEventSentNID)
room.Room.StateSnapshotNID = int64(stateSnapshotNID)
_, err = setRoom(s, ctx, room.Pk, room)
return err
}
func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
ctx context.Context, roomNIDs []types.RoomNID,
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
sqlPrep, err := s.db.Prepare(sqlStr)
if roomNIDs == nil || len(roomNIDs) == 0 {
return make(map[types.RoomNID]gomatrixserverlib.RoomVersion), nil
}
// "SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomNIDs,
}
response, err := queryRoom(s, ctx, selectRoomVersionsForRoomNIDsSQL, params)
if err != nil {
return nil, err
}
iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs {
iRoomNIDs[i] = v
}
rows, err := sqlPrep.QueryContext(ctx, iRoomNIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
for rows.Next() {
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
return nil, err
}
result[roomNID] = roomVersion
for _, item := range response {
result[types.RoomNID(item.Room.RoomNID)] = gomatrixserverlib.RoomVersion(item.Room.RoomVersion)
}
return result, nil
}
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs {
iRoomNIDs[i] = v
if roomNIDs == nil || len(roomNIDs) == 0 {
return []string{}, nil
}
sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...)
// "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomNIDs,
}
response, err := queryRoom(s, ctx, bulkSelectRoomIDsSQL, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
var roomIDs []string
for rows.Next() {
var roomID string
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
roomIDs = append(roomIDs, roomID)
for _, item := range response {
roomIDs = append(roomIDs, item.Room.RoomID)
}
return roomIDs, nil
}
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
iRoomIDs := make([]interface{}, len(roomIDs))
for i, v := range roomIDs {
iRoomIDs[i] = v
if roomIDs == nil || len(roomIDs) == 0 {
return []types.RoomNID{}, nil
}
sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...)
// "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomIDs,
}
response, err := queryRoom(s, ctx, bulkSelectRoomNIDsSQL, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
var roomNIDs []types.RoomNID
for rows.Next() {
var roomNID types.RoomNID
if err = rows.Scan(&roomNID); err != nil {
return nil, err
}
roomNIDs = append(roomNIDs, roomNID)
for _, item := range response {
roomNIDs = append(roomNIDs, types.RoomNID(item.Room.RoomNID))
}
return roomNIDs, nil
}

View file

@ -20,33 +20,54 @@ import (
"database/sql"
"fmt"
"sort"
"strings"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/util"
)
const stateDataSchema = `
CREATE TABLE IF NOT EXISTS roomserver_state_block (
state_block_nid INTEGER NOT NULL,
event_type_nid INTEGER NOT NULL,
event_state_key_nid INTEGER NOT NULL,
event_nid INTEGER NOT NULL,
UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
);
`
// const stateDataSchema = `
// CREATE TABLE IF NOT EXISTS roomserver_state_block (
// state_block_nid INTEGER NOT NULL,
// event_type_nid INTEGER NOT NULL,
// event_state_key_nid INTEGER NOT NULL,
// event_nid INTEGER NOT NULL,
// UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
// );
// `
const insertStateDataSQL = "" +
"INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
" VALUES ($1, $2, $3, $4)"
type StateBlockCosmos struct {
StateBlockNID int64 `json:"state_block_nid"`
EventTypeNID int64 `json:"event_type_nid"`
EventStateKeyNID int64 `json:"event_state_key_nid"`
EventNID int64 `json:"event_nid"`
}
const selectNextStateBlockNIDSQL = `
SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block
`
type StateBlockCosmosMaxNID struct {
Max int64 `json:"maxstateblocknid"`
}
type StateBlockCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
StateBlock StateBlockCosmos `json:"mx_roomserver_state_block"`
}
// const insertStateDataSQL = "" +
// "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
// " VALUES ($1, $2, $3, $4)"
// SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block
const selectNextStateBlockNIDSQL = "" +
"select sub.maxinner != null ? sub.maxinner + 1 : 1 as maxstateblocknid " +
"from " +
"(select MAX(c.mx_roomserver_state_block.state_block_nid) maxinner from c where c._cn = @x1) as sub"
// Bulk state lookup by numeric state block ID.
// Sort by the state_block_nid, event_type_nid, event_state_key_nid
@ -54,10 +75,17 @@ SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block
// together in the list and those entries will sorted by event_type_nid
// and event_state_key_nid. This property makes it easier to merge two
// state data blocks together.
// "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
// " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
const bulkSelectStateBlockEntriesSQL = "" +
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
" FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_state_block.state_block_nid) " +
"order by c.mx_roomserver_state_block.state_block_nid " +
// Cant do multi field order by - The order by query does not have a corresponding composite index that it can be served from
// ", c.mx_roomserver_state_block.event_type_nid " +
// ", c.mx_roomserver_state_block.event_state_key_nid " +
" asc"
// Bulk state lookup by numeric state block ID.
// Filters the rows in each block to the requested types and state keys.
@ -66,35 +94,126 @@ const bulkSelectStateBlockEntriesSQL = "" +
// of types and a list state_keys. So we have to filter the result in the
// application to restrict it to the list of event types and state keys we
// actually wanted.
// "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
// " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
// " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" +
// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
const bulkSelectFilteredStateBlockEntriesSQL = "" +
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
" FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
" AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" +
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_state_block.state_block_nid) " +
"and ARRAY_CONTAINS(@x3, c.mx_roomserver_state_block.event_type_nid) " +
"and ARRAY_CONTAINS(@x4, c.mx_roomserver_state_block.event_state_key_nid) " +
"order by c.mx_roomserver_state_block.state_block_nid " +
// Cant do multi field order by - The order by query does not have a corresponding composite index that it can be served from
// ", c.mx_roomserver_state_block.event_type_nid " +
// ", c.mx_roomserver_state_block.event_state_key_nid " +
"asc"
type stateBlockStatements struct {
db *sql.DB
insertStateDataStmt *sql.Stmt
selectNextStateBlockNIDStmt *sql.Stmt
bulkSelectStateBlockEntriesStmt *sql.Stmt
bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
db *Database
// insertStateDataStmt *sql.Stmt
selectNextStateBlockNIDStmt string
bulkSelectStateBlockEntriesStmt string
bulkSelectFilteredStateBlockEntriesStmt string
tableName string
}
func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
s := &stateBlockStatements{
db: db,
}
_, err := db.Exec(stateDataSchema)
func queryStateBlock(s *stateBlockStatements, ctx context.Context, qry string, params map[string]interface{}) ([]StateBlockCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []StateBlockCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
return s, shared.StatementList{
{&s.insertStateDataStmt, insertStateDataSQL},
{&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
{&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL},
{&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL},
}.Prepare(db)
func NewCosmosDBStateBlockTable(db *Database) (tables.StateBlock, error) {
s := &stateBlockStatements{
db: db,
}
// return s, shared.StatementList{
// {&s.insertStateDataStmt, insertStateDataSQL},
s.selectNextStateBlockNIDStmt = selectNextStateBlockNIDSQL
s.bulkSelectStateBlockEntriesStmt = bulkSelectStateBlockEntriesSQL
s.bulkSelectFilteredStateBlockEntriesStmt = bulkSelectFilteredStateBlockEntriesSQL
// }.Prepare(db)
s.tableName = "state_block"
return s, nil
}
func inertStateBlockCore(s *stateBlockStatements, ctx context.Context, stateBlockNID types.StateBlockNID, entry types.StateEntry) error {
// "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
// " VALUES ($1, $2, $3, $4)"
data := StateBlockCosmos{
EventNID: int64(entry.EventNID),
EventStateKeyNID: int64(entry.EventStateKeyNID),
EventTypeNID: int64(entry.EventTypeNID),
StateBlockNID: int64(stateBlockNID),
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
docId := fmt.Sprintf("%d_%d_%d", data.StateBlockNID, data.EventTypeNID, data.EventStateKeyNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = StateBlockCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
StateBlock: data,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err
}
func getNextStateBlockNID(s *stateBlockStatements, ctx context.Context) (int64, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var stateBlockNext []StateBlockCosmosMaxNID
params := map[string]interface{}{
"@x1": dbCollectionName,
}
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
var query = cosmosdbapi.GetQuery(s.selectNextStateBlockNIDStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&stateBlockNext,
optionsQry)
if err != nil {
return 0, err
}
return stateBlockNext[0].Max, nil
}
func (s *stateBlockStatements) BulkInsertStateData(
@ -104,75 +223,64 @@ func (s *stateBlockStatements) BulkInsertStateData(
if len(entries) == 0 {
return 0, nil
}
var stateBlockNID types.StateBlockNID
err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
if err != nil {
return 0, err
nextID, errNextID := getNextStateBlockNID(s, ctx)
if errNextID != nil {
return 0, errNextID
}
stateBlockNID := types.StateBlockNID(nextID)
for _, entry := range entries {
_, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext(
ctx,
int64(stateBlockNID),
int64(entry.EventTypeNID),
int64(entry.EventStateKeyNID),
int64(entry.EventNID),
)
err := inertStateBlockCore(s, ctx, stateBlockNID, entry)
if err != nil {
return 0, err
}
}
return stateBlockNID, err
return stateBlockNID, nil
}
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
nids := make([]interface{}, len(stateBlockNIDs))
for k, v := range stateBlockNIDs {
nids[k] = v
// "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
// " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var response []StateBlockCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": stateBlockNIDs,
}
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
selectStmt, err := s.db.Prepare(selectOrig)
response, err := queryStateBlock(s, ctx, s.bulkSelectStateBlockEntriesStmt, params)
if err != nil {
return nil, err
}
rows, err := selectStmt.QueryContext(ctx, nids...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed")
results := make([]types.StateEntryList, len(stateBlockNIDs))
// current is a pointer to the StateEntryList to append the state entries to.
var current *types.StateEntryList
i := 0
for rows.Next() {
var (
stateBlockNID int64
eventTypeNID int64
eventStateKeyNID int64
eventNID int64
entry types.StateEntry
)
if err := rows.Scan(
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
); err != nil {
return nil, err
}
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
entry.EventNID = types.EventNID(eventNID)
if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
for _, item := range response {
entry := types.StateEntry{}
entry.EventTypeNID = types.EventTypeNID(item.StateBlock.EventTypeNID)
entry.EventStateKeyNID = types.EventStateKeyNID(item.StateBlock.EventStateKeyNID)
entry.EventNID = types.EventNID(item.StateBlock.EventNID)
if current == nil || types.StateBlockNID(item.StateBlock.StateBlockNID) != current.StateBlockNID {
// The state entry row is for a different state data block to the current one.
// So we start appending to the next entry in the list.
current = &results[i]
current.StateBlockNID = types.StateBlockNID(stateBlockNID)
current.StateBlockNID = types.StateBlockNID(item.StateBlock.StateBlockNID)
i++
}
current.StateEntries = append(current.StateEntries, entry)
}
if i != len(nids) {
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids))
if i != len(stateBlockNIDs) {
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs))
}
return results, nil
}
@ -187,34 +295,33 @@ func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1)
sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
// sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1)
// sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
// sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
var params []interface{}
for _, val := range stateBlockNIDs {
params = append(params, int64(val))
}
for _, val := range eventTypeNIDArray {
params = append(params, val)
}
for _, val := range eventStateKeyNIDArray {
params = append(params, val)
// "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
// " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
// " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" +
// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var response []StateBlockCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": stateBlockNIDs,
"@x3": eventTypeNIDArray,
"@x4": eventStateKeyNIDArray,
}
rows, err := s.db.QueryContext(
ctx,
sqlStatement,
params...,
)
response, err := queryStateBlock(s, ctx, s.bulkSelectFilteredStateBlockEntriesStmt, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed")
var results []types.StateEntryList
var current types.StateEntryList
for rows.Next() {
for _, item := range response {
var (
stateBlockNID int64
eventTypeNID int64
@ -222,11 +329,10 @@ func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
eventNID int64
entry types.StateEntry
)
if err := rows.Scan(
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
); err != nil {
return nil, err
}
stateBlockNID = item.StateBlock.StateBlockNID
eventTypeNID = item.StateBlock.EventTypeNID
eventStateKeyNID = item.StateBlock.EventStateKeyNID
eventNID = item.StateBlock.EventNID
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
entry.EventNID = types.EventNID(eventNID)

View file

@ -0,0 +1,12 @@
package cosmosdb
import (
"context"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
)
func GetNextStateSnapshotNID(s *stateSnapshotStatements, ctx context.Context) (int64, error) {
const docId = "statesnapshotnid_seq"
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1)
}

View file

@ -18,106 +18,169 @@ package cosmosdb
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
)
const stateSnapshotSchema = `
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
room_nid INTEGER NOT NULL,
state_block_nids TEXT NOT NULL DEFAULT '[]'
);
`
// const stateSnapshotSchema = `
// CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
// state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
// room_nid INTEGER NOT NULL,
// state_block_nids TEXT NOT NULL DEFAULT '[]'
// );
// `
const insertStateSQL = `
INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
VALUES ($1, $2);`
type StateSnapshotCosmos struct {
StateSnapshotNID int64 `json:"state_snapshot_nid"`
RoomNID int64 `json:"room_nid"`
StateBlockNIDs []int64 `json:"state_block_nids"`
}
type StateSnapshotCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
StateSnapshot StateSnapshotCosmos `json:"mx_roomserver_state_snapshot"`
}
// const insertStateSQL = `
// INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
// VALUES ($1, $2);`
// Bulk state data NID lookup.
// Sorting by state_snapshot_nid means we can use binary search over the result
// to lookup the state data NIDs for a state snapshot NID.
// "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
// " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
const bulkSelectStateBlockNIDsSQL = "" +
"SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
" WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_state_snapshot.state_snapshot_nid) " +
"order by c.mx_roomserver_state_snapshot.state_snapshot_nid asc"
type stateSnapshotStatements struct {
db *sql.DB
insertStateStmt *sql.Stmt
bulkSelectStateBlockNIDsStmt *sql.Stmt
db *Database
// insertStateStmt *sql.Stmt
bulkSelectStateBlockNIDsStmt string
tableName string
}
func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
func mapFromStateBlockNIDArray(stateBlockNIDs []types.StateBlockNID) []int64 {
result := []int64{}
for i := 0; i < len(stateBlockNIDs); i++ {
result = append(result, int64(stateBlockNIDs[i]))
}
return result
}
func mapToStateBlockNIDArray(stateBlockNIDs []int64) []types.StateBlockNID {
result := []types.StateBlockNID{}
for i := 0; i < len(stateBlockNIDs); i++ {
result = append(result, types.StateBlockNID(stateBlockNIDs[i]))
}
return result
}
func NewCosmosDBStateSnapshotTable(db *Database) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{
db: db,
}
_, err := db.Exec(stateSnapshotSchema)
if err != nil {
return nil, err
}
return s, shared.StatementList{
{&s.insertStateStmt, insertStateSQL},
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
}.Prepare(db)
// return s, shared.StatementList{
// {&s.insertStateStmt, insertStateSQL},
s.bulkSelectStateBlockNIDsStmt = bulkSelectStateBlockNIDsSQL
// }.Prepare(db)
s.tableName = "state_snapshots"
return s, nil
}
func (s *stateSnapshotStatements) InsertState(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
) (stateNID types.StateSnapshotNID, err error) {
stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs)
if err != nil {
return
// INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
// VALUES ($1, $2);`
stateSnapshotNIDSeq, seqErr := GetNextStateSnapshotNID(s, ctx)
if seqErr != nil {
return 0, seqErr
}
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
data := StateSnapshotCosmos{
RoomNID: int64(roomNID),
StateBlockNIDs: mapFromStateBlockNIDArray(stateBlockNIDs),
StateSnapshotNID: int64(stateSnapshotNIDSeq),
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
docId := fmt.Sprintf("%d", stateSnapshotNIDSeq)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = StateSnapshotCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
StateSnapshot: data,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
if err != nil {
return 0, err
}
lastRowID, err := res.LastInsertId()
if err != nil {
return 0, err
}
stateNID = types.StateSnapshotNID(lastRowID)
stateNID = types.StateSnapshotNID(stateSnapshotNIDSeq)
return
}
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
nids := make([]interface{}, len(stateNIDs))
for k, v := range stateNIDs {
nids[k] = v
// "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
// " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []StateSnapshotCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": stateNIDs,
}
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
selectStmt, err := s.db.Prepare(selectOrig)
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.bulkSelectStateBlockNIDsStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
rows, err := selectStmt.QueryContext(ctx, nids...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed")
results := make([]types.StateBlockNIDList, len(stateNIDs))
i := 0
for ; rows.Next(); i++ {
for _, item := range response {
result := &results[i]
var stateBlockNIDsJSON string
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil {
return nil, err
}
if err := json.Unmarshal([]byte(stateBlockNIDsJSON), &result.StateBlockNIDs); err != nil {
return nil, err
}
result.StateSnapshotNID = types.StateSnapshotNID(item.StateSnapshot.StateSnapshotNID)
result.StateBlockNIDs = mapToStateBlockNIDArray(item.StateSnapshot.StateBlockNIDs)
i++
}
if i != len(stateNIDs) {
return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))

View file

@ -17,14 +17,15 @@ package cosmosdb
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
_ "github.com/mattn/go-sqlite3"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
@ -33,16 +34,26 @@ import (
// A Database is used to store room events and stream offsets.
type Database struct {
shared.Database
connection cosmosdbapi.CosmosConnection
databaseName string
cosmosConfig cosmosdbapi.CosmosConfig
serverName gomatrixserverlib.ServerName
}
// Open a sqlite database.
func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
var d Database
var db *sql.DB
var err error
if db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
d := &Database{
databaseName: "roomserver",
connection: conn,
cosmosConfig: config,
}
// var db *sql.DB
// var err error
// if db, err = sqlutil.Open(dbProperties); err != nil {
// return nil, err
// }
//db.Exec("PRAGMA journal_mode=WAL;")
//db.Exec("PRAGMA read_uncommitted = true;")
@ -51,89 +62,91 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches)
// cause the roomserver to be unresponsive to new events because something will
// acquire the global mutex and never unlock it because it is waiting for a connection
// which it will never obtain.
db.SetMaxOpenConns(20)
// db.SetMaxOpenConns(20)
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
ms := membershipStatements{}
if err := ms.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadAddForgottenColumn(m)
if err := m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err := d.prepare(db, cache); err != nil {
// ms := membershipStatements{}
// if err := ms.execSchema(db); err != nil {
// return nil, err
// }
// m := sqlutil.NewMigrations()
// deltas.LoadAddForgottenColumn(m)
// if err := m.RunDeltas(db, dbProperties); err != nil {
// return nil, err
// }
if err := d.prepare(cache); err != nil {
return nil, err
}
return &d, nil
return d, nil
}
// nolint: gocyclo
func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
func (d *Database) prepare(cache caching.RoomServerCaches) error {
var err error
eventStateKeys, err := NewSqliteEventStateKeysTable(db)
d.databaseName = "roomserver"
eventStateKeys, err := NewCosmosDBEventStateKeysTable(d)
if err != nil {
return err
}
eventTypes, err := NewSqliteEventTypesTable(db)
eventTypes, err := NewCosmosDBEventTypesTable(d)
if err != nil {
return err
}
eventJSON, err := NewSqliteEventJSONTable(db)
eventJSON, err := NewCosmosDBEventJSONTable(d)
if err != nil {
return err
}
events, err := NewSqliteEventsTable(db)
events, err := NewCosmosDBEventsTable(d)
if err != nil {
return err
}
rooms, err := NewSqliteRoomsTable(db)
rooms, err := NewCosmosDBRoomsTable(d)
if err != nil {
return err
}
transactions, err := NewSqliteTransactionsTable(db)
transactions, err := NewCosmosDBTransactionsTable(d)
if err != nil {
return err
}
stateBlock, err := NewSqliteStateBlockTable(db)
stateBlock, err := NewCosmosDBStateBlockTable(d)
if err != nil {
return err
}
stateSnapshot, err := NewSqliteStateSnapshotTable(db)
stateSnapshot, err := NewCosmosDBStateSnapshotTable(d)
if err != nil {
return err
}
prevEvents, err := NewSqlitePrevEventsTable(db)
prevEvents, err := NewCosmosDBPrevEventsTable(d)
if err != nil {
return err
}
roomAliases, err := NewSqliteRoomAliasesTable(db)
roomAliases, err := NewCosmosDBRoomAliasesTable(d)
if err != nil {
return err
}
invites, err := NewSqliteInvitesTable(db)
invites, err := NewCosmosDBInvitesTable(d)
if err != nil {
return err
}
membership, err := NewSqliteMembershipTable(db)
membership, err := NewCosmosDBMembershipTable(d)
if err != nil {
return err
}
published, err := NewSqlitePublishedTable(db)
published, err := NewCosmosDBPublishedTable(d)
if err != nil {
return err
}
redactions, err := NewSqliteRedactionsTable(db)
redactions, err := NewCosmosDBRedactionsTable(d)
if err != nil {
return err
}
d.Database = shared.Database{
DB: db,
Cache: cache,
Writer: sqlutil.NewExclusiveWriter(),
DB: nil,
Cache: cache,
//Use the Fake SQL Writer here
Writer: cosmosdbutil.NewExclusiveWriterFake(),
EventsTable: events,
EventTypesTable: eventTypes,
EventStateKeysTable: eventStateKeys,

View file

@ -18,50 +18,84 @@ package cosmosdb
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
)
const transactionsSchema = `
CREATE TABLE IF NOT EXISTS roomserver_transactions (
transaction_id TEXT NOT NULL,
session_id INTEGER NOT NULL,
user_id TEXT NOT NULL,
event_id TEXT NOT NULL,
PRIMARY KEY (transaction_id, session_id, user_id)
);
`
const insertTransactionSQL = `
INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)
VALUES ($1, $2, $3, $4)
`
// const transactionsSchema = `
// CREATE TABLE IF NOT EXISTS roomserver_transactions (
// transaction_id TEXT NOT NULL,
// session_id INTEGER NOT NULL,
// user_id TEXT NOT NULL,
// event_id TEXT NOT NULL,
// PRIMARY KEY (transaction_id, session_id, user_id)
// );
// `
const selectTransactionEventIDSQL = `
SELECT event_id FROM roomserver_transactions
WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3
`
type transactionStatements struct {
db *sql.DB
insertTransactionStmt *sql.Stmt
selectTransactionEventIDStmt *sql.Stmt
type TransactionCosmos struct {
TransactionID string `json:"transaction_id"`
SessionID int64 `json:"session_id"`
UserID string `json:"user_id"`
EventID string `json:"event_id"`
}
func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) {
type TransactionCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Transaction TransactionCosmos `json:"mx_roomserver_transaction"`
}
// const insertTransactionSQL = `
// INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)
// VALUES ($1, $2, $3, $4)
// `
// const selectTransactionEventIDSQL = `
// SELECT event_id FROM roomserver_transactions
// WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3
// `
type transactionStatements struct {
db *Database
// insertTransactionStmt *sql.Stmt
selectTransactionEventIDStmt *sql.Stmt
tableName string
}
func getTransaction(s *transactionStatements, ctx context.Context, pk string, docId string) (*TransactionCosmosData, error) {
response := TransactionCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func NewCosmosDBTransactionsTable(db *Database) (tables.Transactions, error) {
s := &transactionStatements{
db: db,
}
_, err := db.Exec(transactionsSchema)
if err != nil {
return nil, err
}
return s, shared.StatementList{
{&s.insertTransactionStmt, insertTransactionSQL},
{&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL},
}.Prepare(db)
// return s, shared.StatementList{
// {&s.insertTransactionStmt, insertTransactionSQL},
// {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL},
// }.Prepare(db)
s.tableName = "transactions"
return s, nil
}
func (s *transactionStatements) InsertTransaction(
@ -71,10 +105,39 @@ func (s *transactionStatements) InsertTransaction(
userID string,
eventID string,
) error {
stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt)
_, err := stmt.ExecContext(
ctx, transactionID, sessionID, userID, eventID,
)
// INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)
// VALUES ($1, $2, $3, $4)
data := TransactionCosmos{
EventID: eventID,
SessionID: sessionID,
TransactionID: transactionID,
UserID: userID,
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// PRIMARY KEY (transaction_id, session_id, user_id)
docId := fmt.Sprintf("%s_%d_%s", transactionID, sessionID, userID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = TransactionCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
Transaction: data,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err
}
@ -84,8 +147,21 @@ func (s *transactionStatements) SelectTransactionEventID(
sessionID int64,
userID string,
) (eventID string, err error) {
err = s.selectTransactionEventIDStmt.QueryRowContext(
ctx, transactionID, sessionID, userID,
).Scan(&eventID)
return
// SELECT event_id FROM roomserver_transactions
// WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// PRIMARY KEY (transaction_id, session_id, user_id)
docId := fmt.Sprintf("%s_%d_%s", transactionID, sessionID, userID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response, err := getTransaction(s, ctx, pk, cosmosDocId)
if err != nil {
return "", err
}
return response.Transaction.EventID, err
}

View file

@ -71,6 +71,27 @@ func (s *accountDataStatements) prepare(db *Database) (err error) {
return
}
func queryAccountData(s *accountDataStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountDataCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []AccountDataCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func (s *accountDataStatements) insertAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
@ -92,10 +113,14 @@ func (s *accountDataStatements) insertAccountData(
id = fmt.Sprintf("%s_%s_%s", result.LocalPart, result.RoomId, result.Type)
}
docId := id
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = AccountDataCosmosData{
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, id),
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
Pk: pk,
Timestamp: time.Now().Unix(),
AccountData: result,
}
@ -120,24 +145,15 @@ func (s *accountDataStatements) selectAccountData(
) {
// "SELECT room_id, type, content FROM account_data WHERE localpart = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []AccountDataCosmosData{}
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectAccountDataStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
if ex != nil {
return nil, nil, ex
response, err := queryAccountData(s, ctx, s.selectAccountDataStmt, params)
if err != nil {
return nil, nil, err
}
global := map[string]json.RawMessage{}
@ -166,26 +182,17 @@ func (s *accountDataStatements) selectAccountDataByType(
// "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []AccountDataCosmosData{}
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
"@x3": roomID,
"@x4": dataType,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectAccountDataByTypeStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
if ex != nil {
return nil, ex
response, err := queryAccountData(s, ctx, s.selectAccountDataByTypeStmt, params)
if err != nil {
return nil, err
}
if len(response) == 0 {

View file

@ -21,6 +21,7 @@ import (
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/userapi/api"
@ -87,17 +88,42 @@ func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.Serv
return
}
func getAccount(s *accountsStatements, ctx context.Context, pk string, docId string) (*AccountCosmosData, error) {
response := AccountCosmosData{}
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
func queryAccount(s *accountsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []AccountCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getAccount(s *accountsStatements, ctx context.Context, pk string, docId string) (*AccountCosmosData, error) {
response := AccountCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
optionsGet,
&response)
return &response, ex
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func setAccount(s *accountsStatements, ctx context.Context, pk string, account AccountCosmosData) (*AccountCosmosData, error) {
@ -155,10 +181,14 @@ func (s *accountsStatements) insertAccount(
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
docId := result.Localpart
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = AccountCosmosData{
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Localpart),
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
Pk: pk,
Timestamp: time.Now().Unix(),
Account: data,
}
@ -184,10 +214,11 @@ func (s *accountsStatements) updatePassword(
// "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
docId := localpart
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getAccount(s, ctx, pk, docId)
var response, exGet = getAccount(s, ctx, pk, cosmosDocId)
if exGet != nil {
return exGet
}
@ -207,10 +238,12 @@ func (s *accountsStatements) deactivateAccount(
// "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getAccount(s, ctx, pk, docId)
docId := localpart
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getAccount(s, ctx, pk, cosmosDocId)
if exGet != nil {
return exGet
}
@ -230,24 +263,15 @@ func (s *accountsStatements) selectPasswordHash(
// "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []AccountCosmosData{}
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectPasswordHashStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
if ex != nil {
return "", ex
response, err := queryAccount(s, ctx, s.selectPasswordHashStmt, params)
if err != nil {
return "", err
}
if len(response) == 0 {
@ -268,24 +292,15 @@ func (s *accountsStatements) selectAccountByLocalpart(
// "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []AccountCosmosData{}
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectAccountByLocalpartStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
if ex != nil {
return nil, ex
response, err := queryAccount(s, ctx, s.selectAccountByLocalpartStmt, params)
if err != nil {
return nil, err
}
if len(response) == 0 {

View file

@ -62,6 +62,27 @@ func mapToToken(api api.OpenIDToken) OpenIDTokenCosmos {
}
}
func queryOpenIdToken(s *tokenStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OpenIdTokenCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []OpenIdTokenCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func (s *tokenStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.selectTokenStmt = "select * from c where c._cn = @x1 and c.mx_userapi_openidtoken.token = @x2"
@ -87,10 +108,14 @@ func (s *tokenStatements) insertToken(
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
docId := result.Token
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = OpenIdTokenCosmosData{
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Token),
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
Pk: pk,
Timestamp: time.Now().Unix(),
OpenIdToken: mapToToken(*result),
}
@ -120,24 +145,14 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes(
// "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []OpenIdTokenCosmosData{}
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": token,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectTokenStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
response, err := queryOpenIdToken(s, ctx, s.selectTokenStmt, params)
if ex != nil {
return nil, ex
if err != nil {
return nil, err
}
if len(response) == 0 {

View file

@ -21,6 +21,7 @@ import (
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
@ -87,17 +88,42 @@ func (s *profilesStatements) prepare(db *Database) (err error) {
return
}
func getProfile(s *profilesStatements, ctx context.Context, pk string, docId string) (*ProfileCosmosData, error) {
response := ProfileCosmosData{}
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
func queryProfile(s *profilesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ProfileCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []ProfileCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getProfile(s *profilesStatements, ctx context.Context, pk string, docId string) (*ProfileCosmosData, error) {
response := ProfileCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
optionsGet,
&response)
return &response, ex
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func setProfile(s *profilesStatements, ctx context.Context, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) {
@ -123,10 +149,14 @@ func (s *profilesStatements) insertProfile(
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
docId := localpart
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = ProfileCosmosData{
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Localpart),
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
Pk: pk,
Timestamp: time.Now().Unix(),
Profile: mapToProfile(*result),
}
@ -148,24 +178,15 @@ func (s *profilesStatements) selectProfileByLocalpart(
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []ProfileCosmosData{}
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectProfileByLocalpartStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
if ex != nil {
return nil, ex
response, err := queryProfile(s, ctx, s.selectProfileByLocalpartStmt, params)
if err != nil {
return nil, err
}
if len(response) == 0 {
@ -186,10 +207,11 @@ func (s *profilesStatements) setAvatarURL(
// "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
docId := localpart
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getProfile(s, ctx, pk, docId)
var response, exGet = getProfile(s, ctx, pk, cosmosDocId)
if exGet != nil {
return exGet
}
@ -209,9 +231,10 @@ func (s *profilesStatements) setDisplayName(
// "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
var response, exGet = getProfile(s, ctx, pk, docId)
docId := localpart
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getProfile(s, ctx, pk, cosmosDocId)
if exGet != nil {
return exGet
}
@ -232,25 +255,16 @@ func (s *profilesStatements) selectProfilesBySearch(
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []ProfileCosmosData{}
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": searchString,
"@x3": limit,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectProfilesBySearchStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
if ex != nil {
return nil, ex
response, err := queryProfile(s, ctx, s.selectProfilesBySearchStmt, params)
if err != nil {
return nil, err
}
for i := 0; i < len(response); i++ {

View file

@ -20,6 +20,8 @@ import (
"errors"
"strconv"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
@ -27,7 +29,6 @@ import (
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
@ -37,6 +38,7 @@ import (
// Database represents an account database
type Database struct {
sqlutil.PartitionOffsetStatements
writer sqlutil.Writer
accounts accountsStatements
profiles profilesStatements
accountDatas accountDataStatements
@ -62,7 +64,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
connection: conn,
cosmosConfig: config,
// db: db,
// writer: sqlutil.NewExclusiveWriter(),
writer: sqlutil.NewExclusiveWriter(),
// bcryptCost: bcryptCost,
// openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}

View file

@ -69,31 +69,43 @@ func (s *threepidStatements) prepare(db *Database) (err error) {
return
}
func queryThreePID(s *threepidStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ThreePIDCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []ThreePIDCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func (s *threepidStatements) selectLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
// "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []ThreePIDCosmosData{}
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": threepid,
"@x3": medium,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectLocalpartForThreePIDStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
if ex != nil {
return "", ex
response, err := queryThreePID(s, ctx, s.selectLocalpartForThreePIDStmt, params)
if err != nil {
return "", err
}
if len(response) == 0 {
@ -109,24 +121,14 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
// "SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []ThreePIDCosmosData{}
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectThreePIDsForLocalpartStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
response, err := queryThreePID(s, ctx, s.selectThreePIDsForLocalpartStmt, params)
if ex != nil {
return threepids, ex
if err != nil {
return threepids, err
}
if len(response) == 0 {
@ -158,10 +160,11 @@ func (s *threepidStatements) insertThreePID(
docId := fmt.Sprintf("%s_%s", threepid, medium)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = ThreePIDCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
Pk: pk,
Timestamp: time.Now().Unix(),
ThreePID: result,
}

View file

@ -16,10 +16,11 @@ package cosmosdb
import (
"context"
"errors"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/userapi/api"
@ -82,15 +83,15 @@ type DeviceCosmosSessionCount struct {
}
type devicesStatements struct {
db *Database
db *Database
selectDevicesCountStmt string
selectDeviceByTokenStmt string
// selectDeviceByIDStmt *sql.Stmt
selectDevicesByIDStmt string
selectDevicesByLocalpartStmt string
selectDevicesByLocalpartExceptIDStmt string
serverName gomatrixserverlib.ServerName
tableName string
serverName gomatrixserverlib.ServerName
tableName string
}
func mapFromDevice(db DeviceCosmos) api.Device {
@ -121,17 +122,42 @@ func mapTodevice(api api.Device, s *devicesStatements) DeviceCosmos {
}
}
func getDevice(s *devicesStatements, ctx context.Context, pk string, docId string) (*DeviceCosmosData, error) {
response := DeviceCosmosData{}
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
func queryDevice(s *devicesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getDevice(s *devicesStatements, ctx context.Context, pk string, docId string) (*DeviceCosmosData, error) {
response := DeviceCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
optionsGet,
&response)
return &response, ex
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func setDevice(s *devicesStatements, ctx context.Context, pk string, device DeviceCosmosData) (*DeviceCosmosData, error) {
@ -209,10 +235,11 @@ func (s *devicesStatements) insertDevice(
// HACK: check for duplicate PK as we are using the UNIQUE key for the DocId
docId := fmt.Sprintf("%s_%s", localpart, id)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var dbData = DeviceCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
Pk: pk,
Timestamp: time.Now().Unix(),
Device: data,
}
@ -260,7 +287,6 @@ func (s *devicesStatements) deleteDevices(
) error {
// "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
@ -268,15 +294,8 @@ func (s *devicesStatements) deleteDevices(
"@x3": devices,
}
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartStmt, params)
if err != nil {
return err
}
@ -291,8 +310,6 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
) error {
// "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData
exceptDevices := []string{
exceptDeviceID,
}
@ -302,15 +319,8 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
"@x3": exceptDevices,
}
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartStmt, params)
if err != nil {
return err
}
@ -325,9 +335,9 @@ func (s *devicesStatements) updateDeviceName(
) error {
// "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
if exGet != nil {
return exGet
@ -347,27 +357,19 @@ func (s *devicesStatements) selectDeviceByToken(
) (*api.Device, error) {
// "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": accessToken,
}
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectDeviceByTokenStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
response, err := queryDevice(s, ctx, s.selectDeviceByTokenStmt, params)
if err != nil {
return nil, err
}
if len(response) == 0 {
return nil, errors.New(fmt.Sprintf("No Devices found with AccessToken %s", accessToken))
return nil, cosmosdbutil.ErrNoRows
}
if err == nil {
@ -384,9 +386,9 @@ func (s *devicesStatements) selectDeviceByID(
) (*api.Device, error) {
// "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
if exGet != nil {
return nil, exGet
@ -401,23 +403,14 @@ func (s *devicesStatements) selectDevicesByLocalpart(
devices := []api.Device{}
// "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
"@x3": exceptDeviceID,
}
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartExceptIDStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartExceptIDStmt, params)
if err != nil {
return nil, err
}
@ -435,22 +428,14 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
// "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
var devices []api.Device
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": deviceIDs,
}
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectDevicesByIDStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
response, err := queryDevice(s, ctx, s.selectDevicesByIDStmt, params)
if err != nil {
return nil, err
}
@ -466,9 +451,9 @@ func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, localpart,
// "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
if exGet != nil {
return exGet

View file

@ -23,7 +23,6 @@ import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
@ -37,7 +36,6 @@ var deviceIDByteLength = 6
// Database represents a device database.
type Database struct {
writer sqlutil.Writer
devices devicesStatements
connection cosmosdbapi.CosmosConnection
databaseName string