- Implement the SycAPI to use CosmosDB (#8)

- Update the Config to use Cosmos for the sync API
- Ensure Cosmos DocId does not contain escape chars
- Create a shared Cosmos PartitionOffet table and refactor to use it
- Hardcode the "nafka" Connstring to use the "file:naffka.db"
- Create seq documents for each of the nextXXXID methods
This commit is contained in:
alexfca 2021-05-27 18:45:53 +10:00 committed by GitHub
parent af4219f38e
commit 3ca96b13b3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 3300 additions and 1048 deletions

View file

@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
// Import SQLite database driver
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
_ "github.com/mattn/go-sqlite3"
@ -31,7 +31,8 @@ import (
// Database stores events intended to be later sent to application services
type Database struct {
sqlutil.PartitionOffsetStatements
database cosmosdbutil.Database
cosmosdbutil.PartitionOffsetStatements
events eventsStatements
txnID txnStatements
writer cosmosdbutil.Writer
@ -44,14 +45,23 @@ type Database struct {
// NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
configCosmos := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
result := &Database{
databaseName: "appservice",
connection: conn,
cosmosConfig: config,
cosmosConfig: configCosmos,
}
result.database = cosmosdbutil.Database{
Connection: conn,
CosmosConfig: configCosmos,
DatabaseName: result.databaseName,
}
var err error
result.writer = cosmosdbutil.NewExclusiveWriterFake()
if err = result.PartitionOffsetStatements.Prepare(&result.database, result.writer, "appservice"); err != nil {
return nil, err
}
if err = result.prepare(); err != nil {
return nil, err
}

View file

@ -6,7 +6,7 @@
#
# At a minimum, to get started, you will need to update the settings in the
# "global" section for your deployment, and you will need to check that the
# database "connection_string" line in each component section is correct.
# database "connection_string" line in each component section is correct.
#
# Each component with a "database" section can accept the following formats
# for "connection_string":
@ -23,13 +23,13 @@
# small number of users and likely will perform worse still with a higher volume
# of users.
#
# The "max_open_conns" and "max_idle_conns" settings configure the maximum
# The "max_open_conns" and "max_idle_conns" settings configure the maximum
# number of open/idle database connections. The value 0 will use the database
# engine default, and a negative value will use unlimited connections. The
# "conn_max_lifetime" option controls the maximum length of time a database
# connection can be idle in seconds - a negative value is unlimited.
# The version of the configuration file.
# The version of the configuration file.
version: 1
# Global Matrix configuration. This configuration applies to all components.
@ -154,13 +154,13 @@ client_api:
# Whether to require reCAPTCHA for registration.
enable_registration_captcha: false
# Settings for ReCAPTCHA.
# Settings for ReCAPTCHA.
recaptcha_public_key: ""
recaptcha_private_key: ""
recaptcha_bypass_secret: ""
recaptcha_siteverify_api: ""
# TURN server information that this homeserver should send to clients.
# TURN server information that this homeserver should send to clients.
turn:
turn_user_lifetime: ""
turn_uris: []
@ -169,7 +169,7 @@ client_api:
turn_password: ""
# Settings for rate-limited endpoints. Rate limiting will kick in after the
# threshold number of "slots" have been taken by requests from a specific
# threshold number of "slots" have been taken by requests from a specific
# host. Each "slot" will be released after the cooloff time in milliseconds.
rate_limiting:
enabled: true
@ -331,7 +331,7 @@ sync_api:
external_api:
listen: http://[::]:8073
database:
connection_string: file:syncapi.db
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;"
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@ -363,9 +363,9 @@ user_api:
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
# The length of time that a token issued for a relying party from
# The length of time that a token issued for a relying party from
# /_matrix/client/r0/user/{userId}/openid/request_token endpoint
# is considered to be valid in milliseconds.
# is considered to be valid in milliseconds.
# The default lifetime is 3600000ms (60 minutes).
# openid_token_lifetime_ms: 3600000

View file

@ -3,10 +3,23 @@ package cosmosdbapi
import (
"context"
"fmt"
"strings"
)
func removeSpecialChars(docId string) string {
// The following characters are restricted and cannot be used in the Id property: '/', '\', '?', '#'
invalidChars := [4]string{"/", "\\", "?", "#"}
replaceChar := ","
result := docId
for _, invalidChar := range invalidChars {
result = strings.ReplaceAll(result, invalidChar, replaceChar)
}
return result
}
func GetDocumentId(tenantName string, collectionName string, id string) string {
return fmt.Sprintf("%s,%s,%s", collectionName, tenantName, id)
safeId := removeSpecialChars(id)
return fmt.Sprintf("%s,%s,%s", collectionName, tenantName, safeId)
}
func GetPartitionKey(tenantName string, collectionName string) string {

View file

@ -0,0 +1,227 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cosmosdbutil
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/gomatrixserverlib"
)
// // A PartitionOffset is the offset into a partition of the input log.
// type PartitionOffset struct {
// // The ID of the partition.
// Partition int32
// // The offset into the partition.
// Offset int64
// }
// const partitionOffsetsSchema = `
// -- The offsets that the server has processed up to.
// CREATE TABLE IF NOT EXISTS ${prefix}_partition_offsets (
// -- The name of the topic.
// topic TEXT NOT NULL,
// -- The 32-bit partition ID
// partition INTEGER NOT NULL,
// -- The 64-bit offset.
// partition_offset BIGINT NOT NULL,
// UNIQUE (topic, partition)
// );
// `
type PartitionOffsetCosmos struct {
Topic string `json:"topic"`
Partition int32 `json:"partition"`
PartitionOffset int64 `json:"partition_offset"`
}
type PartitionOffsetCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
PartitionOffset PartitionOffsetCosmos `json:"mx_partition_offset"`
}
// "SELECT partition, partition_offset FROM ${prefix}_partition_offsets WHERE topic = $1"
const selectPartitionOffsetsSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_partition_offset.topic = @x2 "
// const upsertPartitionOffsetsSQL = "" +
// "INSERT INTO ${prefix}_partition_offsets (topic, partition, partition_offset) VALUES ($1, $2, $3)" +
// " ON CONFLICT (topic, partition)" +
// " DO UPDATE SET partition_offset = $3"
type Database struct {
Connection cosmosdbapi.CosmosConnection
DatabaseName string
CosmosConfig cosmosdbapi.CosmosConfig
ServerName gomatrixserverlib.ServerName
}
// PartitionOffsetStatements represents a set of statements that can be run on a partition_offsets table.
type PartitionOffsetStatements struct {
db *Database
writer Writer
selectPartitionOffsetsStmt string
// upsertPartitionOffsetStmt *sql.Stmt
prefix string
tableName string
}
func queryPartitionOffset(s *PartitionOffsetStatements, ctx context.Context, qry string, params map[string]interface{}) ([]PartitionOffsetCosmosData, error) {
var dbCollectionName = getCollectionName(*s)
var pk = cosmosdbapi.GetPartitionKey(s.db.CosmosConfig.ContainerName, dbCollectionName)
var response []PartitionOffsetCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.Connection).QueryDocuments(
ctx,
s.db.CosmosConfig.DatabaseName,
s.db.CosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
// Prepare converts the raw SQL statements into prepared statements.
// Takes a prefix to prepend to the table name used to store the partition offsets.
// This allows multiple components to share the same database schema.
func (s *PartitionOffsetStatements) Prepare(db *Database, writer Writer, prefix string) (err error) {
s.db = db
s.writer = writer
s.selectPartitionOffsetsStmt = selectPartitionOffsetsSQL
s.prefix = prefix
s.tableName = "partition_offsets"
return
}
// PartitionOffsets implements PartitionStorer
func (s *PartitionOffsetStatements) PartitionOffsets(
ctx context.Context, topic string,
) ([]sqlutil.PartitionOffset, error) {
return s.selectPartitionOffsets(ctx, topic)
}
// SetPartitionOffset implements PartitionStorer
func (s *PartitionOffsetStatements) SetPartitionOffset(
ctx context.Context, topic string, partition int32, offset int64,
) error {
return s.upsertPartitionOffset(ctx, topic, partition, offset)
}
// selectPartitionOffsets returns all the partition offsets for the given topic.
func (s *PartitionOffsetStatements) selectPartitionOffsets(
ctx context.Context, topic string,
) (results []sqlutil.PartitionOffset, err error) {
// "SELECT partition, partition_offset FROM ${prefix}_partition_offsets WHERE topic = $1"
var dbCollectionName = getCollectionName(*s)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": topic,
}
rows, err := queryPartitionOffset(s, ctx, s.selectPartitionOffsetsStmt, params)
// rows, err := s.selectPartitionOffsetsStmt.QueryContext(ctx, topic)
if err != nil {
return nil, err
}
for _, item := range rows {
var offset sqlutil.PartitionOffset
// if err = rows.Scan(&offset.Partition, &offset.Offset); err != nil {
// return nil, err
// }
offset.Partition = item.PartitionOffset.Partition
offset.Offset = item.PartitionOffset.PartitionOffset
results = append(results, offset)
}
return results, nil
}
// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil
func checkNamedErr(fn func() error, err *error) {
if e := fn(); e != nil && *err == nil {
*err = e
}
}
// UpsertPartitionOffset updates or inserts the partition offset for the given topic.
func (s *PartitionOffsetStatements) upsertPartitionOffset(
ctx context.Context, topic string, partition int32, offset int64,
) error {
return s.writer.Do(nil, nil, func(txn *sql.Tx) error {
// "INSERT INTO ${prefix}_partition_offsets (topic, partition, partition_offset) VALUES ($1, $2, $3)" +
// " ON CONFLICT (topic, partition)" +
// " DO UPDATE SET partition_offset = $3"
// stmt := TxStmt(txn, s.upsertPartitionOffsetStmt)
dbCollectionName := getCollectionName(*s)
// UNIQUE (topic, partition)
docId := fmt.Sprintf("%s_%d", topic, partition)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.CosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.CosmosConfig.ContainerName, dbCollectionName)
data := PartitionOffsetCosmos{
Partition: partition,
PartitionOffset: offset,
Topic: topic,
}
dbData := &PartitionOffsetCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
// nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
Timestamp: time.Now().Unix(),
PartitionOffset: data,
}
// _, err := stmt.ExecContext(ctx, topic, partition, offset)
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.Connection).CreateDocument(
ctx,
s.db.CosmosConfig.DatabaseName,
s.db.CosmosConfig.ContainerName,
&dbData,
options)
return err
})
}
func getCollectionName(s PartitionOffsetStatements) string {
// Include the Prefix
tableName := fmt.Sprintf("%s_%s", s.prefix, s.tableName)
return cosmosdbapi.GetCollectionName(s.db.DatabaseName, tableName)
}

View file

@ -53,9 +53,9 @@ type OneTimeKeyCosmos struct {
KeyJSON []byte `json:"key_json"`
}
type OneTimeKeyAlgoCountCosmosData struct {
type OneTimeKeyAlgoNumberCosmosData struct {
Algorithm string `json:"algorithm"`
Count int `json:"count"`
Number int `json:"number"`
}
type OneTimeKeyCosmosData struct {
@ -81,7 +81,7 @@ const selectKeysSQL = "" +
// "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
const selectKeysCountSQL = "" +
"select c.mx_keyserver_one_time_key.algorithm as algorithm, count(c.mx_keyserver_one_time_key.key_id) as count " +
"select c.mx_keyserver_one_time_key.algorithm, count(c.mx_keyserver_one_time_key.key_id) as number " +
"from c where c._cn = @x1 " +
"and c.mx_keyserver_one_time_key.user_id = @x2 " +
"and c.mx_keyserver_one_time_key.device_id = @x3 " +
@ -110,7 +110,9 @@ type oneTimeKeysStatements struct {
func queryOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OneTimeKeyCosmosData, error) {
var response []OneTimeKeyCosmosData
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
@ -127,18 +129,20 @@ func queryOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, qry string,
return response, nil
}
func queryOneTimeKeyAlgoCount(s *oneTimeKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OneTimeKeyAlgoCountCosmosData, error) {
var response []OneTimeKeyAlgoCountCosmosData
var test interface{}
func queryOneTimeKeyAlgoCount(s *oneTimeKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OneTimeKeyAlgoNumberCosmosData, error) {
var response []OneTimeKeyAlgoNumberCosmosData
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
// var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
var query = cosmosdbapi.GetQuery(qry, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&test,
&response,
optionsQry)
// When there are no Rows we seem to get the generic Bad Req JSON error
@ -252,7 +256,7 @@ func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, de
var algorithm string
var count int
algorithm = item.Algorithm
count = item.Count
count = item.Number
counts.KeyCount[algorithm] = count
}
return counts, nil
@ -324,7 +328,7 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(
var algorithm string
var count int
algorithm = item.Algorithm
count = item.Count
count = item.Number
counts.KeyCount[algorithm] = count
}

View file

@ -49,6 +49,7 @@ func setupNaffka(cfg *config.Kafka) (sarama.Consumer, sarama.SyncProducer) {
if cfg.Database.ConnectionString.IsCosmosDB() {
//TODO: What do we do for Nafka
// cfg.Database.ConnectionString = cosmosdbutil.GetConnectionString(&cfg.Database.ConnectionString)
cfg.Database.ConnectionString = "file:naffka.db"
}
naffkaDB, err := naffkaStorage.NewDatabase(string(cfg.Database.ConnectionString))

View file

@ -18,63 +18,127 @@ package cosmosdb
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
const accountDataSchema = `
CREATE TABLE IF NOT EXISTS syncapi_account_data_type (
id INTEGER PRIMARY KEY,
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
UNIQUE (user_id, room_id, type)
);
`
// const accountDataSchema = `
// CREATE TABLE IF NOT EXISTS syncapi_account_data_type (
// id INTEGER PRIMARY KEY,
// user_id TEXT NOT NULL,
// room_id TEXT NOT NULL,
// type TEXT NOT NULL,
// UNIQUE (user_id, room_id, type)
// );
// `
const insertAccountDataSQL = "" +
"INSERT INTO syncapi_account_data_type (id, user_id, room_id, type) VALUES ($1, $2, $3, $4)" +
" ON CONFLICT (user_id, room_id, type) DO UPDATE" +
" SET id = $5"
const selectAccountDataInRangeSQL = "" +
"SELECT room_id, type FROM syncapi_account_data_type" +
" WHERE user_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id ASC"
const selectMaxAccountDataIDSQL = "" +
"SELECT MAX(id) FROM syncapi_account_data_type"
type accountDataStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
insertAccountDataStmt *sql.Stmt
selectMaxAccountDataIDStmt *sql.Stmt
selectAccountDataInRangeStmt *sql.Stmt
type AccountDataTypeCosmos struct {
ID int64 `json:"id"`
UserID string `json:"user_id"`
RoomID string `json:"room_id"`
DataType string `json:"type"`
}
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
type AccountDataTypeNumberCosmosData struct {
Number int64 `json:"number"`
}
type AccountDataTypeCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
AccountDataType AccountDataTypeCosmos `json:"mx_syncapi_account_data_type"`
}
// const insertAccountDataSQL = "" +
// "INSERT INTO syncapi_account_data_type (id, user_id, room_id, type) VALUES ($1, $2, $3, $4)" +
// " ON CONFLICT (user_id, room_id, type) DO UPDATE" +
// " SET id = $5"
// "SELECT room_id, type FROM syncapi_account_data_type" +
// " WHERE user_id = $1 AND id > $2 AND id <= $3" +
// " ORDER BY id ASC"
const selectAccountDataInRangeSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_account_data_type.user_id = @x2 " +
"and c.mx_syncapi_account_data_type.id > @x3 " +
"and c.mx_syncapi_account_data_type.id < @x4 " +
"order by c.mx_syncapi_account_data_type.id "
// "SELECT MAX(id) FROM syncapi_account_data_type"
const selectMaxAccountDataIDSQL = "" +
"select max(c.mx_syncapi_account_data_type.id) as number from c where c._cn = @x1 "
type accountDataStatements struct {
db *SyncServerDatasource
streamIDStatements *streamIDStatements
insertAccountDataStmt *sql.Stmt
selectMaxAccountDataIDStmt string
selectAccountDataInRangeStmt string
tableName string
}
func queryAccountDataType(s *accountDataStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountDataTypeCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []AccountDataTypeCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryAccountDataTypeNumber(s *accountDataStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountDataTypeNumberCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []AccountDataTypeNumberCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, cosmosdbutil.ErrNoRows
}
return response, nil
}
func NewCosmosDBAccountDataTable(db *SyncServerDatasource, streamID *streamIDStatements) (tables.AccountData, error) {
s := &accountDataStatements{
db: db,
streamIDStatements: streamID,
}
_, err := db.Exec(accountDataSchema)
if err != nil {
return nil, err
}
if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil {
return nil, err
}
if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil {
return nil, err
}
if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil {
return nil, err
}
s.selectMaxAccountDataIDStmt = selectMaxAccountDataIDSQL
s.selectAccountDataInRangeStmt = selectAccountDataInRangeSQL
s.tableName = "account_data_types"
return s, nil
}
@ -82,11 +146,46 @@ func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx,
userID, roomID, dataType string,
) (pos types.StreamPosition, err error) {
// "INSERT INTO syncapi_account_data_type (id, user_id, room_id, type) VALUES ($1, $2, $3, $4)" +
// " ON CONFLICT (user_id, room_id, type) DO UPDATE" +
// " SET id = $5"
pos, err = s.streamIDStatements.nextAccountDataID(ctx, txn)
if err != nil {
return
}
_, err = sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType, pos)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (user_id, room_id, type)
docId := fmt.Sprintf("%s_%s_%s", userID, roomID, dataType)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := AccountDataTypeCosmos{
ID: int64(pos),
UserID: userID,
RoomID: roomID,
DataType: dataType,
}
dbData := &AccountDataTypeCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
AccountDataType: data,
}
// _, err = sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType, pos)
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return
}
@ -98,21 +197,32 @@ func (s *accountDataStatements) SelectAccountDataInRange(
) (data map[string][]string, err error) {
data = make(map[string][]string)
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High())
// "SELECT room_id, type FROM syncapi_account_data_type" +
// " WHERE user_id = $1 AND id > $2 AND id <= $3" +
// " ORDER BY id ASC"
// rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High())
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": userID,
"@x3": r.Low(),
"@x4": r.High(),
}
rows, err := queryAccountDataType(s, ctx, s.selectAccountDataInRangeStmt, params)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed")
var entries int
for rows.Next() {
for _, item := range rows {
var dataType string
var roomID string
if err = rows.Scan(&roomID, &dataType); err != nil {
return
}
roomID = item.AccountDataType.RoomID
dataType = item.AccountDataType.DataType
// check if we should add this by looking at the filter.
// It would be nice if we could do this in SQL-land, but the mix of variadic
@ -147,8 +257,22 @@ func (s *accountDataStatements) SelectAccountDataInRange(
func (s *accountDataStatements) SelectMaxAccountDataID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
// "SELECT MAX(id) FROM syncapi_account_data_type"
var nullableID sql.NullInt64
err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID)
// err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
}
rows, err := queryAccountDataTypeNumber(s, ctx, s.selectMaxAccountDataIDStmt, params)
if err != cosmosdbutil.ErrNoRows && len(rows) == 1 {
nullableID.Int64 = rows[0].Number
}
if nullableID.Valid {
id = nullableID.Int64
}

View file

@ -17,109 +17,238 @@ package cosmosdb
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
)
const backwardExtremitiesSchema = `
-- Stores output room events received from the roomserver.
CREATE TABLE IF NOT EXISTS syncapi_backward_extremities (
-- The 'room_id' key for the event.
room_id TEXT NOT NULL,
-- The event ID for the last known event. This is the backwards extremity.
event_id TEXT NOT NULL,
-- The prev_events for the last known event. This is used to update extremities.
prev_event_id TEXT NOT NULL,
PRIMARY KEY(room_id, event_id, prev_event_id)
);
`
// const backwardExtremitiesSchema = `
// -- Stores output room events received from the roomserver.
// CREATE TABLE IF NOT EXISTS syncapi_backward_extremities (
// -- The 'room_id' key for the event.
// room_id TEXT NOT NULL,
// -- The event ID for the last known event. This is the backwards extremity.
// event_id TEXT NOT NULL,
// -- The prev_events for the last known event. This is used to update extremities.
// prev_event_id TEXT NOT NULL,
// PRIMARY KEY(room_id, event_id, prev_event_id)
// );
// `
const insertBackwardExtremitySQL = "" +
"INSERT INTO syncapi_backward_extremities (room_id, event_id, prev_event_id)" +
" VALUES ($1, $2, $3)" +
" ON CONFLICT (room_id, event_id, prev_event_id) DO NOTHING"
const selectBackwardExtremitiesForRoomSQL = "" +
"SELECT event_id, prev_event_id FROM syncapi_backward_extremities WHERE room_id = $1"
const deleteBackwardExtremitySQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
const deleteBackwardExtremitiesForRoomSQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
type backwardExtremitiesStatements struct {
db *sql.DB
insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt
deleteBackwardExtremitiesForRoomStmt *sql.Stmt
type BackwardExtremityCosmos struct {
RoomID string `json:"room_id"`
EventID string `json:"event_id"`
PrevEventID string `json:"prev_event_id"`
}
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
s := &backwardExtremitiesStatements{
db: db,
}
_, err := db.Exec(backwardExtremitiesSchema)
type BackwardExtremityCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
BackwardExtremity BackwardExtremityCosmos `json:"mx_syncapi_backward_extremity"`
}
// const insertBackwardExtremitySQL = "" +
// "INSERT INTO syncapi_backward_extremities (room_id, event_id, prev_event_id)" +
// " VALUES ($1, $2, $3)" +
// " ON CONFLICT (room_id, event_id, prev_event_id) DO NOTHING"
// "SELECT event_id, prev_event_id FROM syncapi_backward_extremities WHERE room_id = $1"
const selectBackwardExtremitiesForRoomSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_account_data_type.room_id = @x2 "
// "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
const deleteBackwardExtremitySQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_account_data_type.room_id = @x2 " +
"and c.mx_syncapi_account_data_type.prev_event_id = @x3"
// "DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
const deleteBackwardExtremitiesForRoomSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_account_data_type.room_id = @x2 "
type backwardExtremitiesStatements struct {
db *SyncServerDatasource
// insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt string
deleteBackwardExtremityStmt string
deleteBackwardExtremitiesForRoomStmt string
tableName string
}
func queryBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]BackwardExtremityCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []BackwardExtremityCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil {
return nil, err
return response, nil
}
func deleteBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, dbData BackwardExtremityCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil {
return nil, err
}
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
return nil, err
}
if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil {
return nil, err
return err
}
func NewCosmosDBBackwardsExtremitiesTable(db *SyncServerDatasource) (tables.BackwardsExtremities, error) {
s := &backwardExtremitiesStatements{
db: db,
}
s.selectBackwardExtremitiesForRoomStmt = selectBackwardExtremitiesForRoomSQL
s.deleteBackwardExtremityStmt = deleteBackwardExtremitySQL
s.deleteBackwardExtremitiesForRoomStmt = deleteBackwardExtremitiesForRoomSQL
s.tableName = "backward_extremities"
return s, nil
}
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
return err
// "INSERT INTO syncapi_backward_extremities (room_id, event_id, prev_event_id)" +
// " VALUES ($1, $2, $3)" +
// " ON CONFLICT (room_id, event_id, prev_event_id) DO NOTHING"
// _, err = sqlutil.TxStmt(txn, s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// PRIMARY KEY(room_id, event_id, prev_event_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, eventID, prevEventID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := BackwardExtremityCosmos{
EventID: eventID,
PrevEventID: prevEventID,
RoomID: roomID,
}
dbData := &BackwardExtremityCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
BackwardExtremity: data,
}
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return
}
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
ctx context.Context, roomID string,
) (bwExtrems map[string][]string, err error) {
rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID)
// "SELECT event_id, prev_event_id FROM syncapi_backward_extremities WHERE room_id = $1"
// rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
rows, err := queryBackwardExtremity(s, ctx, s.selectBackwardExtremitiesForRoomStmt, params)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "selectBackwardExtremitiesForRoom: rows.close() failed")
bwExtrems = make(map[string][]string)
for rows.Next() {
for _, item := range rows {
var eID string
var prevEventID string
if err = rows.Scan(&eID, &prevEventID); err != nil {
return
}
eID = item.BackwardExtremity.EventID
prevEventID = item.BackwardExtremity.PrevEventID
bwExtrems[eID] = append(bwExtrems[eID], prevEventID)
}
return bwExtrems, rows.Err()
return bwExtrems, err
}
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return err
// "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
// _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
"@x3": knownEventID,
}
rows, err := queryBackwardExtremity(s, ctx, s.deleteBackwardExtremityStmt, params)
if err != nil {
return
}
for _, item := range rows {
err = deleteBackwardExtremity(s, ctx, item)
}
return
}
func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
return err
// "DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
// _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
rows, err := queryBackwardExtremity(s, ctx, s.deleteBackwardExtremitiesForRoomStmt, params)
if err != nil {
return
}
for _, item := range rows {
err = deleteBackwardExtremity(s, ctx, item)
}
return
}

View file

@ -20,108 +20,208 @@ import (
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
const currentRoomStateSchema = `
-- Stores the current room state for every room.
CREATE TABLE IF NOT EXISTS syncapi_current_room_state (
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
type TEXT NOT NULL,
sender TEXT NOT NULL,
contains_url BOOL NOT NULL DEFAULT false,
state_key TEXT NOT NULL,
headered_event_json TEXT NOT NULL,
membership TEXT,
added_at BIGINT,
UNIQUE (room_id, type, state_key)
);
-- for event deletion
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url);
-- for querying membership states of users
-- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave';
-- for querying state by event IDs
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id);
`
// const currentRoomStateSchema = `
// -- Stores the current room state for every room.
// CREATE TABLE IF NOT EXISTS syncapi_current_room_state (
// room_id TEXT NOT NULL,
// event_id TEXT NOT NULL,
// type TEXT NOT NULL,
// sender TEXT NOT NULL,
// contains_url BOOL NOT NULL DEFAULT false,
// state_key TEXT NOT NULL,
// headered_event_json TEXT NOT NULL,
// membership TEXT,
// added_at BIGINT,
// UNIQUE (room_id, type, state_key)
// );
// -- for event deletion
// CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url);
// -- for querying membership states of users
// -- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave';
// -- for querying state by event IDs
// CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id);
// `
const upsertRoomStateSQL = "" +
"INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
" ON CONFLICT (room_id, type, state_key)" +
" DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9"
type CurrentRoomStateCosmos struct {
RoomID string `json:"room_id"`
EventID string `json:"event_id"`
Type string `json:"type"`
Sender string `json:"sender"`
ContainsUrl bool `json:"contains_url"`
StateKey string `json:"state_key"`
HeaderedEventJSON []byte `json:"headered_event_json"`
Membership string `json:"membership"`
AddedAt int64 `json:"added_at"`
}
type CurrentRoomStateCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
CurrentRoomState CurrentRoomStateCosmos `json:"mx_syncapi_current_room_state"`
}
// const upsertRoomStateSQL = "" +
// "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at)" +
// " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
// " ON CONFLICT (room_id, type, state_key)" +
// " DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9"
// "DELETE FROM syncapi_current_room_state WHERE event_id = $1"
const deleteRoomStateByEventIDSQL = "" +
"DELETE FROM syncapi_current_room_state WHERE event_id = $1"
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_current_room_state.event_id = @x2 "
// TODO: Check the SQL is correct here
// "DELETE FROM syncapi_current_room_state WHERE event_id = $1"
const DeleteRoomStateForRoomSQL = "" +
"DELETE FROM syncapi_current_room_state WHERE event_id = $1"
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_current_room_state.room_id = @x2 "
// "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
const selectRoomIDsWithMembershipSQL = "" +
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
"select distinct c.mx_syncapi_current_room_state.room_id from c where c._cn = @x1 " +
"and c.mx_syncapi_current_room_state.type = \"m.room.member\" " +
"and c.mx_syncapi_current_room_state.state_key = @x2 " +
"and c.mx_syncapi_current_room_state.membership = @x3 "
// "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1"
// // WHEN, ORDER BY and LIMIT will be added by prepareWithFilter
const selectCurrentStateSQL = "" +
"SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1"
// WHEN, ORDER BY and LIMIT will be added by prepareWithFilter
"select top @x3 * from c where c._cn = @x1 " +
"and c.mx_syncapi_current_room_state.room_id = @x2 "
// // WHEN, ORDER BY (and LIMIT) will be added by prepareWithFilter
// "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
const selectJoinedUsersSQL = "" +
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_current_room_state.type = \"m.room.member\" " +
"and c.mx_syncapi_current_room_state.membership = \"join\" "
const selectStateEventSQL = "" +
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
// const selectStateEventSQL = "" +
// "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
// "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
// " FROM syncapi_current_room_state WHERE event_id IN ($1)"
const selectEventsWithEventIDsSQL = "" +
// TODO: The session_id and transaction_id blanks are here because otherwise
// the rowsToStreamEvents expects there to be exactly six columns. We need to
// figure out if these really need to be in the DB, and if so, we need a
// better permanent fix for this. - neilalexander, 2 Jan 2020
"SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
" FROM syncapi_current_room_state WHERE event_id IN ($1)"
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_syncapi_current_room_state.event_id) "
type currentRoomStateStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt
DeleteRoomStateForRoomStmt *sql.Stmt
selectRoomIDsWithMembershipStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt
selectStateEventStmt *sql.Stmt
db *SyncServerDatasource
streamIDStatements *streamIDStatements
// upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt string
DeleteRoomStateForRoomStmt string
selectRoomIDsWithMembershipStmt string
selectJoinedUsersStmt string
// selectStateEventStmt *sql.Stmt
tableName string
jsonPropertyName string
}
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
func queryCurrentRoomState(s *currentRoomStateStatements, ctx context.Context, qry string, params map[string]interface{}) ([]CurrentRoomStateCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []CurrentRoomStateCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryCurrentRoomStateDistinct(s *currentRoomStateStatements, ctx context.Context, qry string, params map[string]interface{}) ([]CurrentRoomStateCosmos, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []CurrentRoomStateCosmos
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getEvent(s *currentRoomStateStatements, ctx context.Context, pk string, docId string) (*CurrentRoomStateCosmosData, error) {
response := CurrentRoomStateCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func deleteCurrentRoomState(s *currentRoomStateStatements, ctx context.Context, dbData CurrentRoomStateCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
return err
}
func NewCosmosDBCurrentRoomStateTable(db *SyncServerDatasource, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{
db: db,
streamIDStatements: streamID,
}
_, err := db.Exec(currentRoomStateSchema)
if err != nil {
return nil, err
}
if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil {
return nil, err
}
if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil {
return nil, err
}
if s.DeleteRoomStateForRoomStmt, err = db.Prepare(DeleteRoomStateForRoomSQL); err != nil {
return nil, err
}
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
return nil, err
}
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
return nil, err
}
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
return nil, err
}
s.deleteRoomStateByEventIDStmt = deleteRoomStateByEventIDSQL
s.DeleteRoomStateForRoomStmt = DeleteRoomStateForRoomSQL
s.selectRoomIDsWithMembershipStmt = selectRoomIDsWithMembershipSQL
s.selectJoinedUsersStmt = selectJoinedUsersSQL
s.tableName = "current_room_states"
s.jsonPropertyName = "mx_syncapi_current_room_state"
return s, nil
}
@ -129,19 +229,27 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (t
func (s *currentRoomStateStatements) SelectJoinedUsers(
ctx context.Context,
) (map[string][]string, error) {
rows, err := s.selectJoinedUsersStmt.QueryContext(ctx)
// "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
// rows, err := s.selectJoinedUsersStmt.QueryContext(ctx)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
}
rows, err := queryCurrentRoomState(s, ctx, s.selectJoinedUsersStmt, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed")
result := make(map[string][]string)
for rows.Next() {
for _, item := range rows {
var roomID string
var userID string
if err := rows.Scan(&roomID, &userID); err != nil {
return nil, err
}
roomID = item.CurrentRoomState.RoomID
userID = item.CurrentRoomState.StateKey //StateKey and Not UserID - See the SQL above
users := result[roomID]
users = append(users, userID)
result[roomID] = users
@ -156,19 +264,28 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
userID string,
membership string, // nolint: unparam
) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithMembershipStmt)
rows, err := stmt.QueryContext(ctx, userID, membership)
// "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
// stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithMembershipStmt)
// rows, err := stmt.QueryContext(ctx, userID, membership)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": userID,
"@x3": membership,
}
rows, err := queryCurrentRoomStateDistinct(s, ctx, s.selectRoomIDsWithMembershipStmt, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsWithMembership: rows.close() failed")
var result []string
for rows.Next() {
for _, item := range rows {
var roomID string
if err := rows.Scan(&roomID); err != nil {
return nil, err
}
roomID = item.RoomID
result = append(result, roomID)
}
return result, nil
@ -180,41 +297,74 @@ func (s *currentRoomStateStatements) SelectCurrentState(
stateFilter *gomatrixserverlib.StateFilter,
excludeEventIDs []string,
) ([]*gomatrixserverlib.HeaderedEvent, error) {
stmt, params, err := prepareWithFilters(
s.db, txn, selectCurrentStateSQL,
[]interface{}{
roomID,
},
// "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1"
// // WHEN, ORDER BY and LIMIT will be added by prepareWithFilter
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
"@x3": stateFilter.Limit,
}
stmt, params := prepareWithFilters(
s.jsonPropertyName, selectCurrentStateSQL, params,
stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes,
excludeEventIDs, stateFilter.Limit, FilterOrderNone,
)
if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
}
rows, err := queryCurrentRoomState(s, ctx, stmt, params)
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectCurrentState: rows.close() failed")
return rowsToEvents(rows)
return rowsToEvents(&rows)
}
func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
_, err := stmt.ExecContext(ctx, eventID)
// "DELETE FROM syncapi_current_room_state WHERE event_id = $1"
// stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": eventID,
}
rows, err := queryCurrentRoomState(s, ctx, s.deleteRoomStateByEventIDStmt, params)
for _, item := range rows {
err = deleteCurrentRoomState(s, ctx, item)
}
return err
}
func (s *currentRoomStateStatements) DeleteRoomStateForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
stmt := sqlutil.TxStmt(txn, s.DeleteRoomStateForRoomStmt)
_, err := stmt.ExecContext(ctx, roomID)
// TODO: Check the SQL is correct here
// "DELETE FROM syncapi_current_room_state WHERE event_id = $1"
// stmt := sqlutil.TxStmt(txn, s.DeleteRoomStateForRoomStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
rows, err := queryCurrentRoomState(s, ctx, s.DeleteRoomStateForRoomStmt, params)
for _, item := range rows {
err = deleteCurrentRoomState(s, ctx, item)
}
return err
}
@ -235,20 +385,73 @@ func (s *currentRoomStateStatements) UpsertRoomState(
return err
}
// "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at)" +
// " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
// " ON CONFLICT (room_id, type, state_key)" +
// " DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9"
// TODO: Not sure how we can enfore these extra unique indexes
// CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url);
// -- for querying membership states of users
// -- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave';
// -- for querying state by event IDs
// CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id);
// upsert state event
stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
_, err = stmt.ExecContext(
// stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
// _, err = stmt.ExecContext(
// ctx,
// event.RoomID(),
// event.EventID(),
// event.Type(),
// event.Sender(),
// containsURL,
// *event.StateKey(),
// headeredJSON,
// membership,
// addedAt,
// )
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// " ON CONFLICT (room_id, type, state_key)" +
docId := fmt.Sprintf("%s_%s_%s", event.RoomID(), event.Type(), *event.StateKey())
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
membershipData := ""
if membership != nil {
membershipData = *membership
}
data := CurrentRoomStateCosmos{
RoomID: event.RoomID(),
EventID: event.EventID(),
Type: event.Type(),
Sender: event.Sender(),
ContainsUrl: containsURL,
StateKey: *event.StateKey(),
HeaderedEventJSON: headeredJSON,
Membership: membershipData,
AddedAt: int64(addedAt),
}
dbData := &CurrentRoomStateCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
CurrentRoomState: data,
}
// _, err = sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType, pos)
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
event.RoomID(),
event.EventID(),
event.Type(),
event.Sender(),
containsURL,
*event.StateKey(),
headeredJSON,
membership,
addedAt,
)
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err
}
@ -262,22 +465,33 @@ func minOfInts(a, b int) int {
func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StreamEvent, error) {
iEventIDs := make([]interface{}, len(eventIDs))
for k, v := range eventIDs {
iEventIDs[k] = v
}
// iEventIDs := make([]interface{}, len(eventIDs))
// for k, v := range eventIDs {
// iEventIDs[k] = v
// }
res := make([]types.StreamEvent, 0, len(eventIDs))
var start int
for start < len(eventIDs) {
n := minOfInts(len(eventIDs)-start, 999)
query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", sqlutil.QueryVariadic(n), 1)
rows, err := txn.QueryContext(ctx, query, iEventIDs[start:start+n]...)
// "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
// " FROM syncapi_current_room_state WHERE event_id IN ($1)"
// query := strings.Replace(selectEventsWithEventIDsSQL, "@x2", sql.QueryVariadic(n), 1)
// rows, err := txn.QueryContext(ctx, query, iEventIDs[start:start+n]...)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": eventIDs,
}
rows, err := queryCurrentRoomState(s, ctx, s.DeleteRoomStateForRoomStmt, params)
if err != nil {
return nil, err
}
start = start + n
events, err := rowsToStreamEvents(rows)
internal.CloseAndLogIfError(ctx, rows, "selectEventsWithEventIDs: rows.close() failed")
events, err := rowsToStreamEventsFromCurrentRoomState(&rows)
if err != nil {
return nil, err
}
@ -286,14 +500,58 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
return res, nil
}
func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) {
result := []*gomatrixserverlib.HeaderedEvent{}
for rows.Next() {
var eventID string
var eventBytes []byte
if err := rows.Scan(&eventID, &eventBytes); err != nil {
// Copied from output_room_events_table
func rowsToStreamEventsFromCurrentRoomState(rows *[]CurrentRoomStateCosmosData) ([]types.StreamEvent, error) {
var result []types.StreamEvent
for _, item := range *rows {
var (
eventID string
streamPos types.StreamPosition
eventBytes []byte
excludeFromSync bool
// Not required for this call, see output_room_events_table
// sessionID *int64
// txnID *string
// transactionID *api.TransactionID
)
// if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil {
// return nil, err
// }
// Taken from the SQL above
eventID = item.CurrentRoomState.EventID
streamPos = types.StreamPosition(item.CurrentRoomState.AddedAt)
// TODO: Handle redacted events
var ev gomatrixserverlib.HeaderedEvent
if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
return nil, err
}
// Always null for this use-case
// if sessionID != nil && txnID != nil {
// transactionID = &api.TransactionID{
// SessionID: *sessionID,
// TransactionID: *txnID,
// }
// }
result = append(result, types.StreamEvent{
HeaderedEvent: &ev,
StreamPosition: streamPos,
TransactionID: nil,
ExcludeFromSync: excludeFromSync,
})
}
return result, nil
}
func rowsToEvents(rows *[]CurrentRoomStateCosmosData) ([]*gomatrixserverlib.HeaderedEvent, error) {
result := []*gomatrixserverlib.HeaderedEvent{}
for _, item := range *rows {
var eventID string
var eventBytes []byte
eventID = item.CurrentRoomState.EventID
eventBytes = item.CurrentRoomState.HeaderedEventJSON
// TODO: Handle redacted events
var ev gomatrixserverlib.HeaderedEvent
if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
@ -307,15 +565,25 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) {
func (s *currentRoomStateStatements) SelectStateEvent(
ctx context.Context, roomID, evType, stateKey string,
) (*gomatrixserverlib.HeaderedEvent, error) {
stmt := s.selectStateEventStmt
// stmt := s.selectStateEventStmt
var res []byte
err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res)
if err == sql.ErrNoRows {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// " ON CONFLICT (room_id, type, state_key)" +
docId := fmt.Sprintf("%s_%s_%s", roomID, evType, stateKey)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var response, err = getEvent(s, ctx, pk, cosmosDocId)
// err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res)
if err == cosmosdbutil.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
res = response.CurrentRoomState.HeaderedEventJSON
var ev gomatrixserverlib.HeaderedEvent
if err = json.Unmarshal(res, &ev); err != nil {
return nil, err

View file

@ -0,0 +1,59 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpFixSequences, DownFixSequences)
goose.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
}
func LoadFixSequences(m *sqlutil.Migrations) {
m.AddMigration(UpFixSequences, DownFixSequences)
}
func UpFixSequences(tx *sql.Tx) error {
_, err := tx.Exec(`
-- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence.
DELETE FROM syncapi_receipts;
UPDATE syncapi_stream_id SET stream_id=1 WHERE stream_name="receipt";
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownFixSequences(tx *sql.Tx) error {
_, err := tx.Exec(`
-- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence.
DELETE FROM syncapi_receipts;
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -0,0 +1,67 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadRemoveSendToDeviceSentColumn(m *sqlutil.Migrations) {
m.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
}
func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content);
INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device;
DROP TABLE syncapi_send_to_device;
CREATE TABLE syncapi_send_to_device(
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
content TEXT NOT NULL
);
INSERT INTO syncapi_send_to_device SELECT id, user_id, device_id, content FROM syncapi_send_to_device_backup;
DROP TABLE syncapi_send_to_device_backup;
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content);
INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device;
DROP TABLE syncapi_send_to_device;
CREATE TABLE syncapi_send_to_device(
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
content TEXT NOT NULL,
sent_by_token TEXT
);
INSERT INTO syncapi_send_to_device SELECT id, user_id, device_id, content FROM syncapi_send_to_device_backup;
DROP TABLE syncapi_send_to_device_backup;
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}

View file

@ -16,80 +16,147 @@ package cosmosdb
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
const filterSchema = `
-- Stores data about filters
CREATE TABLE IF NOT EXISTS syncapi_filter (
-- The filter
filter TEXT NOT NULL,
-- The ID
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The localpart of the Matrix user ID associated to this filter
localpart TEXT NOT NULL,
// const filterSchema = `
// -- Stores data about filters
// CREATE TABLE IF NOT EXISTS syncapi_filter (
// -- The filter
// filter TEXT NOT NULL,
// -- The ID
// id INTEGER PRIMARY KEY AUTOINCREMENT,
// -- The localpart of the Matrix user ID associated to this filter
// localpart TEXT NOT NULL,
UNIQUE (id, localpart)
);
// UNIQUE (id, localpart)
// );
CREATE INDEX IF NOT EXISTS syncapi_filter_localpart ON syncapi_filter(localpart);
`
// CREATE INDEX IF NOT EXISTS syncapi_filter_localpart ON syncapi_filter(localpart);
// `
const selectFilterSQL = "" +
"SELECT filter FROM syncapi_filter WHERE localpart = $1 AND id = $2"
const selectFilterIDByContentSQL = "" +
"SELECT id FROM syncapi_filter WHERE localpart = $1 AND filter = $2"
const insertFilterSQL = "" +
"INSERT INTO syncapi_filter (filter, localpart) VALUES ($1, $2)"
type filterStatements struct {
db *sql.DB
selectFilterStmt *sql.Stmt
selectFilterIDByContentStmt *sql.Stmt
insertFilterStmt *sql.Stmt
type FilterCosmos struct {
ID int64 `json:"id"`
Filter []byte `json:"filter"`
Localpart string `json:"localpart"`
}
func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
_, err := db.Exec(filterSchema)
type FilterCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Filter FilterCosmos `json:"mx_syncapi_filter"`
}
// const selectFilterSQL = "" +
// "SELECT filter FROM syncapi_filter WHERE localpart = $1 AND id = $2"
// "SELECT id FROM syncapi_filter WHERE localpart = $1 AND filter = $2"
const selectFilterIDByContentSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_filter.localpart = @x2 " +
"and c.mx_syncapi_filter.filter = @x3 "
// const insertFilterSQL = "" +
// "INSERT INTO syncapi_filter (filter, localpart) VALUES ($1, $2)"
type filterStatements struct {
db *SyncServerDatasource
// selectFilterStmt *sql.Stmt
selectFilterIDByContentStmt string
// insertFilterStmt *sql.Stmt
tableName string
}
func queryFilter(s *filterStatements, ctx context.Context, qry string, params map[string]interface{}) ([]FilterCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []FilterCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
if len(response) == 0 {
return nil, cosmosdbutil.ErrNoRows
}
return response, nil
}
func getFilter(s *filterStatements, ctx context.Context, pk string, docId string) (*FilterCosmosData, error) {
response := FilterCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, nil
}
return &response, err
}
func NewCosmosDBFilterTable(db *SyncServerDatasource) (tables.Filter, error) {
s := &filterStatements{
db: db,
}
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return nil, err
}
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
return nil, err
}
if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
return nil, err
}
s.selectFilterIDByContentStmt = selectFilterIDByContentSQL
s.tableName = "filters"
return s, nil
}
func (s *filterStatements) SelectFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
// "SELECT filter FROM syncapi_filter WHERE localpart = $1 AND id = $2"
// Retrieve filter from database (stored as canonical JSON)
var filterData []byte
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
// err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (id, localpart)
docId := fmt.Sprintf("%s_%s", localpart, filterID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, err = getFilter(s, ctx, pk, cosmosDocId)
if err != nil {
return nil, err
}
// Unmarshal JSON into Filter struct
filter := gomatrixserverlib.DefaultFilter()
if err = json.Unmarshal(filterData, &filter); err != nil {
return nil, err
if response != nil {
filterData = response.Filter.Filter
if err = json.Unmarshal(filterData, &filter); err != nil {
return nil, err
}
}
return &filter, nil
}
@ -97,6 +164,9 @@ func (s *filterStatements) SelectFilter(
func (s *filterStatements) InsertFilter(
ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
) (filterID string, err error) {
// "INSERT INTO syncapi_filter (filter, localpart) VALUES ($1, $2)"
var existingFilterID string
// Serialise json
@ -116,25 +186,73 @@ func (s *filterStatements) InsertFilter(
// This can result in a race condition when two clients try to insert the
// same filter and localpart at the same time, however this is not a
// problem as both calls will result in the same filterID
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
localpart, filterJSON).Scan(&existingFilterID)
if err != nil && err != sql.ErrNoRows {
// err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
// localpart, filterJSON).Scan(&existingFilterID)
// TODO: See if we can avoid the search by Content []byte
// "SELECT id FROM syncapi_filter WHERE localpart = $1 AND filter = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
"@x3": filterJSON,
}
response, err := queryFilter(s, ctx, s.selectFilterIDByContentStmt, params)
if err != nil && err != cosmosdbutil.ErrNoRows {
return "", err
}
if response != nil {
existingFilterID = fmt.Sprintf("%d", response[0].Filter.ID)
}
// If it does, return the existing ID
if existingFilterID != "" {
return existingFilterID, nil
}
// Otherwise insert the filter and return the new ID
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
if err != nil {
return "", err
}
rowid, err := res.LastInsertId()
// res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
// id INTEGER PRIMARY KEY AUTOINCREMENT,
seqID, seqErr := GetNextFilterID(s, ctx)
if seqErr != nil {
return "", seqErr
}
data := FilterCosmos{
ID: seqID,
Localpart: localpart,
Filter: filterJSON,
}
// UNIQUE (id, localpart)
docId := fmt.Sprintf("%s_%d", localpart, seqID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var dbData = FilterCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
Filter: data,
}
var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData,
optionsCreate)
if err != nil {
return "", err
}
rowid := seqID
filterID = fmt.Sprintf("%d", rowid)
return
}

View file

@ -0,0 +1,12 @@
package cosmosdb
import (
"context"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
)
func GetNextFilterID(s *filterStatements, ctx context.Context) (int64, error) {
const docId = "id_seq"
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1)
}

View file

@ -1,10 +1,7 @@
package cosmosdb
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
type FilterOrder int
@ -15,6 +12,10 @@ const (
FilterOrderDesc
)
func getParamName(offset int) string {
return fmt.Sprintf("@x%d", offset)
}
// prepareWithFilters returns a prepared statement with the
// relevant filters included. It also includes an []interface{}
// list of all the relevant parameters to pass straight to
@ -24,59 +25,54 @@ const (
// and it's easier just to have the caller extract the relevant
// parts.
func prepareWithFilters(
db *sql.DB, txn *sql.Tx, query string, params []interface{},
collectionName string, query string, params map[string]interface{},
senders, notsenders, types, nottypes []string, excludeEventIDs []string,
limit int, order FilterOrder,
) (*sql.Stmt, []interface{}, error) {
) (sql string, paramsResult map[string]interface{}) {
offset := len(params)
if count := len(senders); count > 0 {
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range senders {
params, offset = append(params, v), offset+1
}
sql = query
paramsResult = params
// "and (@x4 = null OR ARRAY_CONTAINS(@x4, c.mx_syncapi_current_room_state.sender)) " +
if len(senders) > 0 {
offset++
paramName := getParamName(offset)
sql += fmt.Sprintf("and ARRAY_CONTAINS(%s, c.%s.sender) ", paramName, collectionName)
paramsResult[paramName] = senders
}
if count := len(notsenders); count > 0 {
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range notsenders {
params, offset = append(params, v), offset+1
}
// "and (@x5 = null OR NOT ARRAY_CONTAINS(@x5, c.mx_syncapi_current_room_state.sender)) " +
if len(notsenders) > 0 {
offset++
paramName := getParamName(offset)
sql += fmt.Sprintf("and NOT ARRAY_CONTAINS(%s, c.%s.sender) ", paramName, collectionName)
paramsResult[getParamName(offset)] = notsenders
}
if count := len(types); count > 0 {
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range types {
params, offset = append(params, v), offset+1
}
// "and (@x6 = null OR ARRAY_CONTAINS(@x6, c.mx_syncapi_current_room_state.type)) " +
if len(types) > 0 {
offset++
paramName := getParamName(offset)
sql += fmt.Sprintf("and ARRAY_CONTAINS(%s, c.%s.type) ", paramName, collectionName)
paramsResult[paramName] = types
}
if count := len(nottypes); count > 0 {
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range nottypes {
params, offset = append(params, v), offset+1
}
// "and (@x7 = null OR NOT ARRAY_CONTAINS(@x7, c.mx_syncapi_current_room_state.type)) " +
if len(nottypes) > 0 {
offset++
paramName := getParamName(offset)
sql += fmt.Sprintf("and NOT ARRAY_CONTAINS(%s, c.%s.type) ", paramName, collectionName)
paramsResult[getParamName(offset)] = nottypes
}
if count := len(excludeEventIDs); count > 0 {
query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range excludeEventIDs {
params, offset = append(params, v), offset+1
}
// "and (NOT ARRAY_CONTAINS(@x9, c.mx_syncapi_current_room_state.event_id)) "
if len(excludeEventIDs) > 0 {
offset++
paramName := getParamName(offset)
sql += fmt.Sprintf("and NOT ARRAY_CONTAINS(%s, c.%s.event_id) ", paramName, collectionName)
paramsResult[getParamName(offset)] = excludeEventIDs
}
switch order {
case FilterOrderAsc:
query += " ORDER BY id ASC"
sql += fmt.Sprintf("order by c.%s.event_id asc ", collectionName)
case FilterOrderDesc:
query += " ORDER BY id DESC"
sql += fmt.Sprintf("order by c.%s.event_id desc ", collectionName)
}
query += fmt.Sprintf(" LIMIT $%d", offset+1)
params = append(params, limit)
var stmt *sql.Stmt
var err error
if txn != nil {
stmt, err = txn.Prepare(query)
} else {
stmt, err = db.Prepare(query)
}
if err != nil {
return nil, nil, fmt.Errorf("s.db.Prepare: %w", err)
}
return stmt, params, nil
// query += fmt.Sprintf(" LIMIT $%d", offset+1)
return
}

View file

@ -19,80 +19,179 @@ import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
const inviteEventsSchema = `
CREATE TABLE IF NOT EXISTS syncapi_invite_events (
id INTEGER PRIMARY KEY,
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
target_user_id TEXT NOT NULL,
headered_event_json TEXT NOT NULL,
deleted BOOL NOT NULL
);
// const inviteEventsSchema = `
// CREATE TABLE IF NOT EXISTS syncapi_invite_events (
// id INTEGER PRIMARY KEY,
// event_id TEXT NOT NULL,
// room_id TEXT NOT NULL,
// target_user_id TEXT NOT NULL,
// headered_event_json TEXT NOT NULL,
// deleted BOOL NOT NULL
// );
CREATE INDEX IF NOT EXISTS syncapi_invites_target_user_id_idx ON syncapi_invite_events (target_user_id, id);
CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx ON syncapi_invite_events (event_id);
`
// CREATE INDEX IF NOT EXISTS syncapi_invites_target_user_id_idx ON syncapi_invite_events (target_user_id, id);
// CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx ON syncapi_invite_events (event_id);
// `
const insertInviteEventSQL = "" +
"INSERT INTO syncapi_invite_events" +
" (id, room_id, event_id, target_user_id, headered_event_json, deleted)" +
" VALUES ($1, $2, $3, $4, $5, false)"
const deleteInviteEventSQL = "" +
"UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2"
const selectInviteEventsInRangeSQL = "" +
"SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" +
" WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC"
const selectMaxInviteIDSQL = "" +
"SELECT MAX(id) FROM syncapi_invite_events"
type inviteEventsStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
insertInviteEventStmt *sql.Stmt
selectInviteEventsInRangeStmt *sql.Stmt
deleteInviteEventStmt *sql.Stmt
selectMaxInviteIDStmt *sql.Stmt
type InviteEventCosmos struct {
ID int64 `json:"id"`
EventID string `json:"event_id"`
RoomID string `json:"room_id"`
TargetUserID string `json:"target_user_id"`
HeaderedEventJSON []byte `json:"headered_event_json"`
Deleted bool `json:"deleted"`
}
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
type InviteEventCosmosMaxNumber struct {
Max int64 `json:"number"`
}
type InviteEventCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
InviteEvent InviteEventCosmos `json:"mx_syncapi_invite_event"`
}
// const insertInviteEventSQL = "" +
// "INSERT INTO syncapi_invite_events" +
// " (id, room_id, event_id, target_user_id, headered_event_json, deleted)" +
// " VALUES ($1, $2, $3, $4, $5, false)"
// "UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2"
const deleteInviteEventSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_invite_event.event_id = @x2 "
// "SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" +
// " WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
// " ORDER BY id DESC"
const selectInviteEventsInRangeSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_invite_event.target_user_id = @x2 " +
"and c.mx_syncapi_invite_event.id > @x3 " +
"and c.mx_syncapi_invite_event.id <= @x4 " +
"order by c.mx_syncapi_invite_event.id desc "
// "SELECT MAX(id) FROM syncapi_invite_events"
const selectMaxInviteIDSQL = "" +
"select max(c.mx_syncapi_invite_event.id) from c where c._cn = @x1 "
type inviteEventsStatements struct {
db *SyncServerDatasource
streamIDStatements *streamIDStatements
// insertInviteEventStmt *sql.Stmt
selectInviteEventsInRangeStmt string
deleteInviteEventStmt string
selectMaxInviteIDStmt string
tableName string
}
func queryInviteEvent(s *inviteEventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InviteEventCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []InviteEventCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryInviteEventMaxNumber(s *inviteEventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InviteEventCosmosMaxNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []InviteEventCosmosMaxNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, nil
}
return response, nil
}
func getInviteEvent(s *inviteEventsStatements, ctx context.Context, pk string, docId string) (*InviteEventCosmosData, error) {
response := InviteEventCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
}
return &response, err
}
func setInviteEvent(s *inviteEventsStatements, ctx context.Context, invite InviteEventCosmosData) (*InviteEventCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(invite.Pk, invite.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
invite.Id,
&invite,
optionsReplace)
return &invite, ex
}
func NewCosmosDBInvitesTable(db *SyncServerDatasource, streamID *streamIDStatements) (tables.Invites, error) {
s := &inviteEventsStatements{
db: db,
streamIDStatements: streamID,
}
_, err := db.Exec(inviteEventsSchema)
if err != nil {
return nil, err
}
if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil {
return nil, err
}
if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil {
return nil, err
}
if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil {
return nil, err
}
if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil {
return nil, err
}
s.selectInviteEventsInRangeStmt = selectInviteEventsInRangeSQL
s.deleteInviteEventStmt = deleteInviteEventSQL
s.selectMaxInviteIDStmt = selectMaxInviteIDSQL
s.tableName = "invite_events"
return s, nil
}
func (s *inviteEventsStatements) InsertInviteEvent(
ctx context.Context, txn *sql.Tx, inviteEvent *gomatrixserverlib.HeaderedEvent,
) (streamPos types.StreamPosition, err error) {
// "INSERT INTO syncapi_invite_events" +
// " (id, room_id, event_id, target_user_id, headered_event_json, deleted)" +
// " VALUES ($1, $2, $3, $4, $5, false)"
streamPos, err = s.streamIDStatements.nextInviteID(ctx, txn)
if err != nil {
return
@ -104,15 +203,45 @@ func (s *inviteEventsStatements) InsertInviteEvent(
return
}
stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
_, err = stmt.ExecContext(
// stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
// _, err = stmt.ExecContext(
// ctx,
// streamPos,
// inviteEvent.RoomID(),
// inviteEvent.EventID(),
// *inviteEvent.StateKey(),
// headeredJSON,
// )
data := InviteEventCosmos{
ID: int64(streamPos),
RoomID: inviteEvent.RoomID(),
EventID: inviteEvent.EventID(),
TargetUserID: *inviteEvent.StateKey(),
HeaderedEventJSON: headeredJSON,
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// id INTEGER PRIMARY KEY,
docId := fmt.Sprintf("%d", streamPos)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var dbData = InviteEventCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
InviteEvent: data,
}
var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
streamPos,
inviteEvent.RoomID(),
inviteEvent.EventID(),
*inviteEvent.StateKey(),
headeredJSON,
)
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData,
optionsCreate)
return
}
@ -123,8 +252,23 @@ func (s *inviteEventsStatements) DeleteInviteEvent(
if err != nil {
return streamPos, err
}
stmt := sqlutil.TxStmt(txn, s.deleteInviteEventStmt)
_, err = stmt.ExecContext(ctx, streamPos, inviteEventID)
// "UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2"
// stmt := sqlutil.TxStmt(txn, s.deleteInviteEventStmt)
// _, err = stmt.ExecContext(ctx, streamPos, inviteEventID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": inviteEventID,
}
response, err := queryInviteEvent(s, ctx, s.deleteInviteEventStmt, params)
for _, item := range response {
item.InviteEvent.Deleted = true
item.InviteEvent.ID = int64(streamPos)
setInviteEvent(s, ctx, item)
}
return streamPos, err
}
@ -133,23 +277,39 @@ func (s *inviteEventsStatements) DeleteInviteEvent(
func (s *inviteEventsStatements) SelectInviteEventsInRange(
ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range,
) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt)
rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
// "SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" +
// " WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
// " ORDER BY id DESC"
// stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt)
// rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": targetUserID,
"@x3": r.Low(),
"@x4": r.High(),
}
rows, err := queryInviteEvent(s, ctx, s.selectInviteEventsInRangeStmt, params)
if err != nil {
return nil, nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed")
result := map[string]*gomatrixserverlib.HeaderedEvent{}
retired := map[string]*gomatrixserverlib.HeaderedEvent{}
for rows.Next() {
for _, item := range rows {
var (
roomID string
eventJSON []byte
deleted bool
)
if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil {
return nil, nil, err
}
roomID = item.InviteEvent.RoomID
eventJSON = item.InviteEvent.HeaderedEventJSON
deleted = item.InviteEvent.Deleted
// if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil {
// return nil, nil, err
// }
// if we have seen this room before, it has a higher stream position and hence takes priority
// because the query is ORDER BY id DESC so drop them
@ -176,8 +336,21 @@ func (s *inviteEventsStatements) SelectMaxInviteID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxInviteIDStmt)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
// "SELECT MAX(id) FROM syncapi_invite_events"
// stmt := sqlutil.TxStmt(txn, s.selectMaxInviteIDStmt)
// err = stmt.QueryRowContext(ctx).Scan(&nullableID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
}
response, err := queryInviteEventMaxNumber(s, ctx, s.selectMaxInviteIDStmt, params)
if response != nil {
nullableID.Int64 = response[0].Max
}
if nullableID.Valid {
id = nullableID.Int64
}

View file

@ -18,9 +18,10 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
@ -32,53 +33,92 @@ import (
// a room, either by choice or otherwise. This is important for
// building history visibility.
const membershipsSchema = `
CREATE TABLE IF NOT EXISTS syncapi_memberships (
-- The 'room_id' key for the state event.
room_id TEXT NOT NULL,
-- The state event ID
user_id TEXT NOT NULL,
-- The status of the membership
membership TEXT NOT NULL,
-- The event ID that last changed the membership
event_id TEXT NOT NULL,
-- The stream position of the change
stream_pos BIGINT NOT NULL,
-- The topological position of the change in the room
topological_pos BIGINT NOT NULL,
-- Unique index
UNIQUE (room_id, user_id, membership)
);
`
// const membershipsSchema = `
// CREATE TABLE IF NOT EXISTS syncapi_memberships (
// -- The 'room_id' key for the state event.
// room_id TEXT NOT NULL,
// -- The state event ID
// user_id TEXT NOT NULL,
// -- The status of the membership
// membership TEXT NOT NULL,
// -- The event ID that last changed the membership
// event_id TEXT NOT NULL,
// -- The stream position of the change
// stream_pos BIGINT NOT NULL,
// -- The topological position of the change in the room
// topological_pos BIGINT NOT NULL,
// -- Unique index
// UNIQUE (room_id, user_id, membership)
// );
// `
const upsertMembershipSQL = "" +
"INSERT INTO syncapi_memberships (room_id, user_id, membership, event_id, stream_pos, topological_pos)" +
" VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT (room_id, user_id, membership)" +
" DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
const selectMembershipSQL = "" +
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
" WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" +
" ORDER BY stream_pos DESC" +
" LIMIT 1"
type membershipsStatements struct {
db *sql.DB
upsertMembershipStmt *sql.Stmt
type MembershipCosmos struct {
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
Membership string `json:"membership"`
EventID string `json:"event_id"`
StreamPos int64 `json:"stream_pos"`
TopologicalPos int64 `json:"topological_pos"`
}
func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
s := &membershipsStatements{
db: db,
}
_, err := db.Exec(membershipsSchema)
type MembershipCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Membership MembershipCosmos `json:"mx_syncapi_membership"`
}
// const upsertMembershipSQL = "" +
// "INSERT INTO syncapi_memberships (room_id, user_id, membership, event_id, stream_pos, topological_pos)" +
// " VALUES ($1, $2, $3, $4, $5, $6)" +
// " ON CONFLICT (room_id, user_id, membership)" +
// " DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
// "SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
// " WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" +
// " ORDER BY stream_pos DESC" +
// " LIMIT 1"
const selectMembershipSQL = "" +
"select top 1 * from c where c._cn = @x1 " +
"and c.mx_syncapi_membership.room_id = @x2 " +
"and c.mx_syncapi_membership.user_id = @x3 " +
"and ARRAY_CONTAINS(@x4, c.mx_syncapi_membership.membership) " +
"order by c.mx_syncapi_membership.stream_pos desc "
type membershipsStatements struct {
db *SyncServerDatasource
// upsertMembershipStmt *sql.Stmt
tableName string
}
func queryMembership(s *membershipsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]MembershipCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []MembershipCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil {
return nil, err
return response, nil
}
func NewCosmosDBMembershipsTable(db *SyncServerDatasource) (tables.Memberships, error) {
s := &membershipsStatements{
db: db,
}
s.tableName = "memberships"
return s, nil
}
@ -90,30 +130,86 @@ func (s *membershipsStatements) UpsertMembership(
if err != nil {
return fmt.Errorf("event.Membership: %w", err)
}
_, err = sqlutil.TxStmt(txn, s.upsertMembershipStmt).ExecContext(
// "INSERT INTO syncapi_memberships (room_id, user_id, membership, event_id, stream_pos, topological_pos)" +
// " VALUES ($1, $2, $3, $4, $5, $6)" +
// " ON CONFLICT (room_id, user_id, membership)" +
// " DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
// _, err = sqlutil.TxStmt(txn, s.upsertMembershipStmt).ExecContext(
// ctx,
// event.RoomID(),
// *event.StateKey(),
// membership,
// event.EventID(),
// streamPos,
// topologicalPos,
// )
data := MembershipCosmos{
RoomID: event.RoomID(),
UserID: *event.StateKey(),
Membership: membership,
EventID: event.EventID(),
StreamPos: int64(streamPos),
TopologicalPos: int64(topologicalPos),
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// UNIQUE (room_id, user_id, membership)
docId := fmt.Sprintf("%s_%s_%s", event.RoomID(), *event.StateKey(), membership)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var dbData = MembershipCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
Membership: data,
}
var optionsCreate = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
event.RoomID(),
*event.StateKey(),
membership,
event.EventID(),
streamPos,
topologicalPos,
)
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData,
optionsCreate)
return err
}
func (s *membershipsStatements) SelectMembership(
ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
params := []interface{}{roomID, userID}
for _, membership := range memberships {
params = append(params, membership)
// params := []interface{}{roomID, userID}
// for _, membership := range memberships {
// params = append(params, membership)
// }
// "SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
// " WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" +
// " ORDER BY stream_pos DESC" +
// " LIMIT 1"
// err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
"@x3": userID,
"@x4": memberships,
}
orig := strings.Replace(selectMembershipSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
stmt, err := s.db.Prepare(orig)
if err != nil {
// orig := strings.Replace(selectMembershipSQL, "@x4", cosmosdbutil.QueryVariadicOffset(len(memberships), 2), 1)
rows, err := queryMembership(s, ctx, selectMembershipSQL, params)
if err != nil || len(rows) == 0 {
return "", 0, 0, err
}
err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos)
// err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos)
eventID = rows[0].Membership.EventID
streamPos = types.StreamPosition(rows[0].Membership.StreamPos)
topologyPos = types.StreamPosition(rows[0].Membership.TopologicalPos)
return
}

View file

@ -21,109 +21,222 @@ import (
"encoding/json"
"fmt"
"sort"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
const outputRoomEventsSchema = `
-- Stores output room events received from the roomserver.
CREATE TABLE IF NOT EXISTS syncapi_output_room_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
event_id TEXT NOT NULL UNIQUE,
room_id TEXT NOT NULL,
headered_event_json TEXT NOT NULL,
type TEXT NOT NULL,
sender TEXT NOT NULL,
contains_url BOOL NOT NULL,
add_state_ids TEXT, -- JSON encoded string array
remove_state_ids TEXT, -- JSON encoded string array
session_id BIGINT,
transaction_id TEXT,
exclude_from_sync BOOL NOT NULL DEFAULT FALSE
);
`
// const outputRoomEventsSchema = `
// -- Stores output room events received from the roomserver.
// CREATE TABLE IF NOT EXISTS syncapi_output_room_events (
// id INTEGER PRIMARY KEY AUTOINCREMENT,
// event_id TEXT NOT NULL UNIQUE,
// room_id TEXT NOT NULL,
// headered_event_json TEXT NOT NULL,
// type TEXT NOT NULL,
// sender TEXT NOT NULL,
// contains_url BOOL NOT NULL,
// add_state_ids TEXT, -- JSON encoded string array
// remove_state_ids TEXT, -- JSON encoded string array
// session_id BIGINT,
// transaction_id TEXT,
// exclude_from_sync BOOL NOT NULL DEFAULT FALSE
// );
// `
const insertEventSQL = "" +
"INSERT INTO syncapi_output_room_events (" +
"id, room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" +
") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " +
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
const selectEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1"
const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectRecentEventsForSyncSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectEarlyEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectMaxEventIDSQL = "" +
"SELECT MAX(id) FROM syncapi_output_room_events"
const updateEventJSONSQL = "" +
"UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2"
const selectStateInRangeSQL = "" +
"SELECT id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" +
" FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2)" +
" AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const deleteEventsForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events WHERE room_id = $1"
type outputRoomEventsStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt
updateEventJSONStmt *sql.Stmt
deleteEventsForRoomStmt *sql.Stmt
type OutputRoomEventCosmos struct {
ID int64 `json:"id"`
EventID string `json:"event_id"`
RoomID string `json:"room_id"`
HeaderedEventJSON []byte `json:"headered_event_json"`
Type string `json:"type"`
Sender string `json:"sender"`
ContainsUrl bool `json:"contains_url"`
AddStateIDs string `json:"add_state_ids"`
RemoveStateIDs string `json:"remove_state_ids"`
SessionID int64 `json:"session_id"`
TransactionID string `json:"transaction_id"`
ExcludeFromSync bool `json:"exclude_from_sync"`
}
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
type OutputRoomEventCosmosMaxNumber struct {
Max int64 `json:"number"`
}
type OutputRoomEventCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
OutputRoomEvent OutputRoomEventCosmos `json:"mx_syncapi_output_room_event"`
}
// const insertEventSQL = "" +
// "INSERT INTO syncapi_output_room_events (" +
// "id, room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" +
// ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " +
// "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
// "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1"
const selectEventsSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_output_room_event.event_id = @x2 "
// "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
// " WHERE room_id = $1 AND id > $2 AND id <= $3"
// // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectRecentEventsSQL = "" +
"select top @x5 * from c where c._cn = @x1 " +
"and c.mx_syncapi_output_room_event.room_id = @x2 " +
"and c.mx_syncapi_output_room_event.id > @x3 " +
"and c.mx_syncapi_output_room_event.id <= @x4 "
// "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
// " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE"
// // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectRecentEventsForSyncSQL = "" +
"select top @x5 * from c where c._cn = @x1 " +
"and c.mx_syncapi_output_room_event.room_id = @x2 " +
"and c.mx_syncapi_output_room_event.id > @x3 " +
"and c.mx_syncapi_output_room_event.id <= @x4 " +
"and c.mx_syncapi_output_room_event.exclude_from_sync = false "
// "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
// " WHERE room_id = $1 AND id > $2 AND id <= $3"
// // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectEarlyEventsSQL = "" +
"select top @x5 * from c where c._cn = @x1 " +
"and c.mx_syncapi_output_room_event.room_id = @x2 " +
"and c.mx_syncapi_output_room_event.id > @x3 " +
"and c.mx_syncapi_output_room_event.id <= @x4 "
// "SELECT MAX(id) FROM syncapi_output_room_events"
const selectMaxEventIDSQL = "" +
"select max(c.mx_syncapi_output_room_event.id) as number from c where c._cn = @x1 "
// "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2"
const updateEventJSONSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_output_room_event.event_id = @x2 "
// "SELECT id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" +
// " FROM syncapi_output_room_events" +
// " WHERE (id > $1 AND id <= $2)" +
// " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))"
// // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectStateInRangeSQL = "" +
"select top @x4 * from c where c._cn = @x1 " +
"and c.mx_syncapi_output_room_event.id > @x2 " +
"and c.mx_syncapi_output_room_event.id <= @x3 " +
"and (c.mx_syncapi_output_room_event.add_state_ids != null or c.mx_syncapi_output_room_event.remove_state_ids != null) "
// "DELETE FROM syncapi_output_room_events WHERE room_id = $1"
const deleteEventsForRoomSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_output_room_event.room_id = @x2 "
type outputRoomEventsStatements struct {
db *SyncServerDatasource
streamIDStatements *streamIDStatements
// insertEventStmt *sql.Stmt
selectEventsStmt string
selectMaxEventIDStmt string
updateEventJSONStmt string
deleteEventsForRoomStmt string
tableName string
jsonPropertyName string
}
func queryOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OutputRoomEventCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []OutputRoomEventCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryOutputRoomEventNumber(s *outputRoomEventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OutputRoomEventCosmosMaxNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []OutputRoomEventCosmosMaxNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, cosmosdbutil.ErrNoRows
}
return response, nil
}
func setOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, outputRoomEvent OutputRoomEventCosmosData) (*OutputRoomEventCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(outputRoomEvent.Pk, outputRoomEvent.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
outputRoomEvent.Id,
&outputRoomEvent,
optionsReplace)
return &outputRoomEvent, ex
}
func deleteOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, dbData OutputRoomEventCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
return err
}
func NewCosmosDBEventsTable(db *SyncServerDatasource, streamID *streamIDStatements) (tables.Events, error) {
s := &outputRoomEventsStatements{
db: db,
streamIDStatements: streamID,
}
_, err := db.Exec(outputRoomEventsSchema)
if err != nil {
return nil, err
}
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
return nil, err
}
if s.selectEventsStmt, err = db.Prepare(selectEventsSQL); err != nil {
return nil, err
}
if s.selectMaxEventIDStmt, err = db.Prepare(selectMaxEventIDSQL); err != nil {
return nil, err
}
if s.updateEventJSONStmt, err = db.Prepare(updateEventJSONSQL); err != nil {
return nil, err
}
if s.deleteEventsForRoomStmt, err = db.Prepare(deleteEventsForRoomSQL); err != nil {
return nil, err
}
s.selectEventsStmt = selectEventsSQL
s.selectMaxEventIDStmt = selectMaxEventIDSQL
s.updateEventJSONStmt = updateEventJSONSQL
s.deleteEventsForRoomStmt = deleteEventsForRoomSQL
s.tableName = "output_room_events"
s.jsonPropertyName = "mx_syncapi_output_room_event"
return s, nil
}
@ -132,7 +245,27 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event
if err != nil {
return err
}
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
// "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": event.EventID(),
}
// _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
rows, err := queryOutputRoomEvent(s, ctx, s.deleteEventsForRoomStmt, params)
if err != nil {
return err
}
for _, item := range rows {
item.OutputRoomEvent.HeaderedEventJSON = headeredJSON
_, err = setOutputRoomEvent(s, ctx, item)
}
return err
return err
}
@ -143,24 +276,31 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
ctx context.Context, txn *sql.Tx, r types.Range,
stateFilter *gomatrixserverlib.StateFilter,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt, params, err := prepareWithFilters(
s.db, txn, selectStateInRangeSQL,
[]interface{}{
r.Low(), r.High(),
},
// "SELECT id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" +
// " FROM syncapi_output_room_events" +
// " WHERE (id > $1 AND id <= $2)" +
// " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))"
// // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": r.Low(),
"@x3": r.High(),
"@x4": stateFilter.Limit,
}
query, params := prepareWithFilters(
s.jsonPropertyName, selectStateInRangeSQL, params,
stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes,
nil, stateFilter.Limit, FilterOrderAsc,
)
if err != nil {
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
}
rows, err := stmt.QueryContext(ctx, params...)
// rows, err := stmt.QueryContext(ctx, params...)
rows, err := queryOutputRoomEvent(s, ctx, query, params)
if err != nil {
return nil, nil, err
}
defer rows.Close() // nolint: errcheck
// Fetch all the state change events for all rooms between the two positions then loop each event and:
// - Keep a cache of the event by ID (99% of state change events are for the event itself)
// - For each room ID, build up an array of event IDs which represents cumulative adds/removes
@ -171,7 +311,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
// RoomID => A set (map[string]bool) of state event IDs which are between the two positions
stateNeeded := make(map[string]map[string]bool)
for rows.Next() {
for _, item := range rows {
var (
streamPos types.StreamPosition
eventBytes []byte
@ -179,10 +319,15 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
addIDsJSON string
delIDsJSON string
)
if err := rows.Scan(&streamPos, &eventBytes, &excludeFromSync, &addIDsJSON, &delIDsJSON); err != nil {
return nil, nil, err
}
// SELECT id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids
// if err := rows.Scan(&streamPos, &eventBytes, &excludeFromSync, &addIDsJSON, &delIDsJSON); err != nil {
// return nil, nil, err
// }
streamPos = types.StreamPosition(item.OutputRoomEvent.ID)
eventBytes = item.OutputRoomEvent.HeaderedEventJSON
excludeFromSync = item.OutputRoomEvent.ExcludeFromSync
addIDsJSON = item.OutputRoomEvent.AddStateIDs
delIDsJSON = item.OutputRoomEvent.RemoveStateIDs
addIDs, delIDs, err := unmarshalStateIDs(addIDsJSON, delIDsJSON)
if err != nil {
return nil, nil, err
@ -233,8 +378,20 @@ func (s *outputRoomEventsStatements) SelectMaxEventID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
}
// stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt)
rows, err := queryOutputRoomEventNumber(s, ctx, s.selectMaxEventIDStmt, params)
// err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if rows != nil {
nullableID.Int64 = rows[0].Max
}
if nullableID.Valid {
id = nullableID.Int64
}
@ -248,6 +405,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
event *gomatrixserverlib.HeaderedEvent, addState, removeState []string,
transactionID *api.TransactionID, excludeFromSync bool,
) (types.StreamPosition, error) {
var txnID *string
var sessionID *int64
if transactionID != nil {
@ -283,27 +441,74 @@ func (s *outputRoomEventsStatements) InsertEvent(
return 0, fmt.Errorf("json.Marshal(removeState): %w", err)
}
// id INTEGER PRIMARY KEY AUTOINCREMENT,
streamPos, err := s.streamIDStatements.nextPDUID(ctx, txn)
if err != nil {
return 0, err
}
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
_, err = insertStmt.ExecContext(
// "INSERT INTO syncapi_output_room_events (" +
// "id, room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" +
// ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " +
// "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
// insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
// _, err = insertStmt.ExecContext(
// ctx,
// streamPos,
// event.RoomID(),
// event.EventID(),
// headeredJSON,
// event.Type(),
// event.Sender(),
// containsURL,
// string(addStateJSON),
// string(removeStateJSON),
// sessionID,
// txnID,
// excludeFromSync,
// excludeFromSync,
// )
data := OutputRoomEventCosmos{
ID: int64(streamPos),
RoomID: event.RoomID(),
EventID: event.EventID(),
HeaderedEventJSON: headeredJSON,
Type: event.Type(),
Sender: event.Sender(),
ContainsUrl: containsURL,
AddStateIDs: string(addStateJSON),
RemoveStateIDs: string(removeStateJSON),
ExcludeFromSync: excludeFromSync,
}
if transactionID != nil {
data.SessionID = *sessionID
data.TransactionID = *txnID
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// id INTEGER PRIMARY KEY,
docId := fmt.Sprintf("%d", streamPos)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var dbData = OutputRoomEventCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
OutputRoomEvent: data,
}
var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
streamPos,
event.RoomID(),
event.EventID(),
headeredJSON,
event.Type(),
event.Sender(),
containsURL,
string(addStateJSON),
string(removeStateJSON),
sessionID,
txnID,
excludeFromSync,
excludeFromSync,
)
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData,
optionsCreate)
return streamPos, err
}
@ -314,30 +519,39 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
) ([]types.StreamEvent, bool, error) {
var query string
if onlySyncEvents {
// "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
// " WHERE room_id = $1 AND id > $2 AND id <= $3"
// // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
query = selectRecentEventsForSyncSQL
} else {
// "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
// " WHERE room_id = $1 AND id > $2 AND id <= $3" +
query = selectRecentEventsSQL
}
stmt, params, err := prepareWithFilters(
s.db, txn, query,
[]interface{}{
roomID, r.Low(), r.High(),
},
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
"@x3": r.Low(),
"@x4": r.High(),
"@x5": eventFilter.Limit + 1,
}
query, params = prepareWithFilters(
s.jsonPropertyName, query, params,
eventFilter.Senders, eventFilter.NotSenders,
eventFilter.Types, eventFilter.NotTypes,
nil, eventFilter.Limit+1, FilterOrderDesc,
)
if err != nil {
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
}
rows, err := stmt.QueryContext(ctx, params...)
// rows, err := stmt.QueryContext(ctx, params...)
rows, err := queryOutputRoomEvent(s, ctx, query, params)
if err != nil {
return nil, false, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed")
events, err := rowsToStreamEvents(rows)
events, err := rowsToStreamEvents(&rows)
if err != nil {
return nil, false, err
}
@ -367,24 +581,31 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
) ([]types.StreamEvent, error) {
stmt, params, err := prepareWithFilters(
s.db, txn, selectEarlyEventsSQL,
[]interface{}{
roomID, r.Low(), r.High(),
},
// "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
// " WHERE room_id = $1 AND id > $2 AND id <= $3"
// // WHEN, ORDER BY (and not LIMIT) are appended by prepareWithFilters
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
"@x3": r.Low(),
"@x4": r.High(),
"@x5": eventFilter.Limit,
}
stmt, params := prepareWithFilters(
s.jsonPropertyName, selectEarlyEventsSQL, params,
eventFilter.Senders, eventFilter.NotSenders,
eventFilter.Types, eventFilter.NotTypes,
nil, eventFilter.Limit, FilterOrderAsc,
)
if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
}
rows, err := stmt.QueryContext(ctx, params...)
// rows, err := stmt.QueryContext(ctx, params...)
rows, err := queryOutputRoomEvent(s, ctx, stmt, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectEarlyEvents: rows.close() failed")
events, err := rowsToStreamEvents(rows)
events, err := rowsToStreamEvents(&rows)
if err != nil {
return nil, err
}
@ -402,17 +623,27 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
func (s *outputRoomEventsStatements) SelectEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StreamEvent, error) {
// "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1"
var returnEvents []types.StreamEvent
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
// stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
for _, eventID := range eventIDs {
rows, err := stmt.QueryContext(ctx, eventID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": eventID,
}
// rows, err := stmt.QueryContext(ctx, eventID)
rows, err := queryOutputRoomEvent(s, ctx, s.selectEventsStmt, params)
if err != nil {
return nil, err
}
if streamEvents, err := rowsToStreamEvents(rows); err == nil {
if streamEvents, err := rowsToStreamEvents(&rows); err == nil {
returnEvents = append(returnEvents, streamEvents...)
}
internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
}
return returnEvents, nil
}
@ -420,13 +651,30 @@ func (s *outputRoomEventsStatements) SelectEvents(
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteEventsForRoomStmt).ExecContext(ctx, roomID)
// "DELETE FROM syncapi_output_room_events WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
// _, err = sqlutil.TxStmt(txn, s.deleteEventsForRoomStmt).ExecContext(ctx, roomID)
rows, err := queryOutputRoomEvent(s, ctx, s.deleteEventsForRoomStmt, params)
if err != nil {
return err
}
for _, item := range rows {
err = deleteOutputRoomEvent(s, ctx, item)
}
return err
}
func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
func rowsToStreamEvents(rows *[]OutputRoomEventCosmosData) ([]types.StreamEvent, error) {
var result []types.StreamEvent
for rows.Next() {
for _, item := range *rows {
var (
eventID string
streamPos types.StreamPosition
@ -436,9 +684,17 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
txnID *string
transactionID *api.TransactionID
)
if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil {
return nil, err
}
// SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id
// if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil {
// return nil, err
// }
eventID = item.OutputRoomEvent.EventID
streamPos = types.StreamPosition(item.OutputRoomEvent.ID)
eventBytes = item.OutputRoomEvent.HeaderedEventJSON
sessionID = &item.OutputRoomEvent.SessionID
excludeFromSync = item.OutputRoomEvent.ExcludeFromSync
txnID = &item.OutputRoomEvent.TransactionID
// TODO: Handle redacted events
var ev gomatrixserverlib.HeaderedEvent
if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {

View file

@ -17,93 +17,167 @@ package cosmosdb
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
const outputRoomEventsTopologySchema = `
-- Stores output room events received from the roomserver.
CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology (
event_id TEXT PRIMARY KEY,
topological_position BIGINT NOT NULL,
stream_position BIGINT NOT NULL,
room_id TEXT NOT NULL,
// const outputRoomEventsTopologySchema = `
// -- Stores output room events received from the roomserver.
// CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology (
// event_id TEXT PRIMARY KEY,
// topological_position BIGINT NOT NULL,
// stream_position BIGINT NOT NULL,
// room_id TEXT NOT NULL,
UNIQUE(topological_position, room_id, stream_position)
);
-- The topological order will be used in events selection and ordering
-- CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, stream_position, room_id);
`
// UNIQUE(topological_position, room_id, stream_position)
// );
// -- The topological order will be used in events selection and ordering
// -- CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, stream_position, room_id);
// `
const insertEventInTopologySQL = "" +
"INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id, stream_position)" +
" VALUES ($1, $2, $3, $4)" +
" ON CONFLICT DO NOTHING"
const selectEventIDsInRangeASCSQL = "" +
"SELECT event_id FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 AND (" +
"(topological_position > $2 AND topological_position < $3) OR" +
"(topological_position = $4 AND stream_position <= $5)" +
") ORDER BY topological_position ASC, stream_position ASC LIMIT $6"
const selectEventIDsInRangeDESCSQL = "" +
"SELECT event_id FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 AND (" +
"(topological_position > $2 AND topological_position < $3) OR" +
"(topological_position = $4 AND stream_position <= $5)" +
") ORDER BY topological_position DESC, stream_position DESC LIMIT $6"
const selectPositionInTopologySQL = "" +
"SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" +
" WHERE event_id = $1"
const selectMaxPositionInTopologySQL = "" +
"SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 ORDER BY stream_position DESC"
const deleteTopologyForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1"
type outputRoomEventsTopologyStatements struct {
db *sql.DB
insertEventInTopologyStmt *sql.Stmt
selectEventIDsInRangeASCStmt *sql.Stmt
selectEventIDsInRangeDESCStmt *sql.Stmt
selectPositionInTopologyStmt *sql.Stmt
selectMaxPositionInTopologyStmt *sql.Stmt
deleteTopologyForRoomStmt *sql.Stmt
type OutputRoomEventTopologyCosmos struct {
EventID string `json:"event_id"`
TopologicalPosition int64 `json:"topological_position"`
StreamPosition int64 `json:"stream_position"`
RoomID string `json:"room_id"`
}
func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
s := &outputRoomEventsTopologyStatements{
db: db,
}
_, err := db.Exec(outputRoomEventsTopologySchema)
type OutputRoomEventTopologyCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
OutputRoomEventTopology OutputRoomEventTopologyCosmos `json:"mx_syncapi_output_room_event_topology"`
}
// const insertEventInTopologySQL = "" +
// "INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id, stream_position)" +
// " VALUES ($1, $2, $3, $4)" +
// " ON CONFLICT DO NOTHING"
// "SELECT event_id FROM syncapi_output_room_events_topology" +
// " WHERE room_id = $1 AND (" +
// "(topological_position > $2 AND topological_position < $3) OR" +
// "(topological_position = $4 AND stream_position <= $5)" +
// ") ORDER BY topological_position ASC, stream_position ASC LIMIT $6"
const selectEventIDsInRangeASCSQL = "" +
"select top @x7 * from c where c._cn = @x1 " +
"and c.mx_syncapi_output_room_event_topology.room_id = @x2 " +
"and ( " +
"(c.mx_syncapi_output_room_event_topology.topological_position > @x3 and c.mx_syncapi_output_room_event_topology.topological_position < @x4) " +
"OR " +
"(c.mx_syncapi_output_room_event_topology.topological_position = @x5 and c.mx_syncapi_output_room_event_topology.stream_position < @x6) " +
") " +
"order by c.mx_syncapi_output_room_event_topology.topological_position asc "
// ", c.mx_syncapi_output_room_event_topology.stream_position asc "
// "SELECT event_id FROM syncapi_output_room_events_topology" +
// " WHERE room_id = $1 AND (" +
// "(topological_position > $2 AND topological_position < $3) OR" +
// "(topological_position = $4 AND stream_position <= $5)" +
// ") ORDER BY topological_position DESC, stream_position DESC LIMIT $6"
const selectEventIDsInRangeDESCSQL = "" +
"select top @x7 * from c where c._cn = @x1 " +
"and c.mx_syncapi_output_room_event_topology.room_id = @x2 " +
"and ( " +
"(c.mx_syncapi_output_room_event_topology.topological_position > @x3 and c.mx_syncapi_output_room_event_topology.topological_position < @x4) " +
"OR " +
"(c.mx_syncapi_output_room_event_topology.topological_position = @x5 and c.mx_syncapi_output_room_event_topology.stream_position < @x6) " +
") " +
"order by c.mx_syncapi_output_room_event_topology.topological_position desc "
// ", c.mx_syncapi_output_room_event_topology.stream_position desc "
// "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" +
// " WHERE event_id = $1"
const selectPositionInTopologySQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_output_room_event_topology.event_id = @x2 "
// "SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" +
// " WHERE room_id = $1 ORDER BY stream_position DESC"
// "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" +
// " WHERE topological_position=(" +
// "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" +
// ") ORDER BY stream_position DESC LIMIT 1"
const selectMaxPositionInTopologySQL = "" +
"select top 1 * from c where c._cn = @x1 " +
"and c.mx_syncapi_output_room_event_topology.topological_position = " +
"( " +
"select max(c.mx_syncapi_output_room_event_topology.topological_position) from c where c._cn = @x1 " +
"and c.mx_syncapi_output_room_event_topology.room_id = @x2" +
") " +
"order by c.mx_syncapi_output_room_event_topology.stream_position desc "
// "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1"
const deleteTopologyForRoomSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_output_room_event_topology.room_id = @x2 "
type outputRoomEventsTopologyStatements struct {
db *SyncServerDatasource
// insertEventInTopologyStmt *sql.Stmt
selectEventIDsInRangeASCStmt string
selectEventIDsInRangeDESCStmt string
selectPositionInTopologyStmt string
selectMaxPositionInTopologyStmt string
deleteTopologyForRoomStmt string
tableName string
}
func queryOutputRoomEventTopology(s *outputRoomEventsTopologyStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OutputRoomEventTopologyCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []OutputRoomEventTopologyCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil {
return nil, err
return response, nil
}
func deleteOutputRoomEventTopology(s *outputRoomEventsTopologyStatements, ctx context.Context, dbData OutputRoomEventTopologyCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil {
return nil, err
}
if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil {
return nil, err
}
if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil {
return nil, err
}
if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
return nil, err
}
if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil {
return nil, err
return err
}
func NewCosmosDBTopologyTable(db *SyncServerDatasource) (tables.Topology, error) {
s := &outputRoomEventsTopologyStatements{
db: db,
}
s.selectEventIDsInRangeASCStmt = selectEventIDsInRangeASCSQL
s.selectEventIDsInRangeDESCStmt = selectEventIDsInRangeDESCSQL
s.selectPositionInTopologyStmt = selectPositionInTopologySQL
s.selectMaxPositionInTopologyStmt = selectMaxPositionInTopologySQL
s.deleteTopologyForRoomStmt = deleteTopologyForRoomSQL
s.tableName = "output_room_events_topology"
return s, nil
}
@ -112,9 +186,44 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition,
) (types.StreamPosition, error) {
_, err := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt).ExecContext(
ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
)
// "INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id, stream_position)" +
// " VALUES ($1, $2, $3, $4)" +
// " ON CONFLICT DO NOTHING"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE(topological_position, room_id, stream_position)
docId := fmt.Sprintf("%d_%s_%d", event.Depth(), event.RoomID(), pos)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := OutputRoomEventTopologyCosmos{
EventID: event.EventID(),
TopologicalPosition: event.Depth(),
RoomID: event.RoomID(),
StreamPosition: int64(pos),
}
dbData := &OutputRoomEventTopologyCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
OutputRoomEventTopology: data,
}
// _, err := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt).ExecContext(
// ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
// )
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return types.StreamPosition(event.Depth()), err
}
@ -125,15 +234,38 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
) (eventIDs []string, err error) {
// Decide on the selection's order according to whether chronological order
// is requested or not.
var stmt *sql.Stmt
var stmt string
if chronologicalOrder {
stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt)
// "SELECT event_id FROM syncapi_output_room_events_topology" +
// " WHERE room_id = $1 AND (" +
// "(topological_position > $2 AND topological_position < $3) OR" +
// "(topological_position = $4 AND stream_position <= $5)" +
// ") ORDER BY topological_position ASC, stream_position ASC LIMIT $6"
stmt = s.selectEventIDsInRangeASCStmt
} else {
stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt)
// "SELECT event_id FROM syncapi_output_room_events_topology" +
// " WHERE room_id = $1 AND (" +
// "(topological_position > $2 AND topological_position < $3) OR" +
// "(topological_position = $4 AND stream_position <= $5)" +
// ") ORDER BY topological_position DESC, stream_position DESC LIMIT $6"
stmt = s.selectEventIDsInRangeDESCStmt
}
// Query the event IDs.
rows, err := stmt.QueryContext(ctx, roomID, minDepth, maxDepth, maxDepth, maxStreamPos, limit)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
"@x3": minDepth,
"@x4": maxDepth,
"@x5": maxDepth,
"@x6": maxStreamPos,
"@x7": limit,
}
rows, err := queryOutputRoomEventTopology(s, ctx, stmt, params)
// rows, err := stmt.QueryContext(ctx, roomID, minDepth, maxDepth, maxDepth, maxStreamPos, limit)
if err == sql.ErrNoRows {
// If no event matched the request, return an empty slice.
return []string{}, nil
@ -143,10 +275,11 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
// Return the IDs.
var eventID string
for rows.Next() {
if err = rows.Scan(&eventID); err != nil {
return
}
for _, item := range rows {
// if err = rows.Scan(&eventID); err != nil {
// return
// }
eventID = item.OutputRoomEventTopology.EventID
eventIDs = append(eventIDs, eventID)
}
@ -158,22 +291,89 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology(
ctx context.Context, txn *sql.Tx, eventID string,
) (pos types.StreamPosition, spos types.StreamPosition, err error) {
stmt := sqlutil.TxStmt(txn, s.selectPositionInTopologyStmt)
err = stmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos)
// "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" +
// " WHERE event_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": eventID,
}
rows, err := queryOutputRoomEventTopology(s, ctx, s.selectPositionInTopologyStmt, params)
// stmt := sqlutil.TxStmt(txn, s.selectPositionInTopologyStmt)
if err != nil {
return
}
if len(rows) == 0 {
return
}
// err = stmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos)
pos = types.StreamPosition(rows[0].OutputRoomEventTopology.TopologicalPosition)
spos = types.StreamPosition(rows[0].OutputRoomEventTopology.StreamPosition)
return
}
func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
ctx context.Context, txn *sql.Tx, roomID string,
) (pos types.StreamPosition, spos types.StreamPosition, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt)
err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
// "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" +
// " WHERE topological_position=(" +
// "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" +
// ") ORDER BY stream_position DESC LIMIT 1"
// stmt := sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
rows, err := queryOutputRoomEventTopology(s, ctx, s.selectMaxPositionInTopologyStmt, params)
// err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
if err != nil {
return
}
if len(rows) == 0 {
return
}
pos = types.StreamPosition(rows[0].OutputRoomEventTopology.TopologicalPosition)
spos = types.StreamPosition(rows[0].OutputRoomEventTopology.StreamPosition)
return
}
func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
// "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
rows, err := queryOutputRoomEventTopology(s, ctx, s.deleteTopologyForRoomStmt, params)
// _, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
if err != nil {
return
}
for _, item := range rows {
err = deleteOutputRoomEventTopology(s, ctx, item)
if err != nil {
return
}
}
return err
}

View file

@ -17,91 +17,175 @@ package cosmosdb
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
)
const peeksSchema = `
CREATE TABLE IF NOT EXISTS syncapi_peeks (
id INTEGER,
room_id TEXT NOT NULL,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
deleted BOOL NOT NULL DEFAULT false,
-- When the peek was created in UNIX epoch ms.
creation_ts INTEGER NOT NULL,
UNIQUE(room_id, user_id, device_id)
);
// const peeksSchema = `
// CREATE TABLE IF NOT EXISTS syncapi_peeks (
// id INTEGER,
// room_id TEXT NOT NULL,
// user_id TEXT NOT NULL,
// device_id TEXT NOT NULL,
// deleted BOOL NOT NULL DEFAULT false,
// -- When the peek was created in UNIX epoch ms.
// creation_ts INTEGER NOT NULL,
// UNIQUE(room_id, user_id, device_id)
// );
CREATE INDEX IF NOT EXISTS syncapi_peeks_room_id_idx ON syncapi_peeks(room_id);
CREATE INDEX IF NOT EXISTS syncapi_peeks_user_id_device_id_idx ON syncapi_peeks(user_id, device_id);
`
// CREATE INDEX IF NOT EXISTS syncapi_peeks_room_id_idx ON syncapi_peeks(room_id);
// CREATE INDEX IF NOT EXISTS syncapi_peeks_user_id_device_id_idx ON syncapi_peeks(user_id, device_id);
// `
const insertPeekSQL = "" +
"INSERT OR REPLACE INTO syncapi_peeks" +
" (id, room_id, user_id, device_id, creation_ts, deleted)" +
" VALUES ($1, $2, $3, $4, $5, false)"
type PeekCosmos struct {
ID int64 `json:"id"`
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
Deleted bool `json:"deleted"`
// Use the CosmosDB.Timestamp for this one
// creation_ts int64 `json:"creation_ts"`
}
type PeekCosmosMaxNumber struct {
Max int64 `json:"number"`
}
type PeekCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Peek PeekCosmos `json:"mx_syncapi_peek"`
}
// const insertPeekSQL = "" +
// "INSERT OR REPLACE INTO syncapi_peeks" +
// " (id, room_id, user_id, device_id, creation_ts, deleted)" +
// " VALUES ($1, $2, $3, $4, $5, false)"
// "UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3 AND device_id = $4"
const deletePeekSQL = "" +
"UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3 AND device_id = $4"
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_peek.room_id = @x2 " +
"and c.mx_syncapi_peek.user_id = @x3 " +
"and c.mx_syncapi_peek.device_id = @x4 "
// "UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3"
const deletePeeksSQL = "" +
"UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3"
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_peek.room_id = @x2 " +
"and c.mx_syncapi_peek.user_id = @x3 "
// we care about all the peeks which were created in this range, deleted in this range,
// or were created before this range but haven't been deleted yet.
// BEWARE: sqlite chokes on out of order substitution strings.
// "SELECT id, room_id, deleted FROM syncapi_peeks WHERE user_id = $1 AND device_id = $2 AND ((id <= $3 AND NOT deleted=true) OR (id > $3 AND id <= $4))"
const selectPeeksInRangeSQL = "" +
"SELECT id, room_id, deleted FROM syncapi_peeks WHERE user_id = $1 AND device_id = $2 AND ((id <= $3 AND NOT deleted=true) OR (id > $3 AND id <= $4))"
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_peek.user_id = @x2 " +
"and c.mx_syncapi_peek.device_id = @x3 " +
"and ( " +
"(c.mx_syncapi_peek.id <= @x4 and c.mx_syncapi_peek.deleted = false)" +
"or " +
"(c.mx_syncapi_peek.id > @x4 and c.mx_syncapi_peek.id <= @x5)" +
") "
// "SELECT room_id, user_id, device_id FROM syncapi_peeks WHERE deleted=false"
const selectPeekingDevicesSQL = "" +
"SELECT room_id, user_id, device_id FROM syncapi_peeks WHERE deleted=false"
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_peek.deleted = false "
// "SELECT MAX(id) FROM syncapi_peeks"
const selectMaxPeekIDSQL = "" +
"SELECT MAX(id) FROM syncapi_peeks"
"select max(c.mx_syncapi_peek.id) from c where c._cn = @x1 "
type peekStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
insertPeekStmt *sql.Stmt
deletePeekStmt *sql.Stmt
deletePeeksStmt *sql.Stmt
selectPeeksInRangeStmt *sql.Stmt
selectPeekingDevicesStmt *sql.Stmt
selectMaxPeekIDStmt *sql.Stmt
db *SyncServerDatasource
streamIDStatements *streamIDStatements
// insertPeekStmt *sql.Stmt
deletePeekStmt string
deletePeeksStmt string
selectPeeksInRangeStmt string
selectPeekingDevicesStmt string
selectMaxPeekIDStmt string
tableName string
}
func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) {
_, err := db.Exec(peeksSchema)
func queryPeek(s *peekStatements, ctx context.Context, qry string, params map[string]interface{}) ([]PeekCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []PeekCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryPeekMaxNumber(s *peekStatements, ctx context.Context, qry string, params map[string]interface{}) ([]PeekCosmosMaxNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []PeekCosmosMaxNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, nil
}
return response, nil
}
func setPeek(s *peekStatements, ctx context.Context, peek PeekCosmosData) (*PeekCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(peek.Pk, peek.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
peek.Id,
&peek,
optionsReplace)
return &peek, ex
}
func NewCosmosDBPeeksTable(db *SyncServerDatasource, streamID *streamIDStatements) (tables.Peeks, error) {
s := &peekStatements{
db: db,
streamIDStatements: streamID,
}
if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil {
return nil, err
}
if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil {
return nil, err
}
if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil {
return nil, err
}
if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil {
return nil, err
}
if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil {
return nil, err
}
if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil {
return nil, err
}
s.deletePeekStmt = deletePeekSQL
s.deletePeeksStmt = deletePeeksSQL
s.selectPeeksInRangeStmt = selectPeeksInRangeSQL
s.selectPeekingDevicesStmt = selectPeekingDevicesSQL
s.selectMaxPeekIDStmt = selectMaxPeekIDSQL
s.tableName = "peeks"
return s, nil
}
@ -112,39 +196,120 @@ func (s *peekStatements) InsertPeek(
if err != nil {
return
}
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
_, err = sqlutil.TxStmt(txn, s.insertPeekStmt).ExecContext(ctx, streamPos, roomID, userID, deviceID, nowMilli)
// "INSERT OR REPLACE INTO syncapi_peeks" +
// " (id, room_id, user_id, device_id, creation_ts, deleted)" +
// " VALUES ($1, $2, $3, $4, $5, false)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE(room_id, user_id, device_id)
docId := fmt.Sprintf("%d_%s_%d", roomID, userID, deviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := PeekCosmos{
ID: int64(streamPos),
RoomID: roomID,
UserID: userID,
DeviceID: deviceID,
}
dbData := &PeekCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
// nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
Timestamp: time.Now().Unix(),
Peek: data,
}
// _, err = sqlutil.TxStmt(txn, s.insertPeekStmt).ExecContext(ctx, streamPos, roomID, userID, deviceID, nowMilli)
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return
}
func (s *peekStatements) DeletePeek(
ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string,
) (streamPos types.StreamPosition, err error) {
// "UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3 AND device_id = $4"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
"@x3": userID,
"@x4": deviceID,
}
rows, err := queryPeek(s, ctx, s.deletePeekStmt, params)
// _, err = sqlutil.TxStmt(txn, s.deletePeekStmt).ExecContext(ctx, streamPos, roomID, userID, deviceID)
numAffected := len(rows)
if numAffected == 0 {
return 0, cosmosdbutil.ErrNoRows
}
// Only create a new ID if there are rows to mark as deleted. This is handled in an SQL TX for DBs
streamPos, err = s.streamIDStatements.nextPDUID(ctx, txn)
if err != nil {
return
return 0, err
}
for _, item := range rows {
item.Peek.Deleted = true
item.Peek.ID = int64(streamPos)
_, err = setPeek(s, ctx, item)
if err != nil {
return
}
}
_, err = sqlutil.TxStmt(txn, s.deletePeekStmt).ExecContext(ctx, streamPos, roomID, userID, deviceID)
return
}
func (s *peekStatements) DeletePeeks(
ctx context.Context, txn *sql.Tx, roomID, userID string,
) (types.StreamPosition, error) {
// "UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
"@x3": userID,
}
rows, err := queryPeek(s, ctx, s.deletePeekStmt, params)
// result, err := sqlutil.TxStmt(txn, s.deletePeeksStmt).ExecContext(ctx, streamPos, roomID, userID)
if err != nil {
return 0, err
}
numAffected := len(rows)
if numAffected == 0 {
return 0, cosmosdbutil.ErrNoRows
}
// Only create a new ID if there are rows to mark as deleted. This is handled in an SQL TX for DBs
streamPos, err := s.streamIDStatements.nextPDUID(ctx, txn)
if err != nil {
return 0, err
}
result, err := sqlutil.TxStmt(txn, s.deletePeeksStmt).ExecContext(ctx, streamPos, roomID, userID)
if err != nil {
return 0, err
}
numAffected, err := result.RowsAffected()
if err != nil {
return 0, err
}
if numAffected == 0 {
return 0, sql.ErrNoRows
for _, item := range rows {
item.Peek.Deleted = true
item.Peek.ID = int64(streamPos)
_, err = setPeek(s, ctx, item)
if err != nil {
return 0, err
}
}
return streamPos, nil
}
@ -152,40 +317,65 @@ func (s *peekStatements) DeletePeeks(
func (s *peekStatements) SelectPeeksInRange(
ctx context.Context, txn *sql.Tx, userID, deviceID string, r types.Range,
) (peeks []types.Peek, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectPeeksInRangeStmt).QueryContext(ctx, userID, deviceID, r.Low(), r.High())
// "SELECT id, room_id, deleted FROM syncapi_peeks WHERE user_id = $1 AND device_id = $2 AND ((id <= $3 AND NOT deleted=true) OR (id > $3 AND id <= $4))"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": userID,
"@x3": deviceID,
"@x4": r.Low(),
"@x5": r.High(),
}
rows, err := queryPeek(s, ctx, s.selectPeeksInRangeStmt, params)
// rows, err := sqlutil.TxStmt(txn, s.selectPeeksInRangeStmt).QueryContext(ctx, userID, deviceID, r.Low(), r.High())
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectPeeksInRange: rows.close() failed")
for rows.Next() {
for _, item := range rows {
peek := types.Peek{}
var id types.StreamPosition
if err = rows.Scan(&id, &peek.RoomID, &peek.Deleted); err != nil {
return
}
// if err = rows.Scan(&id, &peek.RoomID, &peek.Deleted); err != nil {
// return
// }
id = types.StreamPosition(item.Peek.ID)
peek.RoomID = item.Peek.RoomID
peek.Deleted = item.Peek.Deleted
peek.New = (id > r.Low() && id <= r.High()) && !peek.Deleted
peeks = append(peeks, peek)
}
return peeks, rows.Err()
return peeks, nil
}
func (s *peekStatements) SelectPeekingDevices(
ctx context.Context,
) (peekingDevices map[string][]types.PeekingDevice, err error) {
rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx)
// "SELECT room_id, user_id, device_id FROM syncapi_peeks WHERE deleted=false"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
}
rows, err := queryPeek(s, ctx, s.selectPeekingDevicesStmt, params)
// rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectPeekingDevices: rows.close() failed")
result := make(map[string][]types.PeekingDevice)
for rows.Next() {
for _, item := range rows {
var roomID, userID, deviceID string
if err := rows.Scan(&roomID, &userID, &deviceID); err != nil {
return nil, err
}
// if err := rows.Scan(&roomID, &userID, &deviceID); err != nil {
// return nil, err
// }
roomID = item.Peek.RoomID
userID = item.Peek.UserID
deviceID = item.Peek.DeviceID
devices := result[roomID]
devices = append(devices, types.PeekingDevice{UserID: userID, DeviceID: deviceID})
result[roomID] = devices
@ -196,9 +386,22 @@ func (s *peekStatements) SelectPeekingDevices(
func (s *peekStatements) SelectMaxPeekID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
// "SELECT MAX(id) FROM syncapi_peeks"
// stmt := sqlutil.TxStmt(txn, s.selectMaxPeekIDStmt)
var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxPeekIDStmt)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
}
rows, err := queryPeekMaxNumber(s, ctx, s.selectMaxPeekIDStmt, params)
// err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if rows != nil {
nullableID.Int64 = rows[0].Max
}
if nullableID.Valid {
id = nullableID.Int64
}

View file

@ -18,72 +18,129 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
const receiptsSchema = `
-- Stores data about receipts
CREATE TABLE IF NOT EXISTS syncapi_receipts (
-- The ID
id BIGINT,
room_id TEXT NOT NULL,
receipt_type TEXT NOT NULL,
user_id TEXT NOT NULL,
event_id TEXT NOT NULL,
receipt_ts BIGINT NOT NULL,
CONSTRAINT syncapi_receipts_unique UNIQUE (room_id, receipt_type, user_id)
);
CREATE INDEX IF NOT EXISTS syncapi_receipts_room_id_idx ON syncapi_receipts(room_id);
`
// const receiptsSchema = `
// -- Stores data about receipts
// CREATE TABLE IF NOT EXISTS syncapi_receipts (
// -- The ID
// id BIGINT,
// room_id TEXT NOT NULL,
// receipt_type TEXT NOT NULL,
// user_id TEXT NOT NULL,
// event_id TEXT NOT NULL,
// receipt_ts BIGINT NOT NULL,
// CONSTRAINT syncapi_receipts_unique UNIQUE (room_id, receipt_type, user_id)
// );
// CREATE INDEX IF NOT EXISTS syncapi_receipts_room_id_idx ON syncapi_receipts(room_id);
// `
const upsertReceipt = "" +
"INSERT INTO syncapi_receipts" +
" (id, room_id, receipt_type, user_id, event_id, receipt_ts)" +
" VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT (room_id, receipt_type, user_id)" +
" DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9"
const selectRoomReceipts = "" +
"SELECT id, room_id, receipt_type, user_id, event_id, receipt_ts" +
" FROM syncapi_receipts" +
" WHERE id > $1 and room_id in ($2)"
const selectMaxReceiptIDSQL = "" +
"SELECT MAX(id) FROM syncapi_receipts"
type receiptStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
upsertReceipt *sql.Stmt
selectRoomReceipts *sql.Stmt
selectMaxReceiptID *sql.Stmt
type ReceiptCosmos struct {
ID int64 `json:"id"`
RoomID string `json:"room_id"`
ReceiptType string `json:"receipt_type"`
UserID string `json:"user_id"`
EventID string `json:"event_id"`
ReceiptTS int64 `json:"receipt_ts"`
}
func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) {
_, err := db.Exec(receiptsSchema)
type ReceiptCosmosMaxNumber struct {
Max int64 `json:"number"`
}
type ReceiptCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Receipt ReceiptCosmos `json:"mx_syncapi_receipt"`
}
// const upsertReceipt = "" +
// "INSERT INTO syncapi_receipts" +
// " (id, room_id, receipt_type, user_id, event_id, receipt_ts)" +
// " VALUES ($1, $2, $3, $4, $5, $6)" +
// " ON CONFLICT (room_id, receipt_type, user_id)" +
// " DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9"
// "SELECT id, room_id, receipt_type, user_id, event_id, receipt_ts" +
// " FROM syncapi_receipts" +
// " WHERE id > $1 and room_id in ($2)"
const selectRoomReceipts = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_receipt.id > @x2 " +
"and ARRAY_CONTAINS(@x3, c.mx_syncapi_receipt.room_id)"
// "SELECT MAX(id) FROM syncapi_receipts"
const selectMaxReceiptIDSQL = "" +
"select max(c.mx_syncapi_receipt.id) as number from c where c._cn = @x1 "
type receiptStatements struct {
db *SyncServerDatasource
streamIDStatements *streamIDStatements
// upsertReceipt *sql.Stmt
// selectRoomReceipts *sql.Stmt
selectMaxReceiptID string
tableName string
}
func queryReceipt(s *receiptStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ReceiptCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []ReceiptCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryReceiptNumber(s *receiptStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ReceiptCosmosMaxNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []ReceiptCosmosMaxNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, nil
}
return response, nil
}
func NewCosmosDBReceiptsTable(db *SyncServerDatasource, streamID *streamIDStatements) (tables.Receipts, error) {
r := &receiptStatements{
db: db,
streamIDStatements: streamID,
}
if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil {
return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err)
}
if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil {
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
}
if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil {
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
}
r.selectMaxReceiptID = selectMaxReceiptIDSQL
r.tableName = "receipts"
return r, nil
}
@ -93,47 +150,115 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
if err != nil {
return
}
stmt := sqlutil.TxStmt(txn, r.upsertReceipt)
_, err = stmt.ExecContext(ctx, pos, roomId, receiptType, userId, eventId, timestamp, pos, eventId, timestamp)
// "INSERT INTO syncapi_receipts" +
// " (id, room_id, receipt_type, user_id, event_id, receipt_ts)" +
// " VALUES ($1, $2, $3, $4, $5, $6)" +
// " ON CONFLICT (room_id, receipt_type, user_id)" +
// " DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9"
data := ReceiptCosmos{
ID: int64(pos),
RoomID: roomId,
ReceiptType: receiptType,
UserID: userId,
EventID: eventId,
ReceiptTS: int64(timestamp),
}
var dbCollectionName = cosmosdbapi.GetCollectionName(r.db.databaseName, r.tableName)
var pk = cosmosdbapi.GetPartitionKey(r.db.cosmosConfig.ContainerName, dbCollectionName)
// CONSTRAINT syncapi_receipts_unique UNIQUE (room_id, receipt_type, user_id)
docId := fmt.Sprintf("%s_%s_%s", roomId, receiptType, userId)
cosmosDocId := cosmosdbapi.GetDocumentId(r.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var dbData = ReceiptCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
Receipt: data,
}
var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(r.db.connection).CreateDocument(
ctx,
r.db.cosmosConfig.DatabaseName,
r.db.cosmosConfig.ContainerName,
dbData,
optionsCreate)
// _, err = stmt.ExecContext(ctx, pos, roomId, receiptType, userId, eventId, timestamp, pos, eventId, timestamp)
return
}
// SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) {
selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1)
// "SELECT id, room_id, receipt_type, user_id, event_id, receipt_ts" +
// " FROM syncapi_receipts" +
// " WHERE id > $1 and room_id in ($2)"
// selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1)
lastPos := streamPos
params := make([]interface{}, len(roomIDs)+1)
params[0] = streamPos
for k, v := range roomIDs {
params[k+1] = v
// params := make([]interface{}, len(roomIDs)+1)
// params[0] = streamPos
// for k, v := range roomIDs {
// params[k+1] = v
var dbCollectionName = cosmosdbapi.GetCollectionName(r.db.databaseName, r.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": streamPos,
"@x3": roomIDs,
}
rows, err := r.db.QueryContext(ctx, selectSQL, params...)
rows, err := queryReceipt(r, ctx, selectRoomReceipts, params)
// rows, err := r.db.QueryContext(ctx, selectSQL, params...)
if err != nil {
return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed")
var res []api.OutputReceiptEvent
for rows.Next() {
for _, item := range rows {
r := api.OutputReceiptEvent{}
var id types.StreamPosition
err = rows.Scan(&id, &r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp)
if err != nil {
return 0, res, fmt.Errorf("unable to scan row to api.Receipts: %w", err)
}
// err = rows.Scan(&id, &r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp)
// if err != nil {
// return 0, res, fmt.Errorf("unable to scan row to api.Receipts: %w", err)
// }
id = types.StreamPosition(item.Receipt.ID)
r.RoomID = item.Receipt.RoomID
r.Type = item.Receipt.ReceiptType
r.UserID = item.Receipt.UserID
r.EventID = item.Receipt.EventID
r.Timestamp = gomatrixserverlib.Timestamp(item.Receipt.ReceiptTS)
res = append(res, r)
if id > lastPos {
lastPos = id
}
}
return lastPos, res, rows.Err()
return lastPos, res, nil
}
func (s *receiptStatements) SelectMaxReceiptID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxReceiptID)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
// "SELECT MAX(id) FROM syncapi_receipts"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
}
rows, err := queryReceiptNumber(s, ctx, s.selectMaxReceiptID, params)
// stmt := sqlutil.TxStmt(txn, s.selectMaxReceiptID)
if rows != nil {
nullableID.Int64 = rows[0].Max
}
// err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
}

View file

@ -18,108 +18,223 @@ import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/sirupsen/logrus"
)
const sendToDeviceSchema = `
-- Stores send-to-device messages.
CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
-- The ID that uniquely identifies this message.
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The user ID to send the message to.
user_id TEXT NOT NULL,
-- The device ID to send the message to.
device_id TEXT NOT NULL,
-- The event content JSON.
content TEXT NOT NULL
);
`
// const sendToDeviceSchema = `
// -- Stores send-to-device messages.
// CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
// -- The ID that uniquely identifies this message.
// id INTEGER PRIMARY KEY AUTOINCREMENT,
// -- The user ID to send the message to.
// user_id TEXT NOT NULL,
// -- The device ID to send the message to.
// device_id TEXT NOT NULL,
// -- The event content JSON.
// content TEXT NOT NULL
// );
// `
const insertSendToDeviceMessageSQL = `
INSERT INTO syncapi_send_to_device (user_id, device_id, content)
VALUES ($1, $2, $3)
`
type SendToDeviceCosmos struct {
ID int64 `json:"id"`
UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
Content string `json:"content"`
}
const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
ORDER BY id DESC
`
type SendToDeviceCosmosMaxNumber struct {
Max int64 `json:"number"`
}
type SendToDeviceCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
SendToDevice SendToDeviceCosmos `json:"mx_syncapi_send_to_device"`
}
// const insertSendToDeviceMessageSQL = `
// INSERT INTO syncapi_send_to_device (user_id, device_id, content)
// VALUES ($1, $2, $3)
// `
// SELECT id, user_id, device_id, content
// FROM syncapi_send_to_device
// WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
// ORDER BY id DESC
const selectSendToDeviceMessagesSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_syncapi_send_to_device.user_id = @x2 " +
"and c.mx_syncapi_send_to_device.device_id = @x3 " +
"and c.mx_syncapi_send_to_device.id > @x4 " +
"and c.mx_syncapi_send_to_device.id <= @x5 " +
"order by c.mx_syncapi_send_to_device.id desc "
const deleteSendToDeviceMessagesSQL = `
DELETE FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 AND id < $3
`
// "SELECT MAX(id) FROM syncapi_send_to_device"
const selectMaxSendToDeviceIDSQL = "" +
"SELECT MAX(id) FROM syncapi_send_to_device"
"select max(c.mx_syncapi_send_to_device.id) as number from c where c._cn = @x1 "
type sendToDeviceStatements struct {
db *sql.DB
insertSendToDeviceMessageStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt
db *SyncServerDatasource
// insertSendToDeviceMessageStmt *sql.Stmt
selectSendToDeviceMessagesStmt string
deleteSendToDeviceMessagesStmt *sql.Stmt
selectMaxSendToDeviceIDStmt *sql.Stmt
selectMaxSendToDeviceIDStmt string
tableName string
}
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
s := &sendToDeviceStatements{
db: db,
}
_, err := db.Exec(sendToDeviceSchema)
func querySendToDevice(s *sendToDeviceStatements, ctx context.Context, qry string, params map[string]interface{}) ([]SendToDeviceCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []SendToDeviceCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err
return response, nil
}
func querySendToDeviceNumber(s *sendToDeviceStatements, ctx context.Context, qry string, params map[string]interface{}) ([]SendToDeviceCosmosMaxNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []SendToDeviceCosmosMaxNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, nil
}
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.selectMaxSendToDeviceIDStmt, err = db.Prepare(selectMaxSendToDeviceIDSQL); err != nil {
return nil, err
return response, nil
}
func NewCosmosDBSendToDeviceTable(db *SyncServerDatasource) (tables.SendToDevice, error) {
s := &sendToDeviceStatements{
db: db,
}
// if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
// return nil, err
// }
s.selectSendToDeviceMessagesStmt = selectSendToDeviceMessagesSQL
// if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil {
// return nil, err
// }
s.selectMaxSendToDeviceIDStmt = selectMaxSendToDeviceIDSQL
s.tableName = "send_to_device"
return s, nil
}
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (pos types.StreamPosition, err error) {
var result sql.Result
result, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
if p, err := result.LastInsertId(); err != nil {
// id INTEGER PRIMARY KEY AUTOINCREMENT,
id, err := GetNextSendToDeviceID(s, ctx)
if err != nil {
return 0, err
} else {
pos = types.StreamPosition(p)
}
pos = types.StreamPosition(id)
// INSERT INTO syncapi_send_to_device (user_id, device_id, content)
// VALUES ($1, $2, $3)
data := SendToDeviceCosmos{
ID: int64(pos),
UserID: userID,
DeviceID: deviceID,
Content: content,
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// NO CONSTRAINT
docId := fmt.Sprintf("%d", pos)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var dbData = SendToDeviceCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
SendToDevice: data,
}
var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData,
optionsCreate)
return
}
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition,
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to)
// SELECT id, user_id, device_id, content
// FROM syncapi_send_to_device
// WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
// ORDER BY id DESC
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": userID,
"@x3": deviceID,
"@x4": from,
"@x5": to,
}
rows, err := querySendToDevice(s, ctx, s.selectSendToDeviceMessagesStmt, params)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed")
for rows.Next() {
for _, item := range rows {
var id types.StreamPosition
var userID, deviceID, content string
if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
logrus.WithError(err).Errorf("Failed to retrieve send-to-device message")
return
}
// if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
// logrus.WithError(err).Errorf("Failed to retrieve send-to-device message")
// return
// }
id = types.StreamPosition(item.SendToDevice.ID)
userID = item.SendToDevice.UserID
deviceID = item.SendToDevice.DeviceID
content = item.SendToDevice.Content
if id > lastPos {
lastPos = id
}
@ -128,8 +243,8 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
UserID: userID,
DeviceID: deviceID,
}
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message")
if jsonErr := json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
logrus.WithError(jsonErr).Errorf("Failed to unmarshal send-to-device message")
continue
}
events = append(events, event)
@ -137,7 +252,7 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
if lastPos == 0 {
lastPos = to
}
return lastPos, events, rows.Err()
return lastPos, events, err
}
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
@ -151,8 +266,21 @@ func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
// "SELECT MAX(id) FROM syncapi_send_to_device"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
}
rows, err := querySendToDeviceNumber(s, ctx, s.selectMaxSendToDeviceIDStmt, params)
// stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt)
// err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if rows != nil {
nullableID.Int64 = rows[0].Max
}
if nullableID.Valid {
id = nullableID.Int64
}

View file

@ -0,0 +1,12 @@
package cosmosdb
import (
"context"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
)
func GetNextSendToDeviceID(s *sendToDeviceStatements, ctx context.Context) (int64, error) {
const docId = "sendtodevice_seq"
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1)
}

View file

@ -4,91 +4,108 @@ import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/syncapi/types"
)
const streamIDTableSchema = `
-- Global stream ID counter, used by other tables.
CREATE TABLE IF NOT EXISTS syncapi_stream_id (
stream_name TEXT NOT NULL PRIMARY KEY,
stream_id INT DEFAULT 0,
// const streamIDTableSchema = `
// -- Global stream ID counter, used by other tables.
// CREATE TABLE IF NOT EXISTS syncapi_stream_id (
// stream_name TEXT NOT NULL PRIMARY KEY,
// stream_id INT DEFAULT 0,
UNIQUE(stream_name)
);
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0)
ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("receipt", 0)
ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("accountdata", 0)
ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("invite", 0)
ON CONFLICT DO NOTHING;
`
// UNIQUE(stream_name)
// );
// INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0)
// ON CONFLICT DO NOTHING;
// INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("receipt", 0)
// ON CONFLICT DO NOTHING;
// INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("accountdata", 0)
// ON CONFLICT DO NOTHING;
// INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("invite", 0)
// ON CONFLICT DO NOTHING;
// `
const increaseStreamIDStmt = "" +
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1"
// const increaseStreamIDStmt = "" +
// "UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1"
const selectStreamIDStmt = "" +
"SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1"
// const selectStreamIDStmt = "" +
// "SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1"
type streamIDStatements struct {
db *sql.DB
increaseStreamIDStmt *sql.Stmt
selectStreamIDStmt *sql.Stmt
db *SyncServerDatasource
// increaseStreamIDStmt *sql.Stmt
// selectStreamIDStmt *sql.Stmt
tableName string
}
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
func (s *streamIDStatements) prepare(db *SyncServerDatasource) (err error) {
s.db = db
_, err = db.Exec(streamIDTableSchema)
if err != nil {
return
}
if s.increaseStreamIDStmt, err = db.Prepare(increaseStreamIDStmt); err != nil {
return
}
if s.selectStreamIDStmt, err = db.Prepare(selectStreamIDStmt); err != nil {
return
}
s.tableName = "stream_id"
return
}
func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
return
const docId = "global_seq"
result, err := cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1)
// increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
// selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
// if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
// return
// }
// err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos)
if err != nil {
return -1, err
}
err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos)
pos = types.StreamPosition(result)
return
}
func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "receipt"); err != nil {
return
const docId = "receipt_seq"
result, err := cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1)
// increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
// selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
// if _, err = increaseStmt.ExecContext(ctx, "receipt"); err != nil {
// return
// }
// err = selectStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
if err != nil {
return -1, err
}
err = selectStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
pos = types.StreamPosition(result)
return
}
func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "invite"); err != nil {
return
const docId = "invite_seq"
result, err := cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1)
// increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
// selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
// if _, err = increaseStmt.ExecContext(ctx, "invite"); err != nil {
// return
// }
// err = selectStmt.QueryRowContext(ctx, "invite").Scan(&pos)
if err != nil {
return -1, err
}
err = selectStmt.QueryRowContext(ctx, "invite").Scan(&pos)
pos = types.StreamPosition(result)
return
}
func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "accountdata"); err != nil {
return
const docId = "accountdata_seq"
result, err := cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1)
// increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
// selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
// if _, err = increaseStmt.ExecContext(ctx, "accountdata"); err != nil {
// return
// }
// err = selectStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
if err != nil {
return -1, err
}
err = selectStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
pos = types.StreamPosition(result)
return
}

View file

@ -16,101 +16,104 @@
package cosmosdb
import (
"database/sql"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
// Import the sqlite3 package
_ "github.com/mattn/go-sqlite3"
// _ "github.com/mattn/go-sqlite3"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/shared"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
)
// SyncServerDatasource represents a sync server datasource which manages
// both the database for PDUs and caches for EDUs.
type SyncServerDatasource struct {
shared.Database
db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
streamID streamIDStatements
// db *sql.DB
writer cosmosdbutil.Writer
database cosmosdbutil.Database
cosmosdbutil.PartitionOffsetStatements
streamID streamIDStatements
connection cosmosdbapi.CosmosConnection
databaseName string
cosmosConfig cosmosdbapi.CosmosConfig
}
// NewDatabase creates a new sync server database
// nolint: gocyclo
func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) {
conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
configCosmos := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
var d SyncServerDatasource
var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil {
d.writer = cosmosdbutil.NewExclusiveWriterFake()
if err := d.prepare(dbProperties); err != nil {
return nil, err
}
d.writer = sqlutil.NewExclusiveWriter()
if err = d.prepare(dbProperties); err != nil {
return nil, err
d.connection = conn
d.cosmosConfig = configCosmos
d.databaseName = "syncapi"
d.database = cosmosdbutil.Database{
Connection: conn,
CosmosConfig: configCosmos,
DatabaseName: d.databaseName,
}
return &d, nil
}
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil {
if err = d.PartitionOffsetStatements.Prepare(&d.database, d.writer, "syncapi"); err != nil {
return err
}
if err = d.streamID.prepare(d.db); err != nil {
if err = d.streamID.prepare(d); err != nil {
return err
}
accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)
accountData, err := NewCosmosDBAccountDataTable(d, &d.streamID)
if err != nil {
return err
}
events, err := NewSqliteEventsTable(d.db, &d.streamID)
events, err := NewCosmosDBEventsTable(d, &d.streamID)
if err != nil {
return err
}
roomState, err := NewSqliteCurrentRoomStateTable(d.db, &d.streamID)
roomState, err := NewCosmosDBCurrentRoomStateTable(d, &d.streamID)
if err != nil {
return err
}
invites, err := NewSqliteInvitesTable(d.db, &d.streamID)
invites, err := NewCosmosDBInvitesTable(d, &d.streamID)
if err != nil {
return err
}
peeks, err := NewSqlitePeeksTable(d.db, &d.streamID)
peeks, err := NewCosmosDBPeeksTable(d, &d.streamID)
if err != nil {
return err
}
topology, err := NewSqliteTopologyTable(d.db)
topology, err := NewCosmosDBTopologyTable(d)
if err != nil {
return err
}
bwExtrem, err := NewSqliteBackwardsExtremitiesTable(d.db)
bwExtrem, err := NewCosmosDBBackwardsExtremitiesTable(d)
if err != nil {
return err
}
sendToDevice, err := NewSqliteSendToDeviceTable(d.db)
sendToDevice, err := NewCosmosDBSendToDeviceTable(d)
if err != nil {
return err
}
filter, err := NewSqliteFilterTable(d.db)
filter, err := NewCosmosDBFilterTable(d)
if err != nil {
return err
}
receipts, err := NewSqliteReceiptsTable(d.db, &d.streamID)
receipts, err := NewCosmosDBReceiptsTable(d, &d.streamID)
if err != nil {
return err
}
memberships, err := NewSqliteMembershipsTable(d.db)
memberships, err := NewCosmosDBMembershipsTable(d)
if err != nil {
return err
}
m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m)
deltas.LoadRemoveSendToDeviceSentColumn(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil {
return err
}
d.Database = shared.Database{
DB: d.db,
DB: nil,
Writer: d.writer,
Invites: invites,
Peeks: peeks,

View file

@ -674,12 +674,17 @@ func (d *Database) GetStateDeltas(
// * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block.
// * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block.
// - Get all CURRENTLY joined rooms, and add them to 'joined' block.
txn, err := d.readOnlySnapshot(ctx)
if err != nil {
return nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err)
// HACK: CosmosDB - Allow for DB nil
var txn *sql.Tx
succeeded := true
if d.DB != nil {
txn, err := d.readOnlySnapshot(ctx)
if err != nil {
return nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err)
}
defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err)
}
var succeeded bool
defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err)
var deltas []types.StateDelta

View file

@ -37,8 +37,9 @@ import (
// Database represents an account database
type Database struct {
sqlutil.PartitionOffsetStatements
writer sqlutil.Writer
database cosmosdbutil.Database
cosmosdbutil.PartitionOffsetStatements
writer cosmosdbutil.Writer
accounts accountsStatements
profiles profilesStatements
accountDatas accountDataStatements
@ -56,18 +57,23 @@ type Database struct {
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
configCosmos := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
d := &Database{
serverName: serverName,
databaseName: "userapi",
connection: conn,
cosmosConfig: config,
cosmosConfig: configCosmos,
// db: db,
writer: sqlutil.NewExclusiveWriter(),
// bcryptCost: bcryptCost,
// openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}
d.database = cosmosdbutil.Database{
Connection: conn,
CosmosConfig: configCosmos,
DatabaseName: d.databaseName,
}
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
@ -80,10 +86,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
// return nil, err
// }
// partitions := sqlutil.PartitionOffsetStatements{}
// if err = partitions.Prepare(db, d.writer, "account"); err != nil {
// return nil, err
// }
if err := d.PartitionOffsetStatements.Prepare(&d.database, d.writer, "account"); err != nil {
return nil, err
}
var err error
if err = d.accounts.prepare(d, serverName); err != nil {
return nil, err

View file

@ -160,8 +160,8 @@ func getDevice(s *devicesStatements, ctx context.Context, pk string, docId strin
return &response, err
}
func setDevice(s *devicesStatements, ctx context.Context, pk string, device DeviceCosmosData) (*DeviceCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, device.ETag)
func setDevice(s *devicesStatements, ctx context.Context, device DeviceCosmosData) (*DeviceCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(device.Pk, device.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
@ -345,7 +345,7 @@ func (s *devicesStatements) updateDeviceName(
response.Device.DisplayName = *displayName
var _, exReplace = setDevice(s, ctx, pk, *response)
var _, exReplace = setDevice(s, ctx, *response)
if exReplace != nil {
return exReplace
}
@ -460,8 +460,9 @@ func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, localpart,
}
response.Device.LastSeenTS = lastSeenTs
response.Device.LastSeenIP = ipAddr
var _, exReplace = setDevice(s, ctx, pk, *response)
var _, exReplace = setDevice(s, ctx, *response)
if exReplace != nil {
return exReplace
}