Implement Cosmos DB for the RoomServer Service (#5)

* - Implement Cosmos for the devices_table
- Use the ConnectionString in the YAML to include the Tenant
- Revert all other non implemented tables back to use SQLLite3

* - Change the Config to use "test.criticicalarc.com" Container
- Add generic function GetDocumentOrNil to standardize GetDocument
- Add func to return CrossPartition queries for Aggregates
- Add func GetNextSequence() as generic seq generator for AutoIncrement
- Add cosmosdbutil.ErrNoRows to return (emulate) sql.ErrNoRows
- Add a "fake" ExclusiveWriterFake
- Add standard "getXX", "setXX" and "queryXX" to all TABLE class files
- Add specific Table SEQ for the Events table
- Add specific Table SEQ for the Rooms table
- Add specific Table SEQ for the StateSnapshot table
This commit is contained in:
alexfca 2021-05-20 14:42:33 +10:00 committed by GitHub
parent b696923333
commit 5d68daef80
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
33 changed files with 4012 additions and 1564 deletions

View file

@ -291,7 +291,7 @@ room_server:
listen: http://localhost:7770 listen: http://localhost:7770
connect: http://localhost:7770 connect: http://localhost:7770
database: database:
connection_string: file:roomserver.db connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;"
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -354,12 +354,12 @@ user_api:
listen: http://localhost:7781 listen: http://localhost:7781
connect: http://localhost:7781 connect: http://localhost:7781
account_database: account_database:
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=criticalarc.com;" connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;"
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
device_database: device_database:
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=criticalarc.com;" connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;"
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1

View file

@ -1,8 +1,8 @@
package cosmosdbapi package cosmosdbapi
import ( import (
"context"
"fmt" "fmt"
) )
func GetDocumentId(tenantName string, collectionName string, id string) string { func GetDocumentId(tenantName string, collectionName string, id string) string {
@ -11,4 +11,25 @@ func GetDocumentId(tenantName string, collectionName string, id string) string {
func GetPartitionKey(tenantName string, collectionName string) string { func GetPartitionKey(tenantName string, collectionName string) string {
return fmt.Sprintf("%s,%s", collectionName, tenantName) return fmt.Sprintf("%s,%s", collectionName, tenantName)
} }
func GetDocumentOrNil(connection CosmosConnection, config CosmosConfig, ctx context.Context, partitionKey string, cosmosDocId string, dbData interface{}) error {
var _, err = GetClient(connection).GetDocument(
ctx,
config.DatabaseName,
config.ContainerName,
cosmosDocId,
GetGetDocumentOptions(partitionKey),
&dbData,
)
if err != nil {
if err.Error() == "Resource that no longer exists" {
dbData = nil
return nil
}
return err
}
return nil
}

View file

@ -6,14 +6,14 @@ import (
func GetCreateDocumentOptions(pk string) cosmosapi.CreateDocumentOptions { func GetCreateDocumentOptions(pk string) cosmosapi.CreateDocumentOptions {
return cosmosapi.CreateDocumentOptions{ return cosmosapi.CreateDocumentOptions{
IsUpsert: false, IsUpsert: false,
PartitionKeyValue: pk, PartitionKeyValue: pk,
} }
} }
func GetUpsertDocumentOptions(pk string) cosmosapi.CreateDocumentOptions { func GetUpsertDocumentOptions(pk string) cosmosapi.CreateDocumentOptions {
return cosmosapi.CreateDocumentOptions{ return cosmosapi.CreateDocumentOptions{
IsUpsert: true, IsUpsert: true,
PartitionKeyValue: pk, PartitionKeyValue: pk,
} }
} }
@ -21,8 +21,16 @@ func GetUpsertDocumentOptions(pk string) cosmosapi.CreateDocumentOptions {
func GetQueryDocumentsOptions(pk string) cosmosapi.QueryDocumentsOptions { func GetQueryDocumentsOptions(pk string) cosmosapi.QueryDocumentsOptions {
return cosmosapi.QueryDocumentsOptions{ return cosmosapi.QueryDocumentsOptions{
PartitionKeyValue: pk, PartitionKeyValue: pk,
IsQuery: true, IsQuery: true,
ContentType: cosmosapi.QUERY_CONTENT_TYPE, ContentType: cosmosapi.QUERY_CONTENT_TYPE,
}
}
func GetQueryAllPartitionsDocumentsOptions() cosmosapi.QueryDocumentsOptions {
return cosmosapi.QueryDocumentsOptions{
IsQuery: true,
EnableCrossPartition: true,
ContentType: cosmosapi.QUERY_CONTENT_TYPE,
} }
} }
@ -35,7 +43,7 @@ func GetGetDocumentOptions(pk string) cosmosapi.GetDocumentOptions {
func GetReplaceDocumentOptions(pk string, etag string) cosmosapi.ReplaceDocumentOptions { func GetReplaceDocumentOptions(pk string, etag string) cosmosapi.ReplaceDocumentOptions {
return cosmosapi.ReplaceDocumentOptions{ return cosmosapi.ReplaceDocumentOptions{
PartitionKeyValue: pk, PartitionKeyValue: pk,
IfMatch: etag, IfMatch: etag,
} }
} }

View file

@ -0,0 +1,76 @@
package cosmosdbutil
import (
"context"
"fmt"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
)
type SequenceCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Value int64 `json:"_value"`
}
func GetNextSequence(
ctx context.Context,
connection cosmosdbapi.CosmosConnection,
config cosmosdbapi.CosmosConfig,
serviceName string,
tableName string,
seqId string,
initial int64,
) (int64, error) {
collName := fmt.Sprintf("%s_%s", tableName, seqId)
dbCollectionName := cosmosdbapi.GetCollectionName(serviceName, collName)
cosmosDocId := cosmosdbapi.GetDocumentId(config.ContainerName, dbCollectionName, seqId)
pk := cosmosDocId
dbData := SequenceCosmosData{}
cosmosdbapi.GetDocumentOrNil(
connection,
config,
ctx,
pk,
cosmosDocId,
&dbData,
)
if dbData.Id == "" {
dbData = SequenceCosmosData{}
dbData.Id = cosmosDocId
dbData.Pk = pk
dbData.Cn = dbCollectionName
dbData.Value = initial
var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(connection).CreateDocument(
ctx,
config.DatabaseName,
config.ContainerName,
dbData,
optionsCreate,
)
if err != nil {
return -1, err
}
} else {
dbData.Value++
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(dbData.Pk, dbData.ETag)
var _, _, err = cosmosdbapi.GetClient(connection).ReplaceDocument(
ctx,
config.DatabaseName,
config.ContainerName,
cosmosDocId,
dbData,
optionsReplace,
)
if err != nil {
return -1, err
}
}
return dbData.Value, nil
}

View file

@ -0,0 +1,12 @@
package cosmosdbutil
import (
"database/sql"
"errors"
)
// ErrNoRows is returned by Scan when QueryRow doesn't return a
// row. Used to simulate the SQL responses as its used for business logic
var ErrNoRows = sql.ErrNoRows
var ErrNotImplemented = errors.New("cosmosdb: not implemented")

View file

@ -0,0 +1,77 @@
package cosmosdbutil
import (
"database/sql"
"errors"
"github.com/matrix-org/dendrite/internal/sqlutil"
"go.uber.org/atomic"
)
// ExclusiveWriter implements sqlutil.Writer.
// Allows non-SQL DBs to still use the same infrastructure
// as Matrix assumes SQL TXN
type ExclusiveWriterFake struct {
running atomic.Bool
todo chan transactionWriterTaskFake
}
func NewExclusiveWriterFake() sqlutil.Writer {
return &ExclusiveWriterFake{
todo: make(chan transactionWriterTaskFake),
}
}
// transactionWriterTask represents a specific task.
type transactionWriterTaskFake struct {
db *sql.DB
txn *sql.Tx
f func(txn *sql.Tx) error
wait chan error
}
func (w *ExclusiveWriterFake) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
if w.todo == nil {
return errors.New("not initialised")
}
if !w.running.Load() {
go w.run()
}
task := transactionWriterTaskFake{
db: db,
txn: txn,
f: f,
wait: make(chan error, 1),
}
w.todo <- task
return <-task.wait
}
// run processes the tasks for a given transaction writer. Only one
// of these goroutines will run at a time. A transaction will be
// opened using the database object from the task and then this will
// be passed as a parameter to the task function.
func (w *ExclusiveWriterFake) run() {
if !w.running.CAS(false, true) {
return
}
// if tracingEnabled {
// gid := goid()
// goidToWriter.Store(gid, w)
// defer goidToWriter.Delete(gid)
// }
defer w.running.Store(false)
for task := range w.todo {
if task.db != nil && task.txn != nil {
task.wait <- task.f(task.txn)
// } else if task.db != nil && task.txn == nil {
// task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
// return task.f(txn)
// })
} else {
task.wait <- task.f(nil)
}
close(task.wait)
}
}

View file

@ -0,0 +1,82 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
}
func LoadAddForgottenColumn(m *sqlutil.Migrations) {
m.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
}
func UpAddForgottenColumn(tx *sql.Tx) error {
_, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp;
CREATE TABLE IF NOT EXISTS roomserver_membership (
room_nid INTEGER NOT NULL,
target_nid INTEGER NOT NULL,
sender_nid INTEGER NOT NULL DEFAULT 0,
membership_nid INTEGER NOT NULL DEFAULT 1,
event_nid INTEGER NOT NULL DEFAULT 0,
target_local BOOLEAN NOT NULL DEFAULT false,
forgotten BOOLEAN NOT NULL DEFAULT false,
UNIQUE (room_nid, target_nid)
);
INSERT
INTO roomserver_membership (
room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local
) SELECT
room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local
FROM roomserver_membership_tmp
;
DROP TABLE roomserver_membership_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownAddForgottenColumn(tx *sql.Tx) error {
_, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp;
CREATE TABLE IF NOT EXISTS roomserver_membership (
room_nid INTEGER NOT NULL,
target_nid INTEGER NOT NULL,
sender_nid INTEGER NOT NULL DEFAULT 0,
membership_nid INTEGER NOT NULL DEFAULT 1,
event_nid INTEGER NOT NULL DEFAULT 0,
target_local BOOLEAN NOT NULL DEFAULT false,
UNIQUE (room_nid, target_nid)
);
INSERT
INTO roomserver_membership (
room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local
) SELECT
room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local
FROM roomserver_membership_tmp
;
DROP TABLE roomserver_membership_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -18,76 +18,152 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings" "fmt"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
const eventJSONSchema = ` // const eventJSONSchema = `
CREATE TABLE IF NOT EXISTS roomserver_event_json ( // CREATE TABLE IF NOT EXISTS roomserver_event_json (
event_nid INTEGER NOT NULL PRIMARY KEY, // event_nid INTEGER NOT NULL PRIMARY KEY,
event_json TEXT NOT NULL // event_json TEXT NOT NULL
); // );
` // `
const insertEventJSONSQL = ` type EventJSONCosmos struct {
INSERT OR REPLACE INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2) EventNID int64 `json:"event_nid"`
` EventJSON []byte `json:"event_json"`
}
type EventJSONCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
EventJSON EventJSONCosmos `json:"mx_roomserver_event_json"`
}
// const insertEventJSONSQL = `
// INSERT OR REPLACE INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2)
// `
// Bulk event JSON lookup by numeric event ID. // Bulk event JSON lookup by numeric event ID.
// Sort by the numeric event ID. // Sort by the numeric event ID.
// This means that we can use binary search to lookup by numeric event ID. // This means that we can use binary search to lookup by numeric event ID.
const bulkSelectEventJSONSQL = ` // SELECT event_nid, event_json FROM roomserver_event_json
SELECT event_nid, event_json FROM roomserver_event_json // WHERE event_nid IN ($1)
WHERE event_nid IN ($1) // ORDER BY event_nid ASC
ORDER BY event_nid ASC const bulkSelectEventJSONSQL = "" +
` "select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_json.event_nid) " +
"order by c.mx_roomserver_event_json.event_nid asc"
type eventJSONStatements struct { type eventJSONStatements struct {
db *sql.DB db *Database
insertEventJSONStmt *sql.Stmt // insertEventJSONStmt *sql.Stmt
bulkSelectEventJSONStmt *sql.Stmt bulkSelectEventJSONStmt string
tableName string
} }
func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) { func queryEventJSON(s *eventJSONStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventJSONCosmosData, error) {
s := &eventJSONStatements{ var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
db: db, var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
} var response []EventJSONCosmosData
_, err := db.Exec(eventJSONSchema)
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return s, shared.StatementList{ return response, nil
{&s.insertEventJSONStmt, insertEventJSONSQL}, }
{&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL},
}.Prepare(db) func NewCosmosDBEventJSONTable(db *Database) (tables.EventJSON, error) {
s := &eventJSONStatements{
db: db,
}
// _, err := db.Exec(eventJSONSchema)
// if err != nil {
// return nil, err
// }
// return s, shared.StatementList{
// {&s.insertEventJSONStmt, insertEventJSONSQL},
s.bulkSelectEventJSONStmt = bulkSelectEventJSONSQL
// }.Prepare(db)
s.tableName = "event_json"
return s, nil
} }
func (s *eventJSONStatements) InsertEventJSON( func (s *eventJSONStatements) InsertEventJSON(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
// _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
// INSERT OR REPLACE INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := fmt.Sprintf("%d", eventNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := EventJSONCosmos{
EventNID: int64(eventNID),
EventJSON: eventJSON,
}
var dbData = EventJSONCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
EventJSON: data,
}
//Insert OR Replace
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err return err
} }
func (s *eventJSONStatements) BulkSelectEventJSON( func (s *eventJSONStatements) BulkSelectEventJSON(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, eventNIDs []types.EventNID,
) ([]tables.EventJSONPair, error) { ) ([]tables.EventJSONPair, error) {
iEventNIDs := make([]interface{}, len(eventNIDs))
for k, v := range eventNIDs {
iEventNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...) // SELECT event_nid, event_json FROM roomserver_event_json
// WHERE event_nid IN ($1)
// ORDER BY event_nid ASC
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": eventNIDs,
}
response, err := queryEventJSON(s, ctx, s.bulkSelectEventJSONStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventJSON: rows.close() failed")
// We know that we will only get as many results as event NIDs // We know that we will only get as many results as event NIDs
// because of the unique constraint on event NIDs. // because of the unique constraint on event NIDs.
@ -95,13 +171,11 @@ func (s *eventJSONStatements) BulkSelectEventJSON(
// We might get fewer results than NIDs so we adjust the length of the slice before returning it. // We might get fewer results than NIDs so we adjust the length of the slice before returning it.
results := make([]tables.EventJSONPair, len(eventNIDs)) results := make([]tables.EventJSONPair, len(eventNIDs))
i := 0 i := 0
for ; rows.Next(); i++ { for _, item := range response {
result := &results[i] result := &results[i]
var eventNID int64 result.EventNID = types.EventNID(item.EventJSON.EventNID)
if err := rows.Scan(&eventNID, &result.EventJSON); err != nil { result.EventJSON = item.EventJSON.EventJSON
return nil, err i++
}
result.EventNID = types.EventNID(eventNID)
} }
return results[:i], nil return results[:i], nil
} }

View file

@ -0,0 +1,24 @@
package cosmosdb
import (
"context"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
)
func GetNextEventStateKeyNID(s *eventStateKeyStatements, ctx context.Context) (int64, error) {
const docId = "eventstatekeynid_seq"
//1 insert start at 2
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 2)
}
func GetNextEventTypeNID(s *eventTypeStatements, ctx context.Context) (int64, error) {
const docId = "eventtypenid_seq"
//7 inserts start at 8
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 8)
}
func GetNextEventNID(s *eventStatements, ctx context.Context) (int64, error) {
const docId = "eventnid_seq"
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1)
}

View file

@ -18,96 +18,248 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings" "time"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
const eventStateKeysSchema = ` // const eventStateKeysSchema = `
CREATE TABLE IF NOT EXISTS roomserver_event_state_keys ( // CREATE TABLE IF NOT EXISTS roomserver_event_state_keys (
event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT, // event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT,
event_state_key TEXT NOT NULL UNIQUE // event_state_key TEXT NOT NULL UNIQUE
); // );
INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key) // INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key)
VALUES (1, '') // VALUES (1, '')
ON CONFLICT DO NOTHING; // ON CONFLICT DO NOTHING;
` // `
type EventStateKeysCosmos struct {
EventStateKeyNID int64 `json:"event_state_key_nid"`
EventStateKey string `json:"event_state_key"`
}
type EventStateKeysCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
EventStateKeys EventStateKeysCosmos `json:"mx_roomserver_event_state_keys"`
}
// Same as insertEventTypeNIDSQL // Same as insertEventTypeNIDSQL
const insertEventStateKeyNIDSQL = ` // const insertEventStateKeyNIDSQL = `
INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1) // INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1)
ON CONFLICT DO NOTHING; // ON CONFLICT DO NOTHING;
` // `
const selectEventStateKeyNIDSQL = ` // SELECT event_state_key_nid FROM roomserver_event_state_keys
SELECT event_state_key_nid FROM roomserver_event_state_keys // WHERE event_state_key = $1
WHERE event_state_key = $1 const selectEventStateKeyNIDSQL = "" +
` "select * from c where c._cn = @x1 " +
"and c.mx_roomserver_event_state_keys.event_state_key = @x2"
// Bulk lookup from string state key to numeric ID for that state key. // // Bulk lookup from string state key to numeric ID for that state key.
// Takes an array of strings as the query parameter. // // Takes an array of strings as the query parameter.
const bulkSelectEventStateKeySQL = ` // SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys // WHERE event_state_key IN ($1)
WHERE event_state_key IN ($1) const bulkSelectEventStateKeySQL = "" +
` "select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_state_keys.event_state_key_nid)"
// Bulk lookup from numeric ID to string state key for that state key. // Bulk lookup from numeric ID to string state key for that state key.
// Takes an array of strings as the query parameter. // Takes an array of strings as the query parameter.
const bulkSelectEventStateKeyNIDSQL = ` // SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys // WHERE event_state_key_nid IN ($1)
WHERE event_state_key_nid IN ($1) const bulkSelectEventStateKeyNIDSQL = "" +
` "select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_state_keys.event_state_key)"
type eventStateKeyStatements struct { type eventStateKeyStatements struct {
db *sql.DB db *Database
insertEventStateKeyNIDStmt *sql.Stmt insertEventStateKeyNIDStmt string
selectEventStateKeyNIDStmt *sql.Stmt selectEventStateKeyNIDStmt string
bulkSelectEventStateKeyNIDStmt *sql.Stmt bulkSelectEventStateKeyNIDStmt string
bulkSelectEventStateKeyStmt *sql.Stmt bulkSelectEventStateKeyStmt string
tableName string
} }
func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { func queryEventStateKeys(s *eventStateKeyStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventStateKeysCosmosData, error) {
s := &eventStateKeyStatements{ var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
db: db, var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
} var response []EventStateKeysCosmosData
_, err := db.Exec(eventStateKeysSchema)
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return s, shared.StatementList{ return response, nil
{&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, }
{&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL},
{&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, func getEventStateKeys(s *eventStateKeyStatements, ctx context.Context, pk string, docId string) (*EventStateKeysCosmosData, error) {
{&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL}, response := EventStateKeysCosmosData{}
}.Prepare(db) err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, nil
}
return &response, err
}
func NewCosmosDBEventStateKeysTable(db *Database) (tables.EventStateKeys, error) {
s := &eventStateKeyStatements{
db: db,
}
// return s, shared.StatementList{
// {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL},
s.selectEventStateKeyNIDStmt = selectEventStateKeyNIDSQL
s.bulkSelectEventStateKeyNIDStmt = bulkSelectEventStateKeyNIDSQL
s.bulkSelectEventStateKeyStmt = bulkSelectEventStateKeySQL
// }.Prepare(db)
s.tableName = "event_state_keys"
//Add in the initial data
ensureEventStateKeys(s, context.Background())
return s, nil
}
func ensureEventStateKeys(s *eventStateKeyStatements, ctx context.Context) {
// INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key)
// VALUES (1, '')
// ON CONFLICT DO NOTHING;
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// event_state_key TEXT NOT NULL UNIQUE
docId := ""
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := EventStateKeysCosmos{
EventStateKey: "",
EventStateKeyNID: 1,
}
// event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT,
dbData := EventStateKeysCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
EventStateKeys: data,
}
insertEventStateKeyCore(s, ctx, dbData)
}
func insertEventStateKeyCore(s *eventStateKeyStatements, ctx context.Context, dbData EventStateKeysCosmosData) error {
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData,
options)
if err != nil {
return err
}
return nil
} }
func (s *eventStateKeyStatements) InsertEventStateKeyNID( func (s *eventStateKeyStatements) InsertEventStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string, ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) { ) (types.EventStateKeyNID, error) {
insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt)
res, err := insertStmt.ExecContext(ctx, eventStateKey) // INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1)
if err != nil { // ON CONFLICT DO NOTHING;
return 0, err if len(eventStateKey) == 0 {
return 0, cosmosdbutil.ErrNoRows
} }
eventStateKeyNID, err := res.LastInsertId()
if err != nil { var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
return 0, err // event_state_key TEXT NOT NULL UNIQUE
docId := eventStateKey
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
existing, _ := getEventStateKeys(s, ctx, pk, cosmosDocId)
var dbData EventStateKeysCosmosData
if existing == nil {
//Not exists, we need to create a new one with a SEQ
eventStateKeyNIDSeq, seqErr := GetNextEventStateKeyNID(s, ctx)
if seqErr != nil {
return -1, seqErr
}
data := EventStateKeysCosmos{
EventStateKey: eventStateKey,
EventStateKeyNID: eventStateKeyNIDSeq,
}
// event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT,
dbData = EventStateKeysCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
EventStateKeys: data,
}
} else {
dbData.EventStateKeys = existing.EventStateKeys
} }
return types.EventStateKeyNID(eventStateKeyNID), err
err := insertEventStateKeyCore(s, ctx, dbData)
return types.EventStateKeyNID(dbData.EventStateKeys.EventStateKeyNID), err
} }
func (s *eventStateKeyStatements) SelectEventStateKeyNID( func (s *eventStateKeyStatements) SelectEventStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string, ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) { ) (types.EventStateKeyNID, error) {
var eventStateKeyNID int64
stmt := sqlutil.TxStmt(txn, s.selectEventStateKeyNIDStmt) // SELECT event_state_key_nid FROM roomserver_event_state_keys
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) // WHERE event_state_key = $1
return types.EventStateKeyNID(eventStateKeyNID), err
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": eventStateKey,
}
response, err := queryEventStateKeys(s, ctx, s.selectEventStateKeyNIDStmt, params)
if err != nil {
return 0, err
}
//See storage.assignStateKeyNID()
if len(response) == 0 {
return 0, cosmosdbutil.ErrNoRows
}
return types.EventStateKeyNID(response[0].EventStateKeys.EventStateKeyNID), err
} }
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
@ -117,21 +269,25 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
for k, v := range eventStateKeys { for k, v := range eventStateKeys {
iEventStateKeys[k] = v iEventStateKeys[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1)
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...) // SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
// WHERE event_state_key IN ($1)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": eventStateKeys,
}
response, err := queryEventStateKeys(s, ctx, s.bulkSelectEventStateKeyNIDStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKeyNID: rows.close() failed")
result := make(map[string]types.EventStateKeyNID, len(eventStateKeys)) result := make(map[string]types.EventStateKeyNID, len(eventStateKeys))
for rows.Next() { for _, item := range response {
var stateKey string result[item.EventStateKeys.EventStateKey] = types.EventStateKeyNID(item.EventStateKeys.EventStateKeyNID)
var stateKeyNID int64
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
return nil, err
}
result[stateKey] = types.EventStateKeyNID(stateKeyNID)
} }
return result, nil return result, nil
} }
@ -139,25 +295,24 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
func (s *eventStateKeyStatements) BulkSelectEventStateKey( func (s *eventStateKeyStatements) BulkSelectEventStateKey(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) { ) (map[types.EventStateKeyNID]string, error) {
iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
for k, v := range eventStateKeyNIDs {
iEventStateKeyNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1)
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) // SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
// WHERE event_state_key_nid IN ($1)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": eventStateKeyNIDs,
}
response, err := queryEventStateKeys(s, ctx, s.bulkSelectEventStateKeyStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKey: rows.close() failed")
result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs)) result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs))
for rows.Next() { for _, item := range response {
var stateKey string result[types.EventStateKeyNID(item.EventStateKeys.EventStateKeyNID)] = item.EventStateKeys.EventStateKey
var stateKeyNID int64
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
return nil, err
}
result[types.EventStateKeyNID(stateKeyNID)] = stateKey
} }
return result, nil return result, nil
} }

View file

@ -18,30 +18,44 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt" "time"
"strings"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
const eventTypesSchema = ` // const eventTypesSchema = `
CREATE TABLE IF NOT EXISTS roomserver_event_types ( // CREATE TABLE IF NOT EXISTS roomserver_event_types (
event_type_nid INTEGER PRIMARY KEY AUTOINCREMENT, // event_type_nid INTEGER PRIMARY KEY AUTOINCREMENT,
event_type TEXT NOT NULL UNIQUE // event_type TEXT NOT NULL UNIQUE
); // );
INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES // INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES
(1, 'm.room.create'), // (1, 'm.room.create'),
(2, 'm.room.power_levels'), // (2, 'm.room.power_levels'),
(3, 'm.room.join_rules'), // (3, 'm.room.join_rules'),
(4, 'm.room.third_party_invite'), // (4, 'm.room.third_party_invite'),
(5, 'm.room.member'), // (5, 'm.room.member'),
(6, 'm.room.redaction'), // (6, 'm.room.redaction'),
(7, 'm.room.history_visibility') ON CONFLICT DO NOTHING; // (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING;
` // `
type EventTypeCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
EventType EventTypeCosmos `json:"mx_roomserver_event_type"`
}
type EventTypeCosmos struct {
EventTypeNID int64 `json:"event_type_nid"`
EventType string `json:"event_type"`
}
// Assign a new numeric event type ID. // Assign a new numeric event type ID.
// The usual case is that the event type is not in the database. // The usual case is that the event type is not in the database.
@ -56,105 +70,243 @@ const eventTypesSchema = `
// return it. Modifying the rows will cause postgres to assign a new tuple for the // return it. Modifying the rows will cause postgres to assign a new tuple for the
// row even though the data doesn't change resulting in unncesssary modifications // row even though the data doesn't change resulting in unncesssary modifications
// to the indexes. // to the indexes.
const insertEventTypeNIDSQL = ` // const insertEventTypeNIDSQL = `
INSERT INTO roomserver_event_types (event_type) VALUES ($1) // INSERT INTO roomserver_event_types (event_type) VALUES ($1)
ON CONFLICT DO NOTHING; // ON CONFLICT DO NOTHING;
` // `
const insertEventTypeNIDResultSQL = ` // const insertEventTypeNIDResultSQL = `
SELECT event_type_nid FROM roomserver_event_types // SELECT event_type_nid FROM roomserver_event_types
WHERE rowid = last_insert_rowid(); // WHERE rowid = last_insert_rowid();
` // `
const selectEventTypeNIDSQL = ` // const selectEventTypeNIDSQL = `
SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1 // SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1
` // `
// Bulk lookup from string event type to numeric ID for that event type. // Bulk lookup from string event type to numeric ID for that event type.
// Takes an array of strings as the query parameter. // Takes an array of strings as the query parameter.
const bulkSelectEventTypeNIDSQL = ` // SELECT event_type, event_type_nid FROM roomserver_event_types
SELECT event_type, event_type_nid FROM roomserver_event_types // WHERE event_type IN ($1)
WHERE event_type IN ($1) const bulkSelectEventTypeNIDSQL = "" +
` "select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_type.event_type)"
type eventTypeStatements struct { type eventTypeStatements struct {
db *sql.DB db *Database
insertEventTypeNIDStmt *sql.Stmt // insertEventTypeNIDStmt *sql.Stmt
insertEventTypeNIDResultStmt *sql.Stmt // insertEventTypeNIDResultStmt *sql.Stmt
selectEventTypeNIDStmt *sql.Stmt // selectEventTypeNIDStmt *sql.Stmt
bulkSelectEventTypeNIDStmt *sql.Stmt bulkSelectEventTypeNIDStmt string
tableName string
} }
func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) { func NewCosmosDBEventTypesTable(db *Database) (tables.EventTypes, error) {
s := &eventTypeStatements{ s := &eventTypeStatements{
db: db, db: db,
} }
_, err := db.Exec(eventTypesSchema)
// return s, shared.StatementList{
// {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL},
// {&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL},
// {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL},
s.bulkSelectEventTypeNIDStmt = bulkSelectEventTypeNIDSQL
// }.Prepare(db)
s.tableName = "event_types"
ensureEventTypes(s, context.Background())
return s, nil
}
func queryEventTypes(s *eventTypeStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventTypeCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []EventTypeCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return response, nil
return s, shared.StatementList{
{&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL},
{&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL},
{&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL},
{&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL},
}.Prepare(db)
} }
func (s *eventTypeStatements) InsertEventTypeNID( func (s *eventTypeStatements) InsertEventTypeNID(
ctx context.Context, txn *sql.Tx, eventType string, ctx context.Context, txn *sql.Tx, eventType string,
) (types.EventTypeNID, error) { ) (types.EventTypeNID, error) {
var eventTypeNID int64 //We need to create a new one with a SEQ
insertStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDStmt) eventTypeNIDSeq, seqErr := GetNextEventTypeNID(s, ctx)
resultStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDResultStmt) if seqErr != nil {
_, err := insertStmt.ExecContext(ctx, eventType) return -1, seqErr
}
data := EventTypeCosmos{
EventType: eventType,
EventTypeNID: eventTypeNIDSeq,
}
dbData, err := insertEventTypeCore(s, ctx, data)
if err != nil { if err != nil {
return 0, fmt.Errorf("insertStmt.ExecContext: %w", err) return 0, err
} }
if err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID); err != nil {
return 0, fmt.Errorf("resultStmt.QueryRowContext.Scan: %w", err) return types.EventTypeNID(dbData.EventTypeNID), err
}
func insertEventTypeCore(s *eventTypeStatements, ctx context.Context, eventType EventTypeCosmos) (*EventTypeCosmos, error) {
// INSERT INTO roomserver_event_types (event_type) VALUES ($1)
// ON CONFLICT DO NOTHING;
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
//Unique on eventType
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, eventType.EventType)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = EventTypeCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
EventType: eventType,
} }
return types.EventTypeNID(eventTypeNID), err
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
if err != nil {
dbData, errGet := selectEventTypeCore(s, ctx, eventType.EventType)
if errGet != nil {
return nil, errGet
}
return dbData, nil
}
return &dbData.EventType, err
}
func ensureEventTypes(s *eventTypeStatements, ctx context.Context) error {
// INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES
// (1, 'm.room.create'),
// (2, 'm.room.power_levels'),
// (3, 'm.room.join_rules'),
// (4, 'm.room.third_party_invite'),
// (5, 'm.room.member'),
// (6, 'm.room.redaction'),
// (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING;
// (1, 'm.room.create'),
_, err := insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 1, EventType: "m.room.create"})
if err != nil {
return err
}
// (2, 'm.room.power_levels'),
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 2, EventType: "m.room.power_levels"})
if err != nil {
return err
}
// (3, 'm.room.join_rules'),
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 3, EventType: "m.room.join_rules"})
if err != nil {
return err
}
// (4, 'm.room.third_party_invite'),
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 4, EventType: "m.room.third_party_invite"})
if err != nil {
return err
}
// (5, 'm.room.member'),
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 5, EventType: "m.room.member"})
if err != nil {
return err
}
// (6, 'm.room.redaction'),
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 6, EventType: "m.room.redaction"})
if err != nil {
return err
}
// (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING;
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 7, EventType: "m.room.history_visibility"})
if err != nil {
return err
}
return nil
}
func selectEventTypeCore(s *eventTypeStatements, ctx context.Context, eventType string) (*EventTypeCosmos, error) {
var response EventTypeCosmosData
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, eventType)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
cosmosDocId,
&response)
if err != nil {
return nil, err
}
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response.EventType, nil
} }
func (s *eventTypeStatements) SelectEventTypeNID( func (s *eventTypeStatements) SelectEventTypeNID(
ctx context.Context, tx *sql.Tx, eventType string, ctx context.Context, tx *sql.Tx, eventType string,
) (types.EventTypeNID, error) { ) (types.EventTypeNID, error) {
var eventTypeNID int64
selectStmt := sqlutil.TxStmt(tx, s.selectEventTypeNIDStmt) // SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1
err := selectStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID)
return types.EventTypeNID(eventTypeNID), err dbData, err := selectEventTypeCore(s, ctx, eventType)
if err != nil {
return -1, err
}
return types.EventTypeNID(dbData.EventTypeNID), nil
} }
func (s *eventTypeStatements) BulkSelectEventTypeNID( func (s *eventTypeStatements) BulkSelectEventTypeNID(
ctx context.Context, eventTypes []string, ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) { ) (map[string]types.EventTypeNID, error) {
///////////////
iEventTypes := make([]interface{}, len(eventTypes))
for k, v := range eventTypes {
iEventTypes[k] = v
}
selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventTypes)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
///////////////
rows, err := selectPrep.QueryContext(ctx, iEventTypes...) // SELECT event_type, event_type_nid FROM roomserver_event_types
// WHERE event_type IN ($1)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": eventTypes,
}
response, err := queryEventTypes(s, ctx, s.bulkSelectEventTypeNIDStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventTypeNID: rows.close() failed")
result := make(map[string]types.EventTypeNID, len(eventTypes)) result := make(map[string]types.EventTypeNID, len(eventTypes))
for rows.Next() { for _, item := range response {
var eventType string var eventType string
var eventTypeNID int64 var eventTypeNID int64
if err := rows.Scan(&eventType, &eventTypeNID); err != nil { eventType = item.EventType.EventType
return nil, err eventTypeNID = item.EventType.EventTypeNID
}
result[eventType] = types.EventTypeNID(eventTypeNID) result[eventType] = types.EventTypeNID(eventTypeNID)
} }
return result, nil return result, nil

File diff suppressed because it is too large Load diff

View file

@ -18,73 +18,148 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "database/sql"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
const inviteSchema = ` // const inviteSchema = `
CREATE TABLE IF NOT EXISTS roomserver_invites ( // CREATE TABLE IF NOT EXISTS roomserver_invites (
invite_event_id TEXT PRIMARY KEY, // invite_event_id TEXT PRIMARY KEY,
room_nid INTEGER NOT NULL, // room_nid INTEGER NOT NULL,
target_nid INTEGER NOT NULL, // target_nid INTEGER NOT NULL,
sender_nid INTEGER NOT NULL DEFAULT 0, // sender_nid INTEGER NOT NULL DEFAULT 0,
retired BOOLEAN NOT NULL DEFAULT FALSE, // retired BOOLEAN NOT NULL DEFAULT FALSE,
invite_event_json TEXT NOT NULL // invite_event_json TEXT NOT NULL
); // );
CREATE INDEX IF NOT EXISTS roomserver_invites_active_idx ON roomserver_invites (target_nid, room_nid) // CREATE INDEX IF NOT EXISTS roomserver_invites_active_idx ON roomserver_invites (target_nid, room_nid)
WHERE NOT retired; // WHERE NOT retired;
` // `
const insertInviteEventSQL = "" +
"INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," +
" sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT DO NOTHING"
type InviteCosmos struct {
InviteEventID string `json:"invite_event_id"`
RoomNID int64 `json:"room_nid"`
TargetNID int64 `json:"target_nid"`
SenderNID int64 `json:"sender_nid"`
Retired bool `json:"retired"`
InviteEventJSON []byte `json:"invite_event_json"`
}
type InviteCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Invite InviteCosmos `json:"mx_roomserver_invite"`
}
// const insertInviteEventSQL = "" +
// "INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," +
// " sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" +
// " ON CONFLICT DO NOTHING"
// "SELECT invite_event_id, sender_nid FROM roomserver_invites" +
// " WHERE target_nid = $1 AND room_nid = $2" +
// " AND NOT retired"
const selectInviteActiveForUserInRoomSQL = "" + const selectInviteActiveForUserInRoomSQL = "" +
"SELECT invite_event_id, sender_nid FROM roomserver_invites" + "select * from c where c._cn = @x1 " +
" WHERE target_nid = $1 AND room_nid = $2" + " and c.mx_roomserver_invite.target_nid = @x2" +
" AND NOT retired" " and c.mx_roomserver_invite.room_nid = @x3" +
" and c.mx_roomserver_invite.retired = false"
// Retire every active invite for a user in a room. // Retire every active invite for a user in a room.
// Ideally we'd know which invite events were retired by a given update so we // Ideally we'd know which invite events were retired by a given update so we
// wouldn't need to remove every active invite. // wouldn't need to remove every active invite.
// However the matrix protocol doesn't give us a way to reliably identify the // However the matrix protocol doesn't give us a way to reliably identify the
// invites that were retired, so we are forced to retire all of them. // invites that were retired, so we are forced to retire all of them.
const updateInviteRetiredSQL = ` // const updateInviteRetiredSQL = `
UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired // UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
` // `
const selectInvitesAboutToRetireSQL = ` // SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired const selectInvitesAboutToRetireSQL = "" +
` "select * from c where c._cn = @x1 " +
" and c.mx_roomserver_invite.room_nid = @x2" +
" and c.mx_roomserver_invite.target_nid = @x3" +
" and c.mx_roomserver_invite.retired = false"
type inviteStatements struct { type inviteStatements struct {
db *sql.DB db *Database
insertInviteEventStmt *sql.Stmt // insertInviteEventStmt *sql.Stmt
selectInviteActiveForUserInRoomStmt *sql.Stmt selectInviteActiveForUserInRoomStmt string
updateInviteRetiredStmt *sql.Stmt // updateInviteRetiredStmt *sql.Stmt
selectInvitesAboutToRetireStmt *sql.Stmt selectInvitesAboutToRetireStmt string
tableName string
} }
func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) { func queryInvite(s *inviteStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InviteCosmosData, error) {
s := &inviteStatements{ var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
db: db, var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
} var response []InviteCosmosData
_, err := db.Exec(inviteSchema)
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return response, nil
}
return s, shared.StatementList{ func getInvite(s *inviteStatements, ctx context.Context, pk string, docId string) (*InviteCosmosData, error) {
{&s.insertInviteEventStmt, insertInviteEventSQL}, response := InviteCosmosData{}
{&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, err := cosmosdbapi.GetDocumentOrNil(
{&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, s.db.connection,
{&s.selectInvitesAboutToRetireStmt, selectInvitesAboutToRetireSQL}, s.db.cosmosConfig,
}.Prepare(db) ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func setInvite(s *inviteStatements, ctx context.Context, invite InviteCosmosData) (*InviteCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(invite.Pk, invite.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
invite.Id,
&invite,
optionsReplace)
return &invite, ex
}
func NewCosmosDBInvitesTable(db *Database) (tables.Invites, error) {
s := &inviteStatements{
db: db,
}
// return s, shared.StatementList{
// {&s.insertInviteEventStmt, insertInviteEventSQL},
s.selectInviteActiveForUserInRoomStmt = selectInviteActiveForUserInRoomSQL
// {&s.updateInviteRetiredStmt, updateInviteRetiredSQL},
s.selectInvitesAboutToRetireStmt = selectInvitesAboutToRetireSQL
// }.Prepare(db)
s.tableName = "invites"
return s, nil
} }
func (s *inviteStatements) InsertInviteEvent( func (s *inviteStatements) InsertInviteEvent(
@ -93,42 +168,84 @@ func (s *inviteStatements) InsertInviteEvent(
targetUserNID, senderUserNID types.EventStateKeyNID, targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte, inviteEventJSON []byte,
) (bool, error) { ) (bool, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) // "INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," +
result, err := stmt.ExecContext( // " sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" +
ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, // " ON CONFLICT DO NOTHING"
)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
data := InviteCosmos{
InviteEventID: inviteEventID,
InviteEventJSON: inviteEventJSON,
Retired: false,
RoomNID: int64(roomNID),
SenderNID: int64(senderUserNID),
TargetNID: int64(targetUserNID),
}
// invite_event_id TEXT PRIMARY KEY,
docId := inviteEventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = InviteCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
Invite: data,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
if err != nil { if err != nil {
return false, err return false, err
} }
count, err = result.RowsAffected() // TODO: Is this important?
if err != nil { // count, err = result.RowsAffected()
return false, err // return count != 0, err
} return true, nil
return count != 0, err
} }
func (s *inviteStatements) UpdateInviteRetired( func (s *inviteStatements) UpdateInviteRetired(
ctx context.Context, ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) { ) (eventIDs []string, err error) {
// "SELECT invite_event_id, sender_nid FROM roomserver_invites" +
// " WHERE target_nid = $1 AND room_nid = $2" +
// " AND NOT retired"
// gather all the event IDs we will retire // gather all the event IDs we will retire
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": targetUserNID,
"@x3": roomNID,
}
response, err := queryInvite(s, ctx, s.selectInvitesAboutToRetireStmt, params)
if err != nil { if err != nil {
return return
} }
defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed")
for rows.Next() { for _, item := range response {
var inviteEventID string eventIDs = append(eventIDs, item.Invite.InviteEventID)
if err = rows.Scan(&inviteEventID); err != nil { // UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
return
} // now retire the invites
eventIDs = append(eventIDs, inviteEventID) item.Invite.Retired = true
_, err = setInvite(s, ctx, item)
} }
// now retire the invites
stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
_, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
return return
} }
@ -137,21 +254,27 @@ func (s *inviteStatements) SelectInviteActiveForUserInRoom(
ctx context.Context, ctx context.Context,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, []string, error) { ) ([]types.EventStateKeyNID, []string, error) {
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
ctx, targetUserNID, roomNID, // SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomNID,
"@x3": targetUserNID,
}
response, err := queryInvite(s, ctx, s.selectInviteActiveForUserInRoomStmt, params)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed")
var result []types.EventStateKeyNID var result []types.EventStateKeyNID
var eventIDs []string var eventIDs []string
for rows.Next() { for _, item := range response {
var eventID string var eventID = item.Invite.InviteEventID
var senderUserNID int64 var senderUserNID = item.Invite.SenderNID
if err := rows.Scan(&eventID, &senderUserNID); err != nil {
return nil, nil, err
}
result = append(result, types.EventStateKeyNID(senderUserNID)) result = append(result, types.EventStateKeyNID(senderUserNID))
eventIDs = append(eventIDs, eventID) eventIDs = append(eventIDs, eventID)
} }

View file

@ -19,125 +19,233 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings" "time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
const membershipSchema = ` // const membershipSchema = `
CREATE TABLE IF NOT EXISTS roomserver_membership ( // CREATE TABLE IF NOT EXISTS roomserver_membership (
room_nid INTEGER NOT NULL, // room_nid INTEGER NOT NULL,
target_nid INTEGER NOT NULL, // target_nid INTEGER NOT NULL,
sender_nid INTEGER NOT NULL DEFAULT 0, // sender_nid INTEGER NOT NULL DEFAULT 0,
membership_nid INTEGER NOT NULL DEFAULT 1, // membership_nid INTEGER NOT NULL DEFAULT 1,
event_nid INTEGER NOT NULL DEFAULT 0, // event_nid INTEGER NOT NULL DEFAULT 0,
target_local BOOLEAN NOT NULL DEFAULT false, // target_local BOOLEAN NOT NULL DEFAULT false,
forgotten BOOLEAN NOT NULL DEFAULT false, // forgotten BOOLEAN NOT NULL DEFAULT false,
UNIQUE (room_nid, target_nid) // UNIQUE (room_nid, target_nid)
); // );
` // `
type MembershipCosmos struct {
RoomNID int64 `json:"room_nid"`
TargetNID int64 `json:"target_nid"`
SenderNID int64 `json:"sender_nid"`
MembershipNID int64 `json:"membership_nid"`
EventNID int64 `json:"event_nid"`
TargetLocal bool `json:"target_local"`
Forgotten bool `json:"forgotten"`
}
type MembershipCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Membership MembershipCosmos `json:"mx_roomserver_membership"`
}
type MembershipJoinedCountCosmosData struct {
TargetNID int64 `json:"target_nid"`
RoomCount int `json:"room_count"`
}
// "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" +
// " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
// " GROUP BY target_nid"
var selectJoinedUsersSetForRoomsSQL = "" + var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" + "select c.mx_roomserver_membership.target_nid, count(c.mx_roomserver_membership.room_id) as room_count from c where c._cn = @x1 " +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " and ARRAY_CONTAINS(@x2, c.mx_roomserver_membership.room_id)" +
" GROUP BY target_nid" " and c.mx_roomserver_membership.membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
" and c.mx_roomserver_membership.forgotten = false" +
" group by c.mx_roomserver_membership.target_nid"
// Insert a row in to membership table so that it can be locked by the // Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE // SELECT FOR UPDATE
const insertMembershipSQL = "" + // const insertMembershipSQL = "" +
"INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" + // "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" +
" VALUES ($1, $2, $3)" + // " VALUES ($1, $2, $3)" +
" ON CONFLICT DO NOTHING" // " ON CONFLICT DO NOTHING"
const selectMembershipFromRoomAndTargetSQL = "" + // const selectMembershipFromRoomAndTargetSQL = "" +
"SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" + // "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2" // " WHERE room_nid = $1 AND target_nid = $2"
// "SELECT event_nid FROM roomserver_membership" +
// " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false"
const selectMembershipsFromRoomAndMembershipSQL = "" + const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "select * from c where c._cn = @x1 " +
" WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false" " and c.mx_roomserver_membership.room_nid = @x2" +
" and c.mx_roomserver_membership.membership_nid = @x3" +
" and c.mx_roomserver_membership.forgotten = false"
// "SELECT event_nid FROM roomserver_membership" +
// " WHERE room_nid = $1 AND membership_nid = $2" +
// " AND target_local = true and forgotten = false"
const selectLocalMembershipsFromRoomAndMembershipSQL = "" + const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "select * from c where c._cn = @x1 " +
" WHERE room_nid = $1 AND membership_nid = $2" + " and c.mx_roomserver_membership.room_nid = @x2" +
" AND target_local = true and forgotten = false" " and c.mx_roomserver_membership.membership_nid = @x3" +
" and c.mx_roomserver_membership.target_local = true" +
" and c.mx_roomserver_membership.forgotten = false"
// "SELECT event_nid FROM roomserver_membership" +
// " WHERE room_nid = $1 and forgotten = false"
const selectMembershipsFromRoomSQL = "" + const selectMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "select * from c where c._cn = @x1 " +
" WHERE room_nid = $1 and forgotten = false" " and c.mx_roomserver_membership.room_nid = @x2" +
" and c.mx_roomserver_membership.forgotten = false"
// "SELECT event_nid FROM roomserver_membership" +
// " WHERE room_nid = $1" +
// " AND target_local = true and forgotten = false"
const selectLocalMembershipsFromRoomSQL = "" + const selectLocalMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "select * from c where c._cn = @x1 " +
" WHERE room_nid = $1" + " and c.mx_roomserver_membership.room_nid = @x2" +
" AND target_local = true and forgotten = false" " and c.mx_roomserver_membership.target_local = true" +
" and c.mx_roomserver_membership.forgotten = false"
const selectMembershipForUpdateSQL = "" + // const selectMembershipForUpdateSQL = "" +
"SELECT membership_nid FROM roomserver_membership" + // "SELECT membership_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2" // " WHERE room_nid = $1 AND target_nid = $2"
const updateMembershipSQL = "" + // const updateMembershipSQL = "" +
"UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" + // "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" +
" WHERE room_nid = $5 AND target_nid = $6" // " WHERE room_nid = $5 AND target_nid = $6"
const updateMembershipForgetRoom = "" + // const updateMembershipForgetRoom = "" +
"UPDATE roomserver_membership SET forgotten = $1" + // "UPDATE roomserver_membership SET forgotten = $1" +
" WHERE room_nid = $2 AND target_nid = $3" // " WHERE room_nid = $2 AND target_nid = $3"
// "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
const selectRoomsWithMembershipSQL = "" + const selectRoomsWithMembershipSQL = "" +
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false" "select * from c where c._cn = @x1 " +
" and c.mx_roomserver_membership.membership_nid = @x2" +
" and c.mx_roomserver_membership.target_nid = true" +
" and c.mx_roomserver_membership.forgotten = false"
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is // selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
// joined to. Since this information is used to populate the user directory, we will // joined to. Since this information is used to populate the user directory, we will
// only return users that the user would ordinarily be able to see anyway. // only return users that the user would ordinarily be able to see anyway.
var selectKnownUsersSQL = "" + // var selectKnownUsersSQL = "" +
"SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " + // "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " +
"roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + // "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
" WHERE room_nid IN (" + // " WHERE room_nid IN (" +
" SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + // " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3" // ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3"
var selectKnownUsersSQLRooms = "" +
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_membership.room_id)"
var selectKnownUsersSQLDistinctRoom = "" +
"select distinct top @x4 c.mx_roomserver_membership.room_nid as room_nid from c where c._cn = @x1 " +
"and c.mx_roomserver_membership.target_nid = @x2 " +
"and c.mx_roomserver_membership.membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " " +
"and contains(c.mx_roomserver_membership.event_state_key, @x3) "
type membershipStatements struct { type membershipStatements struct {
db *sql.DB db *Database
insertMembershipStmt *sql.Stmt // insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt // selectMembershipForUpdateStmt string
selectMembershipFromRoomAndTargetStmt *sql.Stmt // selectMembershipFromRoomAndTargetStmt string
selectMembershipsFromRoomAndMembershipStmt *sql.Stmt selectMembershipsFromRoomAndMembershipStmt string
selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt selectLocalMembershipsFromRoomAndMembershipStmt string
selectMembershipsFromRoomStmt *sql.Stmt selectMembershipsFromRoomStmt string
selectLocalMembershipsFromRoomStmt *sql.Stmt selectLocalMembershipsFromRoomStmt string
selectRoomsWithMembershipStmt *sql.Stmt selectRoomsWithMembershipStmt string
updateMembershipStmt *sql.Stmt // updateMembershipStmt *sql.Stmt
selectKnownUsersStmt *sql.Stmt // selectKnownUsersStmt string
updateMembershipForgetRoomStmt *sql.Stmt // updateMembershipForgetRoomStmt *sql.Stmt
tableName string
} }
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { func queryMembership(s *membershipStatements, ctx context.Context, qry string, params map[string]interface{}) ([]MembershipCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []MembershipCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getMembership(s *membershipStatements, ctx context.Context, pk string, docId string) (*MembershipCosmosData, error) {
response := MembershipCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func setMembership(s *membershipStatements, ctx context.Context, pk string, membership MembershipCosmosData) (*MembershipCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, membership.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
membership.Id,
&membership,
optionsReplace)
return &membership, ex
}
func NewCosmosDBMembershipTable(db *Database) (tables.Membership, error) {
s := &membershipStatements{ s := &membershipStatements{
db: db, db: db,
} }
return s, shared.StatementList{ // return s, shared.StatementList{
{&s.insertMembershipStmt, insertMembershipSQL}, // {&s.insertMembershipStmt, insertMembershipSQL},
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, // s.selectMembershipForUpdateStmt = selectMembershipForUpdateSQL
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, // s.selectMembershipFromRoomAndTargetStmt = selectMembershipFromRoomAndTargetSQL
{&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL}, s.selectMembershipsFromRoomAndMembershipStmt = selectMembershipsFromRoomAndMembershipSQL
{&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL}, s.selectLocalMembershipsFromRoomAndMembershipStmt = selectLocalMembershipsFromRoomAndMembershipSQL
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, s.selectMembershipsFromRoomStmt = selectMembershipsFromRoomSQL
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, s.selectLocalMembershipsFromRoomStmt = selectLocalMembershipsFromRoomSQL
{&s.updateMembershipStmt, updateMembershipSQL}, // {&s.updateMembershipStmt, updateMembershipSQL},
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, s.selectRoomsWithMembershipStmt = selectRoomsWithMembershipSQL
{&s.selectKnownUsersStmt, selectKnownUsersSQL}, // {&s.selectKnownUsersStmt, selectKnownUsersSQL},
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, // {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
}.Prepare(db) // }.Prepare(db)
}
func (s *membershipStatements) execSchema(db *sql.DB) error { s.tableName = "memberships"
_, err := db.Exec(membershipSchema) return s, nil
return err
} }
func (s *membershipStatements) InsertMembership( func (s *membershipStatements) InsertMembership(
@ -145,8 +253,45 @@ func (s *membershipStatements) InsertMembership(
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
localTarget bool, localTarget bool,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) // "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" +
// " VALUES ($1, $2, $3)" +
// " ON CONFLICT DO NOTHING"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (room_nid, target_nid)
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := MembershipCosmos{
EventNID: 0,
Forgotten: false,
MembershipNID: 1,
RoomNID: int64(roomNID),
SenderNID: 0,
TargetLocal: false,
TargetNID: int64(targetUserNID),
}
var dbData = MembershipCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
Membership: data,
}
// " ON CONFLICT DO NOTHING"
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err return err
} }
@ -154,10 +299,18 @@ func (s *membershipStatements) SelectMembershipForUpdate(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (membership tables.MembershipState, err error) { ) (membership tables.MembershipState, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt)
err = stmt.QueryRowContext( // "SELECT membership_nid FROM roomserver_membership" +
ctx, roomNID, targetUserNID, // " WHERE room_nid = $1 AND target_nid = $2"
).Scan(&membership)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response, err := getMembership(s, ctx, pk, cosmosDocId)
if response != nil {
membership = tables.MembershipState(response.Membership.MembershipNID)
}
return return
} }
@ -165,9 +318,20 @@ func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { ) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
ctx, roomNID, targetUserNID, // "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" +
).Scan(&membership, &eventNID, &forgotten) // " WHERE room_nid = $1 AND target_nid = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response, err := getMembership(s, ctx, pk, cosmosDocId)
if response != nil {
eventNID = types.EventNID(response.Membership.EventNID)
forgotten = response.Membership.Forgotten
membership = tables.MembershipState(response.Membership.MembershipNID)
}
return return
} }
@ -175,24 +339,31 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID, localOnly bool, roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var selectStmt *sql.Stmt var selectStmt string
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomNID,
}
if localOnly { if localOnly {
// "SELECT event_nid FROM roomserver_membership" +
// " WHERE room_nid = $1" +
// " AND target_local = true and forgotten = false"
selectStmt = s.selectLocalMembershipsFromRoomStmt selectStmt = s.selectLocalMembershipsFromRoomStmt
} else { } else {
// "SELECT event_nid FROM roomserver_membership" +
// " WHERE room_nid = $1 and forgotten = false"
selectStmt = s.selectMembershipsFromRoomStmt selectStmt = s.selectMembershipsFromRoomStmt
} }
rows, err := selectStmt.QueryContext(ctx, roomNID) response, err := queryMembership(s, ctx, selectStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoom: rows.close() failed")
for rows.Next() { for _, item := range response {
var eNID types.EventNID eventNIDs = append(eventNIDs, types.EventNID(item.Membership.EventNID))
if err = rows.Scan(&eNID); err != nil {
return
}
eventNIDs = append(eventNIDs, eNID)
} }
return return
} }
@ -201,24 +372,31 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var stmt *sql.Stmt var stmt string
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomNID,
"@x3": membership,
}
if localOnly { if localOnly {
// "SELECT event_nid FROM roomserver_membership" +
// " WHERE room_nid = $1 AND membership_nid = $2" +
// " AND target_local = true and forgotten = false"
stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt
} else { } else {
// "SELECT event_nid FROM roomserver_membership" +
// " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false"
stmt = s.selectMembershipsFromRoomAndMembershipStmt stmt = s.selectMembershipsFromRoomAndMembershipStmt
} }
rows, err := stmt.QueryContext(ctx, roomNID, membership) response, err := queryMembership(s, ctx, stmt, params)
if err != nil { if err != nil {
return return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoomAndMembership: rows.close() failed")
for rows.Next() { for _, item := range response {
var eNID types.EventNID eventNIDs = append(eventNIDs, types.EventNID(item.Membership.EventNID))
if err = rows.Scan(&eNID); err != nil {
return
}
eventNIDs = append(eventNIDs, eNID)
} }
return return
} }
@ -228,28 +406,48 @@ func (s *membershipStatements) UpdateMembership(
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
eventNID types.EventNID, forgotten bool, eventNID types.EventNID, forgotten bool,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
_, err := stmt.ExecContext( // "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" +
ctx, senderUserNID, membership, eventNID, forgotten, roomNID, targetUserNID, // " WHERE room_nid = $5 AND target_nid = $6"
)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
dbData, err := getMembership(s, ctx, pk, cosmosDocId)
if err != nil {
return err
}
dbData.Membership.SenderNID = int64(senderUserNID)
dbData.Membership.MembershipNID = int64(membership)
dbData.Membership.EventNID = int64(eventNID)
dbData.Membership.Forgotten = forgotten
_, err = setMembership(s, ctx, pk, *dbData)
return err return err
} }
func (s *membershipStatements) SelectRoomsWithMembership( func (s *membershipStatements) SelectRoomsWithMembership(
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) { ) ([]types.RoomNID, error) {
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
// "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": membershipState,
"@x3": userID,
}
response, err := queryMembership(s, ctx, s.selectRoomsWithMembershipStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed")
var roomNIDs []types.RoomNID var roomNIDs []types.RoomNID
for rows.Next() { for _, item := range response {
var roomNID types.RoomNID roomNIDs = append(roomNIDs, types.RoomNID(item.Membership.RoomNID))
if err := rows.Scan(&roomNID); err != nil {
return nil, err
}
roomNIDs = append(roomNIDs, roomNID)
} }
return roomNIDs, nil return roomNIDs, nil
} }
@ -259,39 +457,136 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
for i, v := range roomNIDs { for i, v := range roomNIDs {
iRoomNIDs[i] = v iRoomNIDs[i] = v
} }
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...) // "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" +
// " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
// " GROUP BY target_nid"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomNIDs,
}
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []MembershipJoinedCountCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectJoinedUsersSetForRoomsSQL, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
result := make(map[types.EventStateKeyNID]int) result := make(map[types.EventStateKeyNID]int)
for rows.Next() { for _, item := range response {
var userID types.EventStateKeyNID userID := types.EventStateKeyNID(item.TargetNID)
var count int count := item.RoomCount
if err := rows.Scan(&userID, &count); err != nil {
return nil, err
}
result[userID] = count result[userID] = count
} }
return result, rows.Err() return result, nil
} }
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
// " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
// ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": userID,
"@x3": searchString,
"@x4": limit,
}
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var responseDistinctRoom []MembershipCosmos
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectKnownUsersSQLDistinctRoom, params) //
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&responseDistinctRoom,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rooms := []int64{}
for _, item := range responseDistinctRoom {
rooms = append(rooms, item.RoomNID)
}
// "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " +
// "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
// " WHERE room_nid IN (" +
params = map[string]interface{}{
"@x1": dbCollectionName,
"@x2": rooms,
}
var responseRooms []MembershipCosmos
query = cosmosdbapi.GetQuery(selectKnownUsersSQLRooms, params)
_, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&responseRooms,
optionsQry)
if err != nil {
return nil, err
}
targetNIDs := []int64{}
for _, item := range responseRooms {
targetNIDs = append(targetNIDs, item.TargetNID)
}
// HACK: Joined table
var dbCollectionNameEventStateKeys = cosmosdbapi.GetCollectionName(s.db.databaseName, "event_state_keys")
params = map[string]interface{}{
"@x1": dbCollectionNameEventStateKeys,
"@x2": targetNIDs,
}
bulkSelectEventStateKeyStmt := "select * from c where c._cn = @x1 and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_state_keys.event_state_key_nid)"
var responseEventStateKeys []EventStateKeysCosmos
query = cosmosdbapi.GetQuery(bulkSelectEventStateKeyStmt, params)
_, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&responseEventStateKeys,
optionsQry)
if err != nil {
return nil, err
}
// SELECT DISTINCT event_state_key
result := []string{} result := []string{}
defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed") for _, item := range responseEventStateKeys {
for rows.Next() { userID := item.EventStateKey
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
result = append(result, userID) result = append(result, userID)
} }
return result, rows.Err() return result, nil
} }
func (s *membershipStatements) UpdateForgetMembership( func (s *membershipStatements) UpdateForgetMembership(
@ -299,8 +594,22 @@ func (s *membershipStatements) UpdateForgetMembership(
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
forget bool, forget bool,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
ctx, forget, roomNID, targetUserNID, // "UPDATE roomserver_membership SET forgotten = $1" +
) // " WHERE room_nid = $2 AND target_nid = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
dbData, err := getMembership(s, ctx, pk, cosmosDocId)
if err != nil {
return err
}
dbData.Membership.Forgotten = forget
_, err = setMembership(s, ctx, pk, *dbData)
return err return err
} }

View file

@ -20,9 +20,12 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -32,59 +35,90 @@ import (
// In Postgres an empty BYTEA field is not NULL so it's fine there. In SQLite it // In Postgres an empty BYTEA field is not NULL so it's fine there. In SQLite it
// seems to care that it's empty and therefore hits a NOT NULL constraint on insert. // seems to care that it's empty and therefore hits a NOT NULL constraint on insert.
// We should really work out what the right thing to do here is. // We should really work out what the right thing to do here is.
const previousEventSchema = ` // const previousEventSchema = `
CREATE TABLE IF NOT EXISTS roomserver_previous_events ( // CREATE TABLE IF NOT EXISTS roomserver_previous_events (
previous_event_id TEXT NOT NULL, // previous_event_id TEXT NOT NULL,
previous_reference_sha256 BLOB, // previous_reference_sha256 BLOB,
event_nids TEXT NOT NULL, // event_nids TEXT NOT NULL,
UNIQUE (previous_event_id, previous_reference_sha256) // UNIQUE (previous_event_id, previous_reference_sha256)
); // );
` // `
type PreviousEventCosmos struct {
PreviousEventID string `json:"previous_event_id"`
PreviousReferenceSha256 []byte `json:"previous_reference_sha256"`
EventNIDs string `json:"event_nids"`
}
type PreviousEventCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
PreviousEvent PreviousEventCosmos `json:"mx_roomserver_previous_event"`
}
// Insert an entry into the previous_events table. // Insert an entry into the previous_events table.
// If there is already an entry indicating that an event references that previous event then // If there is already an entry indicating that an event references that previous event then
// add the event NID to the list to indicate that this event references that previous event as well. // add the event NID to the list to indicate that this event references that previous event as well.
// This should only be modified while holding a "FOR UPDATE" lock on the row in the rooms table for this room. // This should only be modified while holding a "FOR UPDATE" lock on the row in the rooms table for this room.
// The lock is necessary to avoid data races when checking whether an event is already referenced by another event. // The lock is necessary to avoid data races when checking whether an event is already referenced by another event.
const insertPreviousEventSQL = ` // const insertPreviousEventSQL = `
INSERT OR REPLACE INTO roomserver_previous_events // INSERT OR REPLACE INTO roomserver_previous_events
(previous_event_id, previous_reference_sha256, event_nids) // (previous_event_id, previous_reference_sha256, event_nids)
VALUES ($1, $2, $3) // VALUES ($1, $2, $3)
` // `
const selectPreviousEventNIDsSQL = ` // const selectPreviousEventNIDsSQL = `
SELECT event_nids FROM roomserver_previous_events // SELECT event_nids FROM roomserver_previous_events
WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 // WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
` // `
// Check if the event is referenced by another event in the table. // Check if the event is referenced by another event in the table.
// This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room. // This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room.
const selectPreviousEventExistsSQL = ` // const selectPreviousEventExistsSQL = `
SELECT 1 FROM roomserver_previous_events // SELECT 1 FROM roomserver_previous_events
WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 // WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
` // `
type previousEventStatements struct { type previousEventStatements struct {
db *sql.DB db *Database
insertPreviousEventStmt *sql.Stmt // insertPreviousEventStmt *sql.Stmt
selectPreviousEventNIDsStmt *sql.Stmt // selectPreviousEventNIDsStmt *sql.Stmt
selectPreviousEventExistsStmt *sql.Stmt // selectPreviousEventExistsStmt *sql.Stmt
tableName string
} }
func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { func getPreviousEvent(s *previousEventStatements, ctx context.Context, pk string, docId string) (*PreviousEventCosmosData, error) {
response := PreviousEventCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func NewCosmosDBPrevEventsTable(db *Database) (tables.PreviousEvents, error) {
s := &previousEventStatements{ s := &previousEventStatements{
db: db, db: db,
} }
_, err := db.Exec(previousEventSchema)
if err != nil {
return nil, err
}
return s, shared.StatementList{ // return s, shared.StatementList{
{&s.insertPreviousEventStmt, insertPreviousEventSQL}, // {&s.insertPreviousEventStmt, insertPreviousEventSQL},
{&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL}, // {&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL},
{&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, // {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
}.Prepare(db) // }.Prepare(db)
s.tableName = "previous_events"
return s, nil
} }
func (s *previousEventStatements) InsertPreviousEvent( func (s *previousEventStatements) InsertPreviousEvent(
@ -94,28 +128,71 @@ func (s *previousEventStatements) InsertPreviousEvent(
previousEventReferenceSHA256 []byte, previousEventReferenceSHA256 []byte,
eventNID types.EventNID, eventNID types.EventNID,
) error { ) error {
var eventNIDs string
eventNIDAsString := fmt.Sprintf("%d", eventNID) eventNIDAsString := fmt.Sprintf("%d", eventNID)
selectStmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt)
err := selectStmt.QueryRowContext(ctx, previousEventID, previousEventReferenceSHA256).Scan(&eventNIDs) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
if err != nil && err != sql.ErrNoRows {
return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err) // UNIQUE (previous_event_id, previous_reference_sha256)
// TODO: Check value
// docId := fmt.Sprintf("%s_%s", previousEventID, previousEventReferenceSHA256)
docId := previousEventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// SELECT 1 FROM roomserver_previous_events
// WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
existing, err := getPreviousEvent(s, ctx, pk, cosmosDocId)
if err != nil {
if err != cosmosdbutil.ErrNoRows {
return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err)
}
} }
var dbData PreviousEventCosmosData
// Doesnt exist, create a new one
if existing == nil {
data := PreviousEventCosmos{
EventNIDs: "",
PreviousEventID: previousEventID,
PreviousReferenceSha256: previousEventReferenceSHA256,
}
dbData = PreviousEventCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
PreviousEvent: data,
}
} else {
dbData = *existing
}
var nids []string var nids []string
if eventNIDs != "" { if dbData.PreviousEvent.EventNIDs != "" {
nids = strings.Split(eventNIDs, ",") nids = strings.Split(dbData.PreviousEvent.EventNIDs, ",")
for _, nid := range nids { for _, nid := range nids {
if nid == eventNIDAsString { if nid == eventNIDAsString {
return nil return nil
} }
} }
eventNIDs = strings.Join(append(nids, eventNIDAsString), ",") dbData.PreviousEvent.EventNIDs = strings.Join(append(nids, eventNIDAsString), ",")
} else { } else {
eventNIDs = eventNIDAsString dbData.PreviousEvent.EventNIDs = eventNIDAsString
} }
insertStmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
_, err = insertStmt.ExecContext( // INSERT OR REPLACE INTO roomserver_previous_events
ctx, previousEventID, previousEventReferenceSHA256, eventNIDs, // (previous_event_id, previous_reference_sha256, event_nids)
// VALUES ($1, $2, $3)
var optionsReplace = cosmosdbapi.GetUpsertDocumentOptions(pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
optionsReplace,
) )
return err return err
} }
@ -125,7 +202,24 @@ func (s *previousEventStatements) InsertPreviousEvent(
func (s *previousEventStatements) SelectPreviousEventExists( func (s *previousEventStatements) SelectPreviousEventExists(
ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte,
) error { ) error {
var ok int64 var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt)
return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) // UNIQUE (previous_event_id, previous_reference_sha256)
// TODO: Check value
// docId := fmt.Sprintf("%s_%s", previousEventID, previousEventReferenceSHA256)
docId := eventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, string(docId))
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// SELECT 1 FROM roomserver_previous_events
// WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
dbData, err := getPreviousEvent(s, ctx, pk, cosmosDocId)
if err != nil {
return err
}
if dbData == nil {
return cosmosdbutil.ErrNoRows
}
return nil
} }

View file

@ -17,89 +17,199 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "database/sql"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
) )
const publishedSchema = ` // const publishedSchema = `
-- Stores which rooms are published in the room directory // -- Stores which rooms are published in the room directory
CREATE TABLE IF NOT EXISTS roomserver_published ( // CREATE TABLE IF NOT EXISTS roomserver_published (
-- The room ID of the room // -- The room ID of the room
room_id TEXT NOT NULL PRIMARY KEY, // room_id TEXT NOT NULL PRIMARY KEY,
-- Whether it is published or not // -- Whether it is published or not
published BOOLEAN NOT NULL DEFAULT false // published BOOLEAN NOT NULL DEFAULT false
); // );
` // `
const upsertPublishedSQL = "" + type PublishCosmos struct {
"INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)" RoomID string `json:"room_id"`
Published bool `json:"published"`
const selectAllPublishedSQL = "" +
"SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC"
const selectPublishedSQL = "" +
"SELECT published FROM roomserver_published WHERE room_id = $1"
type publishedStatements struct {
db *sql.DB
upsertPublishedStmt *sql.Stmt
selectAllPublishedStmt *sql.Stmt
selectPublishedStmt *sql.Stmt
} }
func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) { type PublishCosmosData struct {
s := &publishedStatements{ Id string `json:"id"`
db: db, Pk string `json:"_pk"`
} Cn string `json:"_cn"`
_, err := db.Exec(publishedSchema) ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Publish PublishCosmos `json:"mx_roomserver_publish"`
}
// const upsertPublishedSQL = "" +
// "INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)"
// "SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC"
const selectAllPublishedSQL = "" +
"select * from c where c._cn = @x1 " +
" and c.mx_roomserver_publish.published = @x2" +
" order by c.mx_roomserver_publish.room_id asc"
// const selectPublishedSQL = "" +
// "SELECT published FROM roomserver_published WHERE room_id = $1"
type publishedStatements struct {
db *Database
// upsertPublishedStmt *sql.Stmt
selectAllPublishedStmt string
// selectPublishedStmt *sql.Stmt
tableName string
}
func queryPublish(s *publishedStatements, ctx context.Context, qry string, params map[string]interface{}) ([]PublishCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []PublishCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return s, shared.StatementList{ return response, nil
{&s.upsertPublishedStmt, upsertPublishedSQL}, }
{&s.selectAllPublishedStmt, selectAllPublishedSQL},
{&s.selectPublishedStmt, selectPublishedSQL}, func getPublish(s *publishedStatements, ctx context.Context, pk string, docId string) (*PublishCosmosData, error) {
}.Prepare(db) response := PublishCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func setPublish(s *publishedStatements, ctx context.Context, pk string, publish PublishCosmosData) (*PublishCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, publish.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
publish.Id,
&publish,
optionsReplace)
return &publish, ex
}
func NewCosmosDBPublishedTable(db *Database) (tables.Published, error) {
s := &publishedStatements{
db: db,
}
// _, err := db.Exec(publishedSchema)
// if err != nil {
// return nil, err
// }
// return s, shared.StatementList{
// {&s.upsertPublishedStmt, upsertPublishedSQL},
s.selectAllPublishedStmt = selectAllPublishedSQL
// {&s.selectPublishedStmt, selectPublishedSQL},
// }.Prepare(db)
s.tableName = "published"
return s, nil
} }
func (s *publishedStatements) UpsertRoomPublished( func (s *publishedStatements) UpsertRoomPublished(
ctx context.Context, txn *sql.Tx, roomID string, published bool, ctx context.Context, txn *sql.Tx, roomID string, published bool,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt)
_, err := stmt.ExecContext(ctx, roomID, published) // "INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// room_id TEXT NOT NULL PRIMARY KEY,
docId := roomID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := PublishCosmos{
RoomID: roomID,
Published: false,
}
var dbData = PublishCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
Publish: data,
}
// "INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)"
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err return err
} }
func (s *publishedStatements) SelectPublishedFromRoomID( func (s *publishedStatements) SelectPublishedFromRoomID(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) (published bool, err error) { ) (published bool, err error) {
err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published)
if err == sql.ErrNoRows { // "SELECT published FROM roomserver_published WHERE room_id = $1"
return false, nil var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// room_id TEXT NOT NULL PRIMARY KEY,
docId := roomID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response, err := getPublish(s, ctx, pk, cosmosDocId)
if err != nil {
return false, err
} }
return return response.Publish.Published, nil
} }
func (s *publishedStatements) SelectAllPublishedRooms( func (s *publishedStatements) SelectAllPublishedRooms(
ctx context.Context, published bool, ctx context.Context, published bool,
) ([]string, error) { ) ([]string, error) {
rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published)
// "SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": published,
}
response, err := queryPublish(s, ctx, s.selectAllPublishedStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectAllPublishedStmt: rows.close() failed")
var roomIDs []string var roomIDs []string
for rows.Next() { for _, item := range response {
var roomID string roomIDs = append(roomIDs, item.Publish.RoomID)
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
roomIDs = append(roomIDs, roomID)
} }
return roomIDs, rows.Err() return roomIDs, nil
} }

View file

@ -17,84 +17,207 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "database/sql"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
) )
const redactionsSchema = ` // const redactionsSchema = `
-- Stores information about the redacted state of events. // -- Stores information about the redacted state of events.
-- We need to track redactions rather than blindly updating the event JSON table on receipt of a redaction // -- We need to track redactions rather than blindly updating the event JSON table on receipt of a redaction
-- because we might receive the redaction BEFORE we receive the event which it redacts (think backfill). // -- because we might receive the redaction BEFORE we receive the event which it redacts (think backfill).
CREATE TABLE IF NOT EXISTS roomserver_redactions ( // CREATE TABLE IF NOT EXISTS roomserver_redactions (
redaction_event_id TEXT PRIMARY KEY, // redaction_event_id TEXT PRIMARY KEY,
redacts_event_id TEXT NOT NULL, // redacts_event_id TEXT NOT NULL,
-- Initially FALSE, set to TRUE when the redaction has been validated according to rooms v3+ spec // -- Initially FALSE, set to TRUE when the redaction has been validated according to rooms v3+ spec
-- https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events // -- https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events
validated BOOLEAN NOT NULL // validated BOOLEAN NOT NULL
); // );
` // `
const insertRedactionSQL = "" + type RedactionCosmos struct {
"INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" + RedactionEventID string `json:"redaction_event_id"`
" VALUES ($1, $2, $3)" RedactsEventID string `json:"redacts_event_id"`
Validated bool `json:"validated"`
const selectRedactionInfoByRedactionEventIDSQL = "" +
"SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
" WHERE redaction_event_id = $1"
const selectRedactionInfoByEventBeingRedactedSQL = "" +
"SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
" WHERE redacts_event_id = $1"
const markRedactionValidatedSQL = "" +
" UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1"
type redactionStatements struct {
db *sql.DB
insertRedactionStmt *sql.Stmt
selectRedactionInfoByRedactionEventIDStmt *sql.Stmt
selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt
markRedactionValidatedStmt *sql.Stmt
} }
func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) { type RedactionCosmosData struct {
s := &redactionStatements{ Id string `json:"id"`
db: db, Pk string `json:"_pk"`
} Cn string `json:"_cn"`
_, err := db.Exec(redactionsSchema) ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Redaction RedactionCosmos `json:"mx_roomserver_redaction"`
}
// const insertRedactionSQL = "" +
// "INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" +
// " VALUES ($1, $2, $3)"
// const selectRedactionInfoByRedactionEventIDSQL = "" +
// "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
// " WHERE redaction_event_id = $1"
// "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
// " WHERE redacts_event_id = $1"
const selectRedactionInfoByEventBeingRedactedSQL = "" +
"select * from c where c._cn = @x1 " +
" and c.mx_roomserver_redaction.redacts_event_id = @x2"
// const markRedactionValidatedSQL = "" +
// " UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1"
type redactionStatements struct {
db *Database
// insertRedactionStmt *sql.Stmt
// selectRedactionInfoByRedactionEventIDStmt *sql.Stmt
selectRedactionInfoByEventBeingRedactedStmt string
// markRedactionValidatedStmt *sql.Stmt
tableName string
}
func queryRedaction(s *redactionStatements, ctx context.Context, qry string, params map[string]interface{}) ([]RedactionCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []RedactionCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return response, nil
}
return s, shared.StatementList{ func getRedaction(s *redactionStatements, ctx context.Context, pk string, docId string) (*RedactionCosmosData, error) {
{&s.insertRedactionStmt, insertRedactionSQL}, response := RedactionCosmosData{}
{&s.selectRedactionInfoByRedactionEventIDStmt, selectRedactionInfoByRedactionEventIDSQL}, err := cosmosdbapi.GetDocumentOrNil(
{&s.selectRedactionInfoByEventBeingRedactedStmt, selectRedactionInfoByEventBeingRedactedSQL}, s.db.connection,
{&s.markRedactionValidatedStmt, markRedactionValidatedSQL}, s.db.cosmosConfig,
}.Prepare(db) ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func setRedaction(s *redactionStatements, ctx context.Context, pk string, redaction RedactionCosmosData) (*RedactionCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, redaction.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
redaction.Id,
&redaction,
optionsReplace)
return &redaction, ex
}
func NewCosmosDBRedactionsTable(db *Database) (tables.Redactions, error) {
s := &redactionStatements{
db: db,
}
// return s, shared.StatementList{
// {&s.insertRedactionStmt, insertRedactionSQL},
// {&s.selectRedactionInfoByRedactionEventIDStmt, selectRedactionInfoByRedactionEventIDSQL},
s.selectRedactionInfoByEventBeingRedactedStmt = selectRedactionInfoByEventBeingRedactedSQL
// {&s.markRedactionValidatedStmt, markRedactionValidatedSQL},
// }.Prepare(db)
s.tableName = "redactions"
return s, nil
} }
func (s *redactionStatements) InsertRedaction( func (s *redactionStatements) InsertRedaction(
ctx context.Context, txn *sql.Tx, info tables.RedactionInfo, ctx context.Context, txn *sql.Tx, info tables.RedactionInfo,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt)
_, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated) // "INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" +
return err // " VALUES ($1, $2, $3)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// redaction_event_id TEXT PRIMARY KEY,
docId := info.RedactionEventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := RedactionCosmos{
RedactionEventID: info.RedactionEventID,
RedactsEventID: info.RedactsEventID,
Validated: info.Validated,
}
var dbData = RedactionCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
Redaction: data,
}
// "INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" +
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
// TODO: Just forDebug - Remove exception
if err != nil {
return err
}
//Ignore Error
return nil
} }
func (s *redactionStatements) SelectRedactionInfoByRedactionEventID( func (s *redactionStatements) SelectRedactionInfoByRedactionEventID(
ctx context.Context, txn *sql.Tx, redactionEventID string, ctx context.Context, txn *sql.Tx, redactionEventID string,
) (info *tables.RedactionInfo, err error) { ) (info *tables.RedactionInfo, err error) {
info = &tables.RedactionInfo{} info = &tables.RedactionInfo{}
stmt := sqlutil.TxStmt(txn, s.selectRedactionInfoByRedactionEventIDStmt)
err = stmt.QueryRowContext(ctx, redactionEventID).Scan( // "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
&info.RedactionEventID, &info.RedactsEventID, &info.Validated, // " WHERE redaction_event_id = $1"
) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
if err == sql.ErrNoRows { // redaction_event_id TEXT PRIMARY KEY,
docId := redactionEventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response, err := getRedaction(s, ctx, pk, cosmosDocId)
if err != nil {
info = nil
err = err
return
}
if response == nil {
info = nil info = nil
err = nil err = nil
return
}
info = &tables.RedactionInfo{
RedactionEventID: response.Redaction.RedactionEventID,
RedactsEventID: response.Redaction.RedactsEventID,
Validated: response.Redaction.Validated,
} }
return return
} }
@ -102,14 +225,31 @@ func (s *redactionStatements) SelectRedactionInfoByRedactionEventID(
func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted( func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted(
ctx context.Context, txn *sql.Tx, eventID string, ctx context.Context, txn *sql.Tx, eventID string,
) (info *tables.RedactionInfo, err error) { ) (info *tables.RedactionInfo, err error) {
info = &tables.RedactionInfo{}
stmt := sqlutil.TxStmt(txn, s.selectRedactionInfoByEventBeingRedactedStmt) // "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
err = stmt.QueryRowContext(ctx, eventID).Scan( // " WHERE redacts_event_id = $1"
&info.RedactionEventID, &info.RedactsEventID, &info.Validated,
) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
if err == sql.ErrNoRows { params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": eventID,
}
response, err := queryRedaction(s, ctx, s.selectRedactionInfoByEventBeingRedactedStmt, params)
if err != nil {
return nil, err
}
if len(response) == 0 {
info = nil info = nil
err = nil err = nil
return
}
// TODO: Check this is ok to return the 1st one
*info = tables.RedactionInfo{
RedactionEventID: response[0].Redaction.RedactionEventID,
RedactsEventID: response[0].Redaction.RedactsEventID,
Validated: response[0].Redaction.Validated,
} }
return return
} }
@ -117,7 +257,22 @@ func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted(
func (s *redactionStatements) MarkRedactionValidated( func (s *redactionStatements) MarkRedactionValidated(
ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool, ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt)
_, err := stmt.ExecContext(ctx, redactionEventID, validated) // " UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// redaction_event_id TEXT PRIMARY KEY,
docId := redactionEventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response, err := getRedaction(s, ctx, pk, cosmosDocId)
if err != nil {
return err
}
response.Redaction.Validated = validated
_, err = setRedaction(s, ctx, pk, *response)
return err return err
} }

View file

@ -18,84 +18,185 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "database/sql"
"time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
) )
const roomAliasesSchema = ` // const roomAliasesSchema = `
CREATE TABLE IF NOT EXISTS roomserver_room_aliases ( // CREATE TABLE IF NOT EXISTS roomserver_room_aliases (
alias TEXT NOT NULL PRIMARY KEY, // alias TEXT NOT NULL PRIMARY KEY,
room_id TEXT NOT NULL, // room_id TEXT NOT NULL,
creator_id TEXT NOT NULL // creator_id TEXT NOT NULL
); // );
CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id); // CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id);
` // `
const insertRoomAliasSQL = ` type RoomAliasCosmos struct {
INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3) Alias string `json:"alias"`
` RoomID string `json:"room_id"`
CreatorID string `json:"creator_id"`
const selectRoomIDFromAliasSQL = `
SELECT room_id FROM roomserver_room_aliases WHERE alias = $1
`
const selectAliasesFromRoomIDSQL = `
SELECT alias FROM roomserver_room_aliases WHERE room_id = $1
`
const selectCreatorIDFromAliasSQL = `
SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1
`
const deleteRoomAliasSQL = `
DELETE FROM roomserver_room_aliases WHERE alias = $1
`
type roomAliasesStatements struct {
db *sql.DB
insertRoomAliasStmt *sql.Stmt
selectRoomIDFromAliasStmt *sql.Stmt
selectAliasesFromRoomIDStmt *sql.Stmt
selectCreatorIDFromAliasStmt *sql.Stmt
deleteRoomAliasStmt *sql.Stmt
} }
func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { type RoomAliasCosmosData struct {
s := &roomAliasesStatements{ Id string `json:"id"`
db: db, Pk string `json:"_pk"`
} Cn string `json:"_cn"`
_, err := db.Exec(roomAliasesSchema) ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
RoomAlias RoomAliasCosmos `json:"mx_roomserver_room_alias"`
}
// const insertRoomAliasSQL = `
// INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3)
// `
// const selectRoomIDFromAliasSQL = `
// SELECT room_id FROM roomserver_room_aliases WHERE alias = $1
// `
// SELECT alias FROM roomserver_room_aliases WHERE room_id = $1
const selectAliasesFromRoomIDSQL = `
select * from c where c._cn = @x1 and c.mx_roomserver_room_alias.room_id = @x2
`
// const selectCreatorIDFromAliasSQL = `
// SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1
// `
// const deleteRoomAliasSQL = `
// DELETE FROM roomserver_room_aliases WHERE alias = $1
// `
type roomAliasesStatements struct {
db *Database
// insertRoomAliasStmt *sql.Stmt
// selectRoomIDFromAliasStmt string
selectAliasesFromRoomIDStmt string
// selectCreatorIDFromAliasStmt string
// deleteRoomAliasStmt *sql.Stmt
tableName string
}
func queryRoomAlias(s *roomAliasesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]RoomAliasCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []RoomAliasCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return s, shared.StatementList{ return response, nil
{&s.insertRoomAliasStmt, insertRoomAliasSQL}, }
{&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL},
{&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, func getRoomAlias(s *roomAliasesStatements, ctx context.Context, pk string, docId string) (*RoomAliasCosmosData, error) {
{&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL}, response := RoomAliasCosmosData{}
{&s.deleteRoomAliasStmt, deleteRoomAliasSQL}, err := cosmosdbapi.GetDocumentOrNil(
}.Prepare(db) s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func NewCosmosDBRoomAliasesTable(db *Database) (tables.RoomAliases, error) {
s := &roomAliasesStatements{
db: db,
}
// _, err := db.Exec(roomAliasesSchema)
// if err != nil {
// return nil, err
// }
// return s, shared.StatementList{
// {&s.insertRoomAliasStmt, insertRoomAliasSQL},
// {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL},
s.selectAliasesFromRoomIDStmt = selectAliasesFromRoomIDSQL
// {&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL},
// {&s.deleteRoomAliasStmt, deleteRoomAliasSQL},
// }.Prepare(db)
s.tableName = "room_aliases"
return s, nil
} }
func (s *roomAliasesStatements) InsertRoomAlias( func (s *roomAliasesStatements) InsertRoomAlias(
ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string, ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.insertRoomAliasStmt)
_, err := stmt.ExecContext(ctx, alias, roomID, creatorUserID) // INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3)
data := RoomAliasCosmos{
Alias: alias,
CreatorID: creatorUserID,
RoomID: roomID,
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// alias TEXT NOT NULL PRIMARY KEY,
docId := alias
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = RoomAliasCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
RoomAlias: data,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err return err
} }
func (s *roomAliasesStatements) SelectRoomIDFromAlias( func (s *roomAliasesStatements) SelectRoomIDFromAlias(
ctx context.Context, alias string, ctx context.Context, alias string,
) (roomID string, err error) { ) (roomID string, err error) {
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
if err == sql.ErrNoRows { // SELECT room_id FROM roomserver_room_aliases WHERE alias = $1
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// alias TEXT NOT NULL PRIMARY KEY,
docId := alias
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response, err := getRoomAlias(s, ctx, pk, cosmosDocId)
if err != nil {
return "", err
}
if response == nil {
return "", nil return "", nil
} }
roomID = response.RoomAlias.RoomID
return return
} }
@ -103,20 +204,23 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) (aliases []string, err error) { ) (aliases []string, err error) {
aliases = []string{} aliases = []string{}
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
if err != nil { // SELECT alias FROM roomserver_room_aliases WHERE room_id = $1
return
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed") response, err := queryRoomAlias(s, ctx, s.selectAliasesFromRoomIDStmt, params)
for rows.Next() { if err != nil {
var alias string return nil, err
if err = rows.Scan(&alias); err != nil { }
return
}
aliases = append(aliases, alias) for _, item := range response {
aliases = append(aliases, item.RoomAlias.Alias)
} }
return return
@ -125,17 +229,48 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
func (s *roomAliasesStatements) SelectCreatorIDFromAlias( func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
ctx context.Context, alias string, ctx context.Context, alias string,
) (creatorID string, err error) { ) (creatorID string, err error) {
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
if err == sql.ErrNoRows { // SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// alias TEXT NOT NULL PRIMARY KEY,
docId := alias
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response, err := getRoomAlias(s, ctx, pk, cosmosDocId)
if err != nil {
return "", err
}
if response == nil {
return "", nil return "", nil
} }
creatorID = response.RoomAlias.CreatorID
return return
} }
func (s *roomAliasesStatements) DeleteRoomAlias( func (s *roomAliasesStatements) DeleteRoomAlias(
ctx context.Context, txn *sql.Tx, alias string, ctx context.Context, txn *sql.Tx, alias string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteRoomAliasStmt)
_, err := stmt.ExecContext(ctx, alias) // DELETE FROM roomserver_room_aliases WHERE alias = $1
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := alias
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
cosmosDocId,
options)
if err != nil {
return err
}
return err return err
} }

View file

@ -0,0 +1,12 @@
package cosmosdb
import (
"context"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
)
func GetNextRoomNID(s *roomStatements, ctx context.Context) (int64, error) {
const docId = "roomnid_seq"
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1)
}

View file

@ -18,128 +18,227 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"fmt" "fmt"
"strings" "time"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
const roomsSchema = ` // const roomsSchema = `
CREATE TABLE IF NOT EXISTS roomserver_rooms ( // CREATE TABLE IF NOT EXISTS roomserver_rooms (
room_nid INTEGER PRIMARY KEY AUTOINCREMENT, // room_nid INTEGER PRIMARY KEY AUTOINCREMENT,
room_id TEXT NOT NULL UNIQUE, // room_id TEXT NOT NULL UNIQUE,
latest_event_nids TEXT NOT NULL DEFAULT '[]', // latest_event_nids TEXT NOT NULL DEFAULT '[]',
last_event_sent_nid INTEGER NOT NULL DEFAULT 0, // last_event_sent_nid INTEGER NOT NULL DEFAULT 0,
state_snapshot_nid INTEGER NOT NULL DEFAULT 0, // state_snapshot_nid INTEGER NOT NULL DEFAULT 0,
room_version TEXT NOT NULL // room_version TEXT NOT NULL
); // );
` // `
// Same as insertEventTypeNIDSQL type RoomCosmosData struct {
const insertRoomNIDSQL = ` Id string `json:"id"`
INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2) Pk string `json:"_pk"`
ON CONFLICT DO NOTHING; Cn string `json:"_cn"`
` ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
const selectRoomNIDSQL = "" + Room RoomCosmos `json:"mx_roomserver_room"`
"SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
const selectLatestEventNIDsSQL = "" +
"SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
const selectLatestEventNIDsForUpdateSQL = "" +
"SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
const updateLatestEventNIDsSQL = "" +
"UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4"
const selectRoomVersionsForRoomNIDsSQL = "" +
"SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)"
const selectRoomInfoSQL = "" +
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
const selectRoomIDsSQL = "" +
"SELECT room_id FROM roomserver_rooms"
const bulkSelectRoomIDsSQL = "" +
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
const bulkSelectRoomNIDsSQL = "" +
"SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)"
type roomStatements struct {
db *sql.DB
insertRoomNIDStmt *sql.Stmt
selectRoomNIDStmt *sql.Stmt
selectLatestEventNIDsStmt *sql.Stmt
selectLatestEventNIDsForUpdateStmt *sql.Stmt
updateLatestEventNIDsStmt *sql.Stmt
//selectRoomVersionForRoomNIDStmt *sql.Stmt
selectRoomInfoStmt *sql.Stmt
selectRoomIDsStmt *sql.Stmt
} }
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { type RoomCosmos struct {
RoomNID int64 `json:"room_nid"`
RoomID string `json:"room_id"`
LatestEventNIDs []int64 `json:"latest_event_nids"`
LastEventSentNID int64 `json:"last_event_sent_nid"`
StateSnapshotNID int64 `json:"state_snapshot_nid"`
RoomVersion string `json:"room_version"`
}
// Same as insertEventTypeNIDSQL
// const insertRoomNIDSQL = `
// INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2)
// ON CONFLICT DO NOTHING;
// `
// "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
// const selectRoomNIDSQL = "" +
// "select * from c where c._cn = @x1 and c.mx_roomserver_room.room_nid = @x1"
// "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
const selectLatestEventNIDsSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_roomserver_room.room_nid = @x2"
// "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
const selectLatestEventNIDsForUpdateSQL = "" +
"select * from c where c._cn = @x1 " +
" and c.mx_roomserver_room.room_nid = @x2"
// const updateLatestEventNIDsSQL = "" +
// "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4"
// "SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)"
const selectRoomVersionsForRoomNIDsSQL = "" +
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_room.room_nid)"
// "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
// const selectRoomInfoSQL = "" +
// "select * from c where c._cn = @x1 and c.mx_roomserver_room.room_id = @x2"
// "SELECT room_id FROM roomserver_rooms"
const selectRoomIDsSQL = "" +
"select * from c where c._cn = @x1"
// "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
const bulkSelectRoomIDsSQL = "" +
"select * from c where c._cn = @x1 " +
" and ARRAY_CONTAINS(@x2, c.mx_roomserver_room.room_nid)"
// "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)"
const bulkSelectRoomNIDsSQL = "" +
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_room.room_nid)"
type roomStatements struct {
db *Database
// insertRoomNIDStmt *sql.Stmt
// selectRoomNIDStmt string
selectLatestEventNIDsStmt string
selectLatestEventNIDsForUpdateStmt string
updateLatestEventNIDsStmt string
selectRoomVersionForRoomNIDStmt string
// selectRoomInfoStmt *sql.Stmt
selectRoomIDsStmt string
tableName string
}
func NewCosmosDBRoomsTable(db *Database) (tables.Rooms, error) {
s := &roomStatements{ s := &roomStatements{
db: db, db: db,
} }
_, err := db.Exec(roomsSchema) // return s, shared.StatementList{
// {&s.insertRoomNIDStmt, insertRoomNIDSQL},
// {&s.selectRoomNIDStmt, selectRoomNIDSQL},
s.selectLatestEventNIDsStmt = selectLatestEventNIDsSQL
s.selectLatestEventNIDsForUpdateStmt = selectLatestEventNIDsForUpdateSQL
// {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
//{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL},
// {&s.selectRoomInfoStmt, selectRoomInfoSQL},
s.selectRoomIDsStmt = selectRoomIDsSQL
// }.Prepare(db)
s.tableName = "rooms"
return s, nil
}
func mapToRoomEventNIDArray(eventNIDs []int64) []types.EventNID {
result := []types.EventNID{}
for i := 0; i < len(eventNIDs); i++ {
result = append(result, types.EventNID(eventNIDs[i]))
}
return result
}
func queryRoom(s *roomStatements, ctx context.Context, qry string, params map[string]interface{}) ([]RoomCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []RoomCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return s, shared.StatementList{ return response, nil
{&s.insertRoomNIDStmt, insertRoomNIDSQL}, }
{&s.selectRoomNIDStmt, selectRoomNIDSQL},
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, func getRoom(s *roomStatements, ctx context.Context, pk string, docId string) (*RoomCosmosData, error) {
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, response := RoomCosmosData{}
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, err := cosmosdbapi.GetDocumentOrNil(
//{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL}, s.db.connection,
{&s.selectRoomInfoStmt, selectRoomInfoSQL}, s.db.cosmosConfig,
{&s.selectRoomIDsStmt, selectRoomIDsSQL}, ctx,
}.Prepare(db) pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func setRoom(s *roomStatements, ctx context.Context, pk string, room RoomCosmosData) (*RoomCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, room.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
room.Id,
&room,
optionsReplace)
return &room, ex
} }
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
// "SELECT room_id FROM roomserver_rooms"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
}
response, err := queryRoom(s, ctx, s.selectRoomIDsStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
var roomIDs []string var roomIDs []string
for rows.Next() { for _, item := range response {
var roomID string roomIDs = append(roomIDs, item.Room.RoomID)
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
roomIDs = append(roomIDs, roomID)
} }
return roomIDs, nil return roomIDs, nil
} }
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
var info types.RoomInfo info := types.RoomInfo{}
var latestNIDsJSON string
err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan( // "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON,
) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// room_id TEXT NOT NULL UNIQUE,
docId := roomID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
room, err := getRoom(s, ctx, pk, cosmosDocId)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == cosmosdbutil.ErrNoRows {
return nil, nil return nil, nil
} }
return nil, err return nil, err
} }
var latestNIDs []int64
if err = json.Unmarshal([]byte(latestNIDsJSON), &latestNIDs); err != nil { info.RoomVersion = gomatrixserverlib.RoomVersion(room.Room.RoomVersion)
return nil, err info.RoomNID = types.RoomNID(room.Room.RoomNID)
} info.StateSnapshotNID = types.StateSnapshotNID(room.Room.StateSnapshotNID)
info.IsStub = len(latestNIDs) == 0 info.IsStub = len(room.Room.LatestEventNIDs) == 0
return &info, err return &info, err
} }
@ -147,60 +246,135 @@ func (s *roomStatements) InsertRoomNID(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion, roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (roomNID types.RoomNID, err error) { ) (roomNID types.RoomNID, err error) {
insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
_, err = insertStmt.ExecContext(ctx, roomID, roomVersion) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
if err != nil {
return 0, fmt.Errorf("insertStmt.ExecContext: %w", err) // INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2)
// ON CONFLICT DO NOTHING;
// room_id TEXT NOT NULL UNIQUE,
docId := roomID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
dbData, errGet := getRoom(s, ctx, pk, cosmosDocId)
if errGet == cosmosdbutil.ErrNoRows {
// room_nid INTEGER PRIMARY KEY AUTOINCREMENT,
roomNIDSeq, seqErr := GetNextRoomNID(s, ctx)
if seqErr != nil {
return 0, seqErr
}
data := RoomCosmos{
RoomNID: int64(roomNIDSeq),
RoomID: roomID,
RoomVersion: string(roomVersion),
}
dbData = &RoomCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
Room: data,
}
} }
roomNID, err = s.SelectRoomNID(ctx, txn, roomID)
// ON CONFLICT DO NOTHING; - Do Upsert
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
if err != nil { if err != nil {
return 0, fmt.Errorf("s.SelectRoomNID: %w", err) return 0, fmt.Errorf("s.SelectRoomNID: %w", err)
} }
roomNID = types.RoomNID(dbData.Room.RoomNID)
return return
} }
func (s *roomStatements) SelectRoomNID( func (s *roomStatements) SelectRoomNID(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (types.RoomNID, error) { ) (types.RoomNID, error) {
var roomNID int64
stmt := sqlutil.TxStmt(txn, s.selectRoomNIDStmt) // "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
return types.RoomNID(roomNID), err var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// room_id TEXT NOT NULL UNIQUE,
docId := roomID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
room, err := getRoom(s, ctx, pk, cosmosDocId)
if err != nil {
return 0, err
}
if room == nil {
return 0, nil
}
return types.RoomNID(room.Room.RoomNID), err
} }
func (s *roomStatements) SelectLatestEventNIDs( func (s *roomStatements) SelectLatestEventNIDs(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
) ([]types.EventNID, types.StateSnapshotNID, error) { ) ([]types.EventNID, types.StateSnapshotNID, error) {
var eventNIDs []types.EventNID
var nidsJSON string // "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
var stateSnapshotNID int64
stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nidsJSON, &stateSnapshotNID) params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomNID,
}
response, err := queryRoom(s, ctx, s.selectLatestEventNIDsStmt, params)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
if err := json.Unmarshal([]byte(nidsJSON), &eventNIDs); err != nil {
return nil, 0, err // TODO: Check the error handling
if len(response) == 0 {
return nil, 0, cosmosdbutil.ErrNoRows
} }
return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil
//Assume 1 per RoomNID
room := response[0]
return mapToRoomEventNIDArray(room.Room.LatestEventNIDs), types.StateSnapshotNID(room.Room.StateSnapshotNID), nil
} }
func (s *roomStatements) SelectLatestEventsNIDsForUpdate( func (s *roomStatements) SelectLatestEventsNIDsForUpdate(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) { ) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) {
var eventNIDs []types.EventNID
var nidsJSON string // "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
var lastEventSentNID int64
var stateSnapshotNID int64 var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt) params := map[string]interface{}{
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nidsJSON, &lastEventSentNID, &stateSnapshotNID) "@x1": dbCollectionName,
"@x2": roomNID,
}
response, err := queryRoom(s, ctx, s.selectLatestEventNIDsForUpdateStmt, params)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
if err := json.Unmarshal([]byte(nidsJSON), &eventNIDs); err != nil {
return nil, 0, 0, err // TODO: Check the error handling
if len(response) == 0 {
return nil, 0, 0, cosmosdbutil.ErrNoRows
} }
return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil
//Assume 1 per RoomNID
room := response[0]
return mapToRoomEventNIDArray(room.Room.LatestEventNIDs), types.EventNID(room.Room.LastEventSentNID), types.StateSnapshotNID(room.Room.StateSnapshotNID), nil
} }
func (s *roomStatements) UpdateLatestEventNIDs( func (s *roomStatements) UpdateLatestEventNIDs(
@ -211,86 +385,113 @@ func (s *roomStatements) UpdateLatestEventNIDs(
lastEventSentNID types.EventNID, lastEventSentNID types.EventNID,
stateSnapshotNID types.StateSnapshotNID, stateSnapshotNID types.StateSnapshotNID,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt)
_, err := stmt.ExecContext( // "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4"
ctx,
eventNIDsAsArray(eventNIDs), var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
int64(lastEventSentNID), params := map[string]interface{}{
int64(stateSnapshotNID), "@x1": dbCollectionName,
roomNID, "@x2": roomNID,
) }
response, err := queryRoom(s, ctx, s.selectLatestEventNIDsForUpdateStmt, params)
if err != nil {
return err
}
// TODO: Check the error handling
if len(response) == 0 {
return cosmosdbutil.ErrNoRows
}
//Assume 1 per RoomNID
room := response[0]
room.Room.LatestEventNIDs = mapFromEventNIDArray(eventNIDs)
room.Room.LastEventSentNID = int64(lastEventSentNID)
room.Room.StateSnapshotNID = int64(stateSnapshotNID)
_, err = setRoom(s, ctx, room.Pk, room)
return err return err
} }
func (s *roomStatements) SelectRoomVersionsForRoomNIDs( func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
ctx context.Context, roomNIDs []types.RoomNID, ctx context.Context, roomNIDs []types.RoomNID,
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) { ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) if roomNIDs == nil || len(roomNIDs) == 0 {
sqlPrep, err := s.db.Prepare(sqlStr) return make(map[types.RoomNID]gomatrixserverlib.RoomVersion), nil
}
// "SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomNIDs,
}
response, err := queryRoom(s, ctx, selectRoomVersionsForRoomNIDsSQL, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs {
iRoomNIDs[i] = v
}
rows, err := sqlPrep.QueryContext(ctx, iRoomNIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion) result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
for rows.Next() { for _, item := range response {
var roomNID types.RoomNID result[types.RoomNID(item.Room.RoomNID)] = gomatrixserverlib.RoomVersion(item.Room.RoomVersion)
var roomVersion gomatrixserverlib.RoomVersion
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
return nil, err
}
result[roomNID] = roomVersion
} }
return result, nil return result, nil
} }
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
iRoomNIDs := make([]interface{}, len(roomNIDs)) if roomNIDs == nil || len(roomNIDs) == 0 {
for i, v := range roomNIDs { return []string{}, nil
iRoomNIDs[i] = v
} }
sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...) // "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomNIDs,
}
response, err := queryRoom(s, ctx, bulkSelectRoomIDsSQL, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
var roomIDs []string var roomIDs []string
for rows.Next() { for _, item := range response {
var roomID string roomIDs = append(roomIDs, item.Room.RoomID)
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
roomIDs = append(roomIDs, roomID)
} }
return roomIDs, nil return roomIDs, nil
} }
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
iRoomIDs := make([]interface{}, len(roomIDs)) if roomIDs == nil || len(roomIDs) == 0 {
for i, v := range roomIDs { return []types.RoomNID{}, nil
iRoomIDs[i] = v
} }
sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...) // "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomIDs,
}
response, err := queryRoom(s, ctx, bulkSelectRoomNIDsSQL, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
var roomNIDs []types.RoomNID var roomNIDs []types.RoomNID
for rows.Next() { for _, item := range response {
var roomNID types.RoomNID roomNIDs = append(roomNIDs, types.RoomNID(item.Room.RoomNID))
if err = rows.Scan(&roomNID); err != nil {
return nil, err
}
roomNIDs = append(roomNIDs, roomNID)
} }
return roomNIDs, nil return roomNIDs, nil
} }

View file

@ -20,33 +20,54 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"sort" "sort"
"strings" "time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
const stateDataSchema = ` // const stateDataSchema = `
CREATE TABLE IF NOT EXISTS roomserver_state_block ( // CREATE TABLE IF NOT EXISTS roomserver_state_block (
state_block_nid INTEGER NOT NULL, // state_block_nid INTEGER NOT NULL,
event_type_nid INTEGER NOT NULL, // event_type_nid INTEGER NOT NULL,
event_state_key_nid INTEGER NOT NULL, // event_state_key_nid INTEGER NOT NULL,
event_nid INTEGER NOT NULL, // event_nid INTEGER NOT NULL,
UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) // UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
); // );
` // `
const insertStateDataSQL = "" + type StateBlockCosmos struct {
"INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + StateBlockNID int64 `json:"state_block_nid"`
" VALUES ($1, $2, $3, $4)" EventTypeNID int64 `json:"event_type_nid"`
EventStateKeyNID int64 `json:"event_state_key_nid"`
EventNID int64 `json:"event_nid"`
}
const selectNextStateBlockNIDSQL = ` type StateBlockCosmosMaxNID struct {
SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block Max int64 `json:"maxstateblocknid"`
` }
type StateBlockCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
StateBlock StateBlockCosmos `json:"mx_roomserver_state_block"`
}
// const insertStateDataSQL = "" +
// "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
// " VALUES ($1, $2, $3, $4)"
// SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block
const selectNextStateBlockNIDSQL = "" +
"select sub.maxinner != null ? sub.maxinner + 1 : 1 as maxstateblocknid " +
"from " +
"(select MAX(c.mx_roomserver_state_block.state_block_nid) maxinner from c where c._cn = @x1) as sub"
// Bulk state lookup by numeric state block ID. // Bulk state lookup by numeric state block ID.
// Sort by the state_block_nid, event_type_nid, event_state_key_nid // Sort by the state_block_nid, event_type_nid, event_state_key_nid
@ -54,10 +75,17 @@ SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block
// together in the list and those entries will sorted by event_type_nid // together in the list and those entries will sorted by event_type_nid
// and event_state_key_nid. This property makes it easier to merge two // and event_state_key_nid. This property makes it easier to merge two
// state data blocks together. // state data blocks together.
// "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
// " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
const bulkSelectStateBlockEntriesSQL = "" + const bulkSelectStateBlockEntriesSQL = "" +
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + "select * from c where c._cn = @x1 " +
" FROM roomserver_state_block WHERE state_block_nid IN ($1)" + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_state_block.state_block_nid) " +
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid" "order by c.mx_roomserver_state_block.state_block_nid " +
// Cant do multi field order by - The order by query does not have a corresponding composite index that it can be served from
// ", c.mx_roomserver_state_block.event_type_nid " +
// ", c.mx_roomserver_state_block.event_state_key_nid " +
" asc"
// Bulk state lookup by numeric state block ID. // Bulk state lookup by numeric state block ID.
// Filters the rows in each block to the requested types and state keys. // Filters the rows in each block to the requested types and state keys.
@ -66,35 +94,126 @@ const bulkSelectStateBlockEntriesSQL = "" +
// of types and a list state_keys. So we have to filter the result in the // of types and a list state_keys. So we have to filter the result in the
// application to restrict it to the list of event types and state keys we // application to restrict it to the list of event types and state keys we
// actually wanted. // actually wanted.
// "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
// " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
// " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" +
// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
const bulkSelectFilteredStateBlockEntriesSQL = "" + const bulkSelectFilteredStateBlockEntriesSQL = "" +
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + "select * from c where c._cn = @x1 " +
" FROM roomserver_state_block WHERE state_block_nid IN ($1)" + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_state_block.state_block_nid) " +
" AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" + "and ARRAY_CONTAINS(@x3, c.mx_roomserver_state_block.event_type_nid) " +
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid" "and ARRAY_CONTAINS(@x4, c.mx_roomserver_state_block.event_state_key_nid) " +
"order by c.mx_roomserver_state_block.state_block_nid " +
// Cant do multi field order by - The order by query does not have a corresponding composite index that it can be served from
// ", c.mx_roomserver_state_block.event_type_nid " +
// ", c.mx_roomserver_state_block.event_state_key_nid " +
"asc"
type stateBlockStatements struct { type stateBlockStatements struct {
db *sql.DB db *Database
insertStateDataStmt *sql.Stmt // insertStateDataStmt *sql.Stmt
selectNextStateBlockNIDStmt *sql.Stmt selectNextStateBlockNIDStmt string
bulkSelectStateBlockEntriesStmt *sql.Stmt bulkSelectStateBlockEntriesStmt string
bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt bulkSelectFilteredStateBlockEntriesStmt string
tableName string
} }
func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { func queryStateBlock(s *stateBlockStatements, ctx context.Context, qry string, params map[string]interface{}) ([]StateBlockCosmosData, error) {
s := &stateBlockStatements{ var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
db: db, var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
} var response []StateBlockCosmosData
_, err := db.Exec(stateDataSchema)
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return response, nil
}
return s, shared.StatementList{ func NewCosmosDBStateBlockTable(db *Database) (tables.StateBlock, error) {
{&s.insertStateDataStmt, insertStateDataSQL}, s := &stateBlockStatements{
{&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, db: db,
{&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, }
{&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL},
}.Prepare(db) // return s, shared.StatementList{
// {&s.insertStateDataStmt, insertStateDataSQL},
s.selectNextStateBlockNIDStmt = selectNextStateBlockNIDSQL
s.bulkSelectStateBlockEntriesStmt = bulkSelectStateBlockEntriesSQL
s.bulkSelectFilteredStateBlockEntriesStmt = bulkSelectFilteredStateBlockEntriesSQL
// }.Prepare(db)
s.tableName = "state_block"
return s, nil
}
func inertStateBlockCore(s *stateBlockStatements, ctx context.Context, stateBlockNID types.StateBlockNID, entry types.StateEntry) error {
// "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
// " VALUES ($1, $2, $3, $4)"
data := StateBlockCosmos{
EventNID: int64(entry.EventNID),
EventStateKeyNID: int64(entry.EventStateKeyNID),
EventTypeNID: int64(entry.EventTypeNID),
StateBlockNID: int64(stateBlockNID),
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
docId := fmt.Sprintf("%d_%d_%d", data.StateBlockNID, data.EventTypeNID, data.EventStateKeyNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = StateBlockCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
StateBlock: data,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err
}
func getNextStateBlockNID(s *stateBlockStatements, ctx context.Context) (int64, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var stateBlockNext []StateBlockCosmosMaxNID
params := map[string]interface{}{
"@x1": dbCollectionName,
}
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
var query = cosmosdbapi.GetQuery(s.selectNextStateBlockNIDStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&stateBlockNext,
optionsQry)
if err != nil {
return 0, err
}
return stateBlockNext[0].Max, nil
} }
func (s *stateBlockStatements) BulkInsertStateData( func (s *stateBlockStatements) BulkInsertStateData(
@ -104,75 +223,64 @@ func (s *stateBlockStatements) BulkInsertStateData(
if len(entries) == 0 { if len(entries) == 0 {
return 0, nil return 0, nil
} }
var stateBlockNID types.StateBlockNID
err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) nextID, errNextID := getNextStateBlockNID(s, ctx)
if err != nil { if errNextID != nil {
return 0, err return 0, errNextID
} }
stateBlockNID := types.StateBlockNID(nextID)
for _, entry := range entries { for _, entry := range entries {
_, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext( err := inertStateBlockCore(s, ctx, stateBlockNID, entry)
ctx,
int64(stateBlockNID),
int64(entry.EventTypeNID),
int64(entry.EventStateKeyNID),
int64(entry.EventNID),
)
if err != nil { if err != nil {
return 0, err return 0, err
} }
} }
return stateBlockNID, err return stateBlockNID, nil
} }
func (s *stateBlockStatements) BulkSelectStateBlockEntries( func (s *stateBlockStatements) BulkSelectStateBlockEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID, ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) { ) ([]types.StateEntryList, error) {
nids := make([]interface{}, len(stateBlockNIDs))
for k, v := range stateBlockNIDs { // "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
nids[k] = v // " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var response []StateBlockCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": stateBlockNIDs,
} }
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
selectStmt, err := s.db.Prepare(selectOrig) response, err := queryStateBlock(s, ctx, s.bulkSelectStateBlockEntriesStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rows, err := selectStmt.QueryContext(ctx, nids...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed")
results := make([]types.StateEntryList, len(stateBlockNIDs)) results := make([]types.StateEntryList, len(stateBlockNIDs))
// current is a pointer to the StateEntryList to append the state entries to. // current is a pointer to the StateEntryList to append the state entries to.
var current *types.StateEntryList var current *types.StateEntryList
i := 0 i := 0
for rows.Next() { for _, item := range response {
var ( entry := types.StateEntry{}
stateBlockNID int64 entry.EventTypeNID = types.EventTypeNID(item.StateBlock.EventTypeNID)
eventTypeNID int64 entry.EventStateKeyNID = types.EventStateKeyNID(item.StateBlock.EventStateKeyNID)
eventStateKeyNID int64 entry.EventNID = types.EventNID(item.StateBlock.EventNID)
eventNID int64
entry types.StateEntry if current == nil || types.StateBlockNID(item.StateBlock.StateBlockNID) != current.StateBlockNID {
)
if err := rows.Scan(
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
); err != nil {
return nil, err
}
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
entry.EventNID = types.EventNID(eventNID)
if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
// The state entry row is for a different state data block to the current one. // The state entry row is for a different state data block to the current one.
// So we start appending to the next entry in the list. // So we start appending to the next entry in the list.
current = &results[i] current = &results[i]
current.StateBlockNID = types.StateBlockNID(stateBlockNID) current.StateBlockNID = types.StateBlockNID(item.StateBlock.StateBlockNID)
i++ i++
} }
current.StateEntries = append(current.StateEntries, entry) current.StateEntries = append(current.StateEntries, entry)
} }
if i != len(nids) { if i != len(stateBlockNIDs) {
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids)) return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs))
} }
return results, nil return results, nil
} }
@ -187,34 +295,33 @@ func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
sort.Sort(tuples) sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1) // sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1)
sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1) // sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1) // sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
var params []interface{} // "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
for _, val := range stateBlockNIDs { // " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
params = append(params, int64(val)) // " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" +
} // " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
for _, val := range eventTypeNIDArray {
params = append(params, val) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
} var response []StateBlockCosmosData
for _, val := range eventStateKeyNIDArray { params := map[string]interface{}{
params = append(params, val) "@x1": dbCollectionName,
"@x2": stateBlockNIDs,
"@x3": eventTypeNIDArray,
"@x4": eventStateKeyNIDArray,
} }
rows, err := s.db.QueryContext( response, err := queryStateBlock(s, ctx, s.bulkSelectFilteredStateBlockEntriesStmt, params)
ctx,
sqlStatement,
params...,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed")
var results []types.StateEntryList var results []types.StateEntryList
var current types.StateEntryList var current types.StateEntryList
for rows.Next() { for _, item := range response {
var ( var (
stateBlockNID int64 stateBlockNID int64
eventTypeNID int64 eventTypeNID int64
@ -222,11 +329,10 @@ func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
eventNID int64 eventNID int64
entry types.StateEntry entry types.StateEntry
) )
if err := rows.Scan( stateBlockNID = item.StateBlock.StateBlockNID
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, eventTypeNID = item.StateBlock.EventTypeNID
); err != nil { eventStateKeyNID = item.StateBlock.EventStateKeyNID
return nil, err eventNID = item.StateBlock.EventNID
}
entry.EventTypeNID = types.EventTypeNID(eventTypeNID) entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
entry.EventNID = types.EventNID(eventNID) entry.EventNID = types.EventNID(eventNID)

View file

@ -0,0 +1,12 @@
package cosmosdb
import (
"context"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
)
func GetNextStateSnapshotNID(s *stateSnapshotStatements, ctx context.Context) (int64, error) {
const docId = "statesnapshotnid_seq"
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1)
}

View file

@ -18,106 +18,169 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"fmt" "fmt"
"strings" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
const stateSnapshotSchema = ` // const stateSnapshotSchema = `
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( // CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, // state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
room_nid INTEGER NOT NULL, // room_nid INTEGER NOT NULL,
state_block_nids TEXT NOT NULL DEFAULT '[]' // state_block_nids TEXT NOT NULL DEFAULT '[]'
); // );
` // `
const insertStateSQL = ` type StateSnapshotCosmos struct {
INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids) StateSnapshotNID int64 `json:"state_snapshot_nid"`
VALUES ($1, $2);` RoomNID int64 `json:"room_nid"`
StateBlockNIDs []int64 `json:"state_block_nids"`
}
type StateSnapshotCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
StateSnapshot StateSnapshotCosmos `json:"mx_roomserver_state_snapshot"`
}
// const insertStateSQL = `
// INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
// VALUES ($1, $2);`
// Bulk state data NID lookup. // Bulk state data NID lookup.
// Sorting by state_snapshot_nid means we can use binary search over the result // Sorting by state_snapshot_nid means we can use binary search over the result
// to lookup the state data NIDs for a state snapshot NID. // to lookup the state data NIDs for a state snapshot NID.
// "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
// " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
const bulkSelectStateBlockNIDsSQL = "" + const bulkSelectStateBlockNIDsSQL = "" +
"SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + "select * from c where c._cn = @x1 " +
" WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC" "and ARRAY_CONTAINS(@x2, c.mx_roomserver_state_snapshot.state_snapshot_nid) " +
"order by c.mx_roomserver_state_snapshot.state_snapshot_nid asc"
type stateSnapshotStatements struct { type stateSnapshotStatements struct {
db *sql.DB db *Database
insertStateStmt *sql.Stmt // insertStateStmt *sql.Stmt
bulkSelectStateBlockNIDsStmt *sql.Stmt bulkSelectStateBlockNIDsStmt string
tableName string
} }
func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { func mapFromStateBlockNIDArray(stateBlockNIDs []types.StateBlockNID) []int64 {
result := []int64{}
for i := 0; i < len(stateBlockNIDs); i++ {
result = append(result, int64(stateBlockNIDs[i]))
}
return result
}
func mapToStateBlockNIDArray(stateBlockNIDs []int64) []types.StateBlockNID {
result := []types.StateBlockNID{}
for i := 0; i < len(stateBlockNIDs); i++ {
result = append(result, types.StateBlockNID(stateBlockNIDs[i]))
}
return result
}
func NewCosmosDBStateSnapshotTable(db *Database) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{ s := &stateSnapshotStatements{
db: db, db: db,
} }
_, err := db.Exec(stateSnapshotSchema)
if err != nil {
return nil, err
}
return s, shared.StatementList{ // return s, shared.StatementList{
{&s.insertStateStmt, insertStateSQL}, // {&s.insertStateStmt, insertStateSQL},
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, s.bulkSelectStateBlockNIDsStmt = bulkSelectStateBlockNIDsSQL
}.Prepare(db) // }.Prepare(db)
s.tableName = "state_snapshots"
return s, nil
} }
func (s *stateSnapshotStatements) InsertState( func (s *stateSnapshotStatements) InsertState(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
) (stateNID types.StateSnapshotNID, err error) { ) (stateNID types.StateSnapshotNID, err error) {
stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs)
if err != nil { // INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
return // VALUES ($1, $2);`
stateSnapshotNIDSeq, seqErr := GetNextStateSnapshotNID(s, ctx)
if seqErr != nil {
return 0, seqErr
} }
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)) data := StateSnapshotCosmos{
RoomNID: int64(roomNID),
StateBlockNIDs: mapFromStateBlockNIDArray(stateBlockNIDs),
StateSnapshotNID: int64(stateSnapshotNIDSeq),
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
docId := fmt.Sprintf("%d", stateSnapshotNIDSeq)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = StateSnapshotCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
StateSnapshot: data,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
if err != nil { if err != nil {
return 0, err return 0, err
} }
lastRowID, err := res.LastInsertId()
if err != nil { stateNID = types.StateSnapshotNID(stateSnapshotNIDSeq)
return 0, err
}
stateNID = types.StateSnapshotNID(lastRowID)
return return
} }
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID, ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) { ) ([]types.StateBlockNIDList, error) {
nids := make([]interface{}, len(stateNIDs))
for k, v := range stateNIDs { // "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
nids[k] = v // " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []StateSnapshotCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": stateNIDs,
} }
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
selectStmt, err := s.db.Prepare(selectOrig) var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.bulkSelectStateBlockNIDsStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rows, err := selectStmt.QueryContext(ctx, nids...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed")
results := make([]types.StateBlockNIDList, len(stateNIDs)) results := make([]types.StateBlockNIDList, len(stateNIDs))
i := 0 i := 0
for ; rows.Next(); i++ { for _, item := range response {
result := &results[i] result := &results[i]
var stateBlockNIDsJSON string result.StateSnapshotNID = types.StateSnapshotNID(item.StateSnapshot.StateSnapshotNID)
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil { result.StateBlockNIDs = mapToStateBlockNIDArray(item.StateSnapshot.StateBlockNIDs)
return nil, err i++
}
if err := json.Unmarshal([]byte(stateBlockNIDsJSON), &result.StateBlockNIDs); err != nil {
return nil, err
}
} }
if i != len(stateNIDs) { if i != len(stateNIDs) {
return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs)) return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))

View file

@ -17,14 +17,15 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -33,16 +34,26 @@ import (
// A Database is used to store room events and stream offsets. // A Database is used to store room events and stream offsets.
type Database struct { type Database struct {
shared.Database shared.Database
connection cosmosdbapi.CosmosConnection
databaseName string
cosmosConfig cosmosdbapi.CosmosConfig
serverName gomatrixserverlib.ServerName
} }
// Open a sqlite database. // Open a sqlite database.
func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
var d Database conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
var db *sql.DB config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
var err error d := &Database{
if db, err = sqlutil.Open(dbProperties); err != nil { databaseName: "roomserver",
return nil, err connection: conn,
cosmosConfig: config,
} }
// var db *sql.DB
// var err error
// if db, err = sqlutil.Open(dbProperties); err != nil {
// return nil, err
// }
//db.Exec("PRAGMA journal_mode=WAL;") //db.Exec("PRAGMA journal_mode=WAL;")
//db.Exec("PRAGMA read_uncommitted = true;") //db.Exec("PRAGMA read_uncommitted = true;")
@ -51,89 +62,91 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches)
// cause the roomserver to be unresponsive to new events because something will // cause the roomserver to be unresponsive to new events because something will
// acquire the global mutex and never unlock it because it is waiting for a connection // acquire the global mutex and never unlock it because it is waiting for a connection
// which it will never obtain. // which it will never obtain.
db.SetMaxOpenConns(20) // db.SetMaxOpenConns(20)
// Create tables before executing migrations so we don't fail if the table is missing, // Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns // and THEN prepare statements so we don't fail due to referencing new columns
ms := membershipStatements{} // ms := membershipStatements{}
if err := ms.execSchema(db); err != nil { // if err := ms.execSchema(db); err != nil {
return nil, err // return nil, err
} // }
m := sqlutil.NewMigrations() // m := sqlutil.NewMigrations()
deltas.LoadAddForgottenColumn(m) // deltas.LoadAddForgottenColumn(m)
if err := m.RunDeltas(db, dbProperties); err != nil { // if err := m.RunDeltas(db, dbProperties); err != nil {
return nil, err // return nil, err
} // }
if err := d.prepare(db, cache); err != nil { if err := d.prepare(cache); err != nil {
return nil, err return nil, err
} }
return &d, nil return d, nil
} }
// nolint: gocyclo // nolint: gocyclo
func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { func (d *Database) prepare(cache caching.RoomServerCaches) error {
var err error var err error
eventStateKeys, err := NewSqliteEventStateKeysTable(db) d.databaseName = "roomserver"
eventStateKeys, err := NewCosmosDBEventStateKeysTable(d)
if err != nil { if err != nil {
return err return err
} }
eventTypes, err := NewSqliteEventTypesTable(db) eventTypes, err := NewCosmosDBEventTypesTable(d)
if err != nil { if err != nil {
return err return err
} }
eventJSON, err := NewSqliteEventJSONTable(db) eventJSON, err := NewCosmosDBEventJSONTable(d)
if err != nil { if err != nil {
return err return err
} }
events, err := NewSqliteEventsTable(db) events, err := NewCosmosDBEventsTable(d)
if err != nil { if err != nil {
return err return err
} }
rooms, err := NewSqliteRoomsTable(db) rooms, err := NewCosmosDBRoomsTable(d)
if err != nil { if err != nil {
return err return err
} }
transactions, err := NewSqliteTransactionsTable(db) transactions, err := NewCosmosDBTransactionsTable(d)
if err != nil { if err != nil {
return err return err
} }
stateBlock, err := NewSqliteStateBlockTable(db) stateBlock, err := NewCosmosDBStateBlockTable(d)
if err != nil { if err != nil {
return err return err
} }
stateSnapshot, err := NewSqliteStateSnapshotTable(db) stateSnapshot, err := NewCosmosDBStateSnapshotTable(d)
if err != nil { if err != nil {
return err return err
} }
prevEvents, err := NewSqlitePrevEventsTable(db) prevEvents, err := NewCosmosDBPrevEventsTable(d)
if err != nil { if err != nil {
return err return err
} }
roomAliases, err := NewSqliteRoomAliasesTable(db) roomAliases, err := NewCosmosDBRoomAliasesTable(d)
if err != nil { if err != nil {
return err return err
} }
invites, err := NewSqliteInvitesTable(db) invites, err := NewCosmosDBInvitesTable(d)
if err != nil { if err != nil {
return err return err
} }
membership, err := NewSqliteMembershipTable(db) membership, err := NewCosmosDBMembershipTable(d)
if err != nil { if err != nil {
return err return err
} }
published, err := NewSqlitePublishedTable(db) published, err := NewCosmosDBPublishedTable(d)
if err != nil { if err != nil {
return err return err
} }
redactions, err := NewSqliteRedactionsTable(db) redactions, err := NewCosmosDBRedactionsTable(d)
if err != nil { if err != nil {
return err return err
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: db, DB: nil,
Cache: cache, Cache: cache,
Writer: sqlutil.NewExclusiveWriter(), //Use the Fake SQL Writer here
Writer: cosmosdbutil.NewExclusiveWriterFake(),
EventsTable: events, EventsTable: events,
EventTypesTable: eventTypes, EventTypesTable: eventTypes,
EventStateKeysTable: eventStateKeys, EventStateKeysTable: eventStateKeys,

View file

@ -18,50 +18,84 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
) )
const transactionsSchema = ` // const transactionsSchema = `
CREATE TABLE IF NOT EXISTS roomserver_transactions ( // CREATE TABLE IF NOT EXISTS roomserver_transactions (
transaction_id TEXT NOT NULL, // transaction_id TEXT NOT NULL,
session_id INTEGER NOT NULL, // session_id INTEGER NOT NULL,
user_id TEXT NOT NULL, // user_id TEXT NOT NULL,
event_id TEXT NOT NULL, // event_id TEXT NOT NULL,
PRIMARY KEY (transaction_id, session_id, user_id) // PRIMARY KEY (transaction_id, session_id, user_id)
); // );
` // `
const insertTransactionSQL = `
INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)
VALUES ($1, $2, $3, $4)
`
const selectTransactionEventIDSQL = ` type TransactionCosmos struct {
SELECT event_id FROM roomserver_transactions TransactionID string `json:"transaction_id"`
WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3 SessionID int64 `json:"session_id"`
` UserID string `json:"user_id"`
EventID string `json:"event_id"`
type transactionStatements struct {
db *sql.DB
insertTransactionStmt *sql.Stmt
selectTransactionEventIDStmt *sql.Stmt
} }
func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) { type TransactionCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Transaction TransactionCosmos `json:"mx_roomserver_transaction"`
}
// const insertTransactionSQL = `
// INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)
// VALUES ($1, $2, $3, $4)
// `
// const selectTransactionEventIDSQL = `
// SELECT event_id FROM roomserver_transactions
// WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3
// `
type transactionStatements struct {
db *Database
// insertTransactionStmt *sql.Stmt
selectTransactionEventIDStmt *sql.Stmt
tableName string
}
func getTransaction(s *transactionStatements, ctx context.Context, pk string, docId string) (*TransactionCosmosData, error) {
response := TransactionCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func NewCosmosDBTransactionsTable(db *Database) (tables.Transactions, error) {
s := &transactionStatements{ s := &transactionStatements{
db: db, db: db,
} }
_, err := db.Exec(transactionsSchema) // return s, shared.StatementList{
if err != nil { // {&s.insertTransactionStmt, insertTransactionSQL},
return nil, err // {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL},
} // }.Prepare(db)
s.tableName = "transactions"
return s, shared.StatementList{ return s, nil
{&s.insertTransactionStmt, insertTransactionSQL},
{&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL},
}.Prepare(db)
} }
func (s *transactionStatements) InsertTransaction( func (s *transactionStatements) InsertTransaction(
@ -71,10 +105,39 @@ func (s *transactionStatements) InsertTransaction(
userID string, userID string,
eventID string, eventID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt)
_, err := stmt.ExecContext( // INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)
ctx, transactionID, sessionID, userID, eventID, // VALUES ($1, $2, $3, $4)
) data := TransactionCosmos{
EventID: eventID,
SessionID: sessionID,
TransactionID: transactionID,
UserID: userID,
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// PRIMARY KEY (transaction_id, session_id, user_id)
docId := fmt.Sprintf("%s_%d_%s", transactionID, sessionID, userID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = TransactionCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
Transaction: data,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err return err
} }
@ -84,8 +147,21 @@ func (s *transactionStatements) SelectTransactionEventID(
sessionID int64, sessionID int64,
userID string, userID string,
) (eventID string, err error) { ) (eventID string, err error) {
err = s.selectTransactionEventIDStmt.QueryRowContext(
ctx, transactionID, sessionID, userID, // SELECT event_id FROM roomserver_transactions
).Scan(&eventID) // WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3
return
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// PRIMARY KEY (transaction_id, session_id, user_id)
docId := fmt.Sprintf("%s_%d_%s", transactionID, sessionID, userID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response, err := getTransaction(s, ctx, pk, cosmosDocId)
if err != nil {
return "", err
}
return response.Transaction.EventID, err
} }

View file

@ -71,6 +71,27 @@ func (s *accountDataStatements) prepare(db *Database) (err error) {
return return
} }
func queryAccountData(s *accountDataStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountDataCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []AccountDataCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error { ) error {
@ -92,10 +113,14 @@ func (s *accountDataStatements) insertAccountData(
id = fmt.Sprintf("%s_%s_%s", result.LocalPart, result.RoomId, result.Type) id = fmt.Sprintf("%s_%s_%s", result.LocalPart, result.RoomId, result.Type)
} }
docId := id
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = AccountDataCosmosData{ var dbData = AccountDataCosmosData{
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, id), Id: cosmosDocId,
Cn: dbCollectionName, Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName), Pk: pk,
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
AccountData: result, AccountData: result,
} }
@ -120,24 +145,15 @@ func (s *accountDataStatements) selectAccountData(
) { ) {
// "SELECT room_id, type, content FROM account_data WHERE localpart = $1" // "SELECT room_id, type, content FROM account_data WHERE localpart = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []AccountDataCosmosData{}
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
"@x2": localpart, "@x2": localpart,
} }
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectAccountDataStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
if ex != nil { response, err := queryAccountData(s, ctx, s.selectAccountDataStmt, params)
return nil, nil, ex
if err != nil {
return nil, nil, err
} }
global := map[string]json.RawMessage{} global := map[string]json.RawMessage{}
@ -166,26 +182,17 @@ func (s *accountDataStatements) selectAccountDataByType(
// "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" // "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []AccountDataCosmosData{}
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
"@x2": localpart, "@x2": localpart,
"@x3": roomID, "@x3": roomID,
"@x4": dataType, "@x4": dataType,
} }
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectAccountDataByTypeStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
if ex != nil { response, err := queryAccountData(s, ctx, s.selectAccountDataByTypeStmt, params)
return nil, ex
if err != nil {
return nil, err
} }
if len(response) == 0 { if len(response) == 0 {

View file

@ -21,6 +21,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
@ -87,17 +88,42 @@ func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.Serv
return return
} }
func getAccount(s *accountsStatements, ctx context.Context, pk string, docId string) (*AccountCosmosData, error) { func queryAccount(s *accountsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountCosmosData, error) {
response := AccountCosmosData{} var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk) var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument( var response []AccountCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx, ctx,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getAccount(s *accountsStatements, ctx context.Context, pk string, docId string) (*AccountCosmosData, error) {
response := AccountCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId, docId,
optionsGet,
&response) &response)
return &response, ex
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
} }
func setAccount(s *accountsStatements, ctx context.Context, pk string, account AccountCosmosData) (*AccountCosmosData, error) { func setAccount(s *accountsStatements, ctx context.Context, pk string, account AccountCosmosData) (*AccountCosmosData, error) {
@ -155,10 +181,14 @@ func (s *accountsStatements) insertAccount(
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
docId := result.Localpart
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = AccountCosmosData{ var dbData = AccountCosmosData{
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Localpart), Id: cosmosDocId,
Cn: dbCollectionName, Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName), Pk: pk,
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
Account: data, Account: data,
} }
@ -184,10 +214,11 @@ func (s *accountsStatements) updatePassword(
// "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" // "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart) docId := localpart
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getAccount(s, ctx, pk, docId) var response, exGet = getAccount(s, ctx, pk, cosmosDocId)
if exGet != nil { if exGet != nil {
return exGet return exGet
} }
@ -207,10 +238,12 @@ func (s *accountsStatements) deactivateAccount(
// "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" // "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getAccount(s, ctx, pk, docId) docId := localpart
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getAccount(s, ctx, pk, cosmosDocId)
if exGet != nil { if exGet != nil {
return exGet return exGet
} }
@ -230,24 +263,15 @@ func (s *accountsStatements) selectPasswordHash(
// "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" // "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []AccountCosmosData{}
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
"@x2": localpart, "@x2": localpart,
} }
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectPasswordHashStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
if ex != nil { response, err := queryAccount(s, ctx, s.selectPasswordHashStmt, params)
return "", ex
if err != nil {
return "", err
} }
if len(response) == 0 { if len(response) == 0 {
@ -268,24 +292,15 @@ func (s *accountsStatements) selectAccountByLocalpart(
// "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" // "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []AccountCosmosData{}
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
"@x2": localpart, "@x2": localpart,
} }
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectAccountByLocalpartStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
if ex != nil { response, err := queryAccount(s, ctx, s.selectAccountByLocalpartStmt, params)
return nil, ex
if err != nil {
return nil, err
} }
if len(response) == 0 { if len(response) == 0 {

View file

@ -62,6 +62,27 @@ func mapToToken(api api.OpenIDToken) OpenIDTokenCosmos {
} }
} }
func queryOpenIdToken(s *tokenStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OpenIdTokenCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []OpenIdTokenCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func (s *tokenStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) { func (s *tokenStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
s.db = db s.db = db
s.selectTokenStmt = "select * from c where c._cn = @x1 and c.mx_userapi_openidtoken.token = @x2" s.selectTokenStmt = "select * from c where c._cn = @x1 and c.mx_userapi_openidtoken.token = @x2"
@ -87,10 +108,14 @@ func (s *tokenStatements) insertToken(
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
docId := result.Token
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = OpenIdTokenCosmosData{ var dbData = OpenIdTokenCosmosData{
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Token), Id: cosmosDocId,
Cn: dbCollectionName, Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName), Pk: pk,
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
OpenIdToken: mapToToken(*result), OpenIdToken: mapToToken(*result),
} }
@ -120,24 +145,14 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes(
// "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" // "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []OpenIdTokenCosmosData{}
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
"@x2": token, "@x2": token,
} }
var options = cosmosdbapi.GetQueryDocumentsOptions(pk) response, err := queryOpenIdToken(s, ctx, s.selectTokenStmt, params)
var query = cosmosdbapi.GetQuery(s.selectTokenStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
if ex != nil { if err != nil {
return nil, ex return nil, err
} }
if len(response) == 0 { if len(response) == 0 {

View file

@ -21,6 +21,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
) )
@ -87,17 +88,42 @@ func (s *profilesStatements) prepare(db *Database) (err error) {
return return
} }
func getProfile(s *profilesStatements, ctx context.Context, pk string, docId string) (*ProfileCosmosData, error) { func queryProfile(s *profilesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ProfileCosmosData, error) {
response := ProfileCosmosData{} var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk) var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument( var response []ProfileCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx, ctx,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getProfile(s *profilesStatements, ctx context.Context, pk string, docId string) (*ProfileCosmosData, error) {
response := ProfileCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId, docId,
optionsGet,
&response) &response)
return &response, ex
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
} }
func setProfile(s *profilesStatements, ctx context.Context, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) { func setProfile(s *profilesStatements, ctx context.Context, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) {
@ -123,10 +149,14 @@ func (s *profilesStatements) insertProfile(
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
docId := localpart
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = ProfileCosmosData{ var dbData = ProfileCosmosData{
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Localpart), Id: cosmosDocId,
Cn: dbCollectionName, Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName), Pk: pk,
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
Profile: mapToProfile(*result), Profile: mapToProfile(*result),
} }
@ -148,24 +178,15 @@ func (s *profilesStatements) selectProfileByLocalpart(
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" // "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []ProfileCosmosData{}
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
"@x2": localpart, "@x2": localpart,
} }
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectProfileByLocalpartStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
if ex != nil { response, err := queryProfile(s, ctx, s.selectProfileByLocalpartStmt, params)
return nil, ex
if err != nil {
return nil, err
} }
if len(response) == 0 { if len(response) == 0 {
@ -186,10 +207,11 @@ func (s *profilesStatements) setAvatarURL(
// "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" // "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) docId := localpart
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getProfile(s, ctx, pk, docId) var response, exGet = getProfile(s, ctx, pk, cosmosDocId)
if exGet != nil { if exGet != nil {
return exGet return exGet
} }
@ -209,9 +231,10 @@ func (s *profilesStatements) setDisplayName(
// "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" // "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) docId := localpart
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var response, exGet = getProfile(s, ctx, pk, docId) pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getProfile(s, ctx, pk, cosmosDocId)
if exGet != nil { if exGet != nil {
return exGet return exGet
} }
@ -232,25 +255,16 @@ func (s *profilesStatements) selectProfilesBySearch(
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" // "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []ProfileCosmosData{}
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
"@x2": searchString, "@x2": searchString,
"@x3": limit, "@x3": limit,
} }
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectProfilesBySearchStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
if ex != nil { response, err := queryProfile(s, ctx, s.selectProfilesBySearchStmt, params)
return nil, ex
if err != nil {
return nil, err
} }
for i := 0; i < len(response); i++ { for i := 0; i < len(response); i++ {

View file

@ -20,6 +20,8 @@ import (
"errors" "errors"
"strconv" "strconv"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil" "github.com/matrix-org/dendrite/internal/cosmosdbutil"
@ -27,7 +29,6 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -37,6 +38,7 @@ import (
// Database represents an account database // Database represents an account database
type Database struct { type Database struct {
sqlutil.PartitionOffsetStatements sqlutil.PartitionOffsetStatements
writer sqlutil.Writer
accounts accountsStatements accounts accountsStatements
profiles profilesStatements profiles profilesStatements
accountDatas accountDataStatements accountDatas accountDataStatements
@ -62,7 +64,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
connection: conn, connection: conn,
cosmosConfig: config, cosmosConfig: config,
// db: db, // db: db,
// writer: sqlutil.NewExclusiveWriter(), writer: sqlutil.NewExclusiveWriter(),
// bcryptCost: bcryptCost, // bcryptCost: bcryptCost,
// openIDTokenLifetimeMS: openIDTokenLifetimeMS, // openIDTokenLifetimeMS: openIDTokenLifetimeMS,
} }

View file

@ -69,31 +69,43 @@ func (s *threepidStatements) prepare(db *Database) (err error) {
return return
} }
func queryThreePID(s *threepidStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ThreePIDCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []ThreePIDCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func (s *threepidStatements) selectLocalpartForThreePID( func (s *threepidStatements) selectLocalpartForThreePID(
ctx context.Context, threepid string, medium string, ctx context.Context, threepid string, medium string,
) (localpart string, err error) { ) (localpart string, err error) {
// "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" // "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []ThreePIDCosmosData{}
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
"@x2": threepid, "@x2": threepid,
"@x3": medium, "@x3": medium,
} }
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectLocalpartForThreePIDStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
if ex != nil { response, err := queryThreePID(s, ctx, s.selectLocalpartForThreePIDStmt, params)
return "", ex
if err != nil {
return "", err
} }
if len(response) == 0 { if len(response) == 0 {
@ -109,24 +121,14 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
// "SELECT threepid, medium FROM account_threepid WHERE localpart = $1" // "SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []ThreePIDCosmosData{}
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
"@x2": localpart, "@x2": localpart,
} }
var options = cosmosdbapi.GetQueryDocumentsOptions(pk) response, err := queryThreePID(s, ctx, s.selectThreePIDsForLocalpartStmt, params)
var query = cosmosdbapi.GetQuery(s.selectThreePIDsForLocalpartStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
if ex != nil { if err != nil {
return threepids, ex return threepids, err
} }
if len(response) == 0 { if len(response) == 0 {
@ -158,10 +160,11 @@ func (s *threepidStatements) insertThreePID(
docId := fmt.Sprintf("%s_%s", threepid, medium) docId := fmt.Sprintf("%s_%s", threepid, medium)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = ThreePIDCosmosData{ var dbData = ThreePIDCosmosData{
Id: cosmosDocId, Id: cosmosDocId,
Cn: dbCollectionName, Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName), Pk: pk,
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
ThreePID: result, ThreePID: result,
} }

View file

@ -16,10 +16,11 @@ package cosmosdb
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time" "time"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
@ -82,15 +83,15 @@ type DeviceCosmosSessionCount struct {
} }
type devicesStatements struct { type devicesStatements struct {
db *Database db *Database
selectDevicesCountStmt string selectDevicesCountStmt string
selectDeviceByTokenStmt string selectDeviceByTokenStmt string
// selectDeviceByIDStmt *sql.Stmt // selectDeviceByIDStmt *sql.Stmt
selectDevicesByIDStmt string selectDevicesByIDStmt string
selectDevicesByLocalpartStmt string selectDevicesByLocalpartStmt string
selectDevicesByLocalpartExceptIDStmt string selectDevicesByLocalpartExceptIDStmt string
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
tableName string tableName string
} }
func mapFromDevice(db DeviceCosmos) api.Device { func mapFromDevice(db DeviceCosmos) api.Device {
@ -121,17 +122,42 @@ func mapTodevice(api api.Device, s *devicesStatements) DeviceCosmos {
} }
} }
func getDevice(s *devicesStatements, ctx context.Context, pk string, docId string) (*DeviceCosmosData, error) { func queryDevice(s *devicesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceCosmosData, error) {
response := DeviceCosmosData{} var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk) var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument( var response []DeviceCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx, ctx,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getDevice(s *devicesStatements, ctx context.Context, pk string, docId string) (*DeviceCosmosData, error) {
response := DeviceCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId, docId,
optionsGet,
&response) &response)
return &response, ex
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
} }
func setDevice(s *devicesStatements, ctx context.Context, pk string, device DeviceCosmosData) (*DeviceCosmosData, error) { func setDevice(s *devicesStatements, ctx context.Context, pk string, device DeviceCosmosData) (*DeviceCosmosData, error) {
@ -209,10 +235,11 @@ func (s *devicesStatements) insertDevice(
// HACK: check for duplicate PK as we are using the UNIQUE key for the DocId // HACK: check for duplicate PK as we are using the UNIQUE key for the DocId
docId := fmt.Sprintf("%s_%s", localpart, id) docId := fmt.Sprintf("%s_%s", localpart, id)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var dbData = DeviceCosmosData{ var dbData = DeviceCosmosData{
Id: cosmosDocId, Id: cosmosDocId,
Cn: dbCollectionName, Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName), Pk: pk,
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
Device: data, Device: data,
} }
@ -260,7 +287,6 @@ func (s *devicesStatements) deleteDevices(
) error { ) error {
// "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" // "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData var response []DeviceCosmosData
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
@ -268,15 +294,8 @@ func (s *devicesStatements) deleteDevices(
"@x3": devices, "@x3": devices,
} }
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartStmt, params)
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return err return err
} }
@ -291,8 +310,6 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
) error { ) error {
// "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" // "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData
exceptDevices := []string{ exceptDevices := []string{
exceptDeviceID, exceptDeviceID,
} }
@ -302,15 +319,8 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
"@x3": exceptDevices, "@x3": exceptDevices,
} }
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartStmt, params)
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return err return err
} }
@ -325,9 +335,9 @@ func (s *devicesStatements) updateDeviceName(
) error { ) error {
// "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" // "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
docId := fmt.Sprintf("%s_%s", localpart, deviceID) docId := fmt.Sprintf("%s_%s", localpart, deviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getDevice(s, ctx, pk, cosmosDocId) var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
if exGet != nil { if exGet != nil {
return exGet return exGet
@ -347,27 +357,19 @@ func (s *devicesStatements) selectDeviceByToken(
) (*api.Device, error) { ) (*api.Device, error) {
// "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" // "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData var response []DeviceCosmosData
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
"@x2": accessToken, "@x2": accessToken,
} }
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) response, err := queryDevice(s, ctx, s.selectDeviceByTokenStmt, params)
var query = cosmosdbapi.GetQuery(s.selectDeviceByTokenStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(response) == 0 { if len(response) == 0 {
return nil, errors.New(fmt.Sprintf("No Devices found with AccessToken %s", accessToken)) return nil, cosmosdbutil.ErrNoRows
} }
if err == nil { if err == nil {
@ -384,9 +386,9 @@ func (s *devicesStatements) selectDeviceByID(
) (*api.Device, error) { ) (*api.Device, error) {
// "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" // "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
docId := fmt.Sprintf("%s_%s", localpart, deviceID) docId := fmt.Sprintf("%s_%s", localpart, deviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getDevice(s, ctx, pk, cosmosDocId) var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
if exGet != nil { if exGet != nil {
return nil, exGet return nil, exGet
@ -401,23 +403,14 @@ func (s *devicesStatements) selectDevicesByLocalpart(
devices := []api.Device{} devices := []api.Device{}
// "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2" // "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
"@x2": localpart, "@x2": localpart,
"@x3": exceptDeviceID, "@x3": exceptDeviceID,
} }
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartExceptIDStmt, params)
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartExceptIDStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -435,22 +428,14 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
// "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)" // "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
var devices []api.Device var devices []api.Device
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData var response []DeviceCosmosData
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
"@x2": deviceIDs, "@x2": deviceIDs,
} }
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) response, err := queryDevice(s, ctx, s.selectDevicesByIDStmt, params)
var query = cosmosdbapi.GetQuery(s.selectDevicesByIDStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -466,9 +451,9 @@ func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, localpart,
// "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4" // "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
docId := fmt.Sprintf("%s_%s", localpart, deviceID) docId := fmt.Sprintf("%s_%s", localpart, deviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getDevice(s, ctx, pk, cosmosDocId) var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
if exGet != nil { if exGet != nil {
return exGet return exGet

View file

@ -23,7 +23,6 @@ import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil" "github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
@ -37,7 +36,6 @@ var deviceIDByteLength = 6
// Database represents a device database. // Database represents a device database.
type Database struct { type Database struct {
writer sqlutil.Writer
devices devicesStatements devices devicesStatements
connection cosmosdbapi.CosmosConnection connection cosmosdbapi.CosmosConnection
databaseName string databaseName string