mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-26 08:13:09 -06:00
Implement Cosmos DB for the RoomServer Service (#5)
* - Implement Cosmos for the devices_table - Use the ConnectionString in the YAML to include the Tenant - Revert all other non implemented tables back to use SQLLite3 * - Change the Config to use "test.criticicalarc.com" Container - Add generic function GetDocumentOrNil to standardize GetDocument - Add func to return CrossPartition queries for Aggregates - Add func GetNextSequence() as generic seq generator for AutoIncrement - Add cosmosdbutil.ErrNoRows to return (emulate) sql.ErrNoRows - Add a "fake" ExclusiveWriterFake - Add standard "getXX", "setXX" and "queryXX" to all TABLE class files - Add specific Table SEQ for the Events table - Add specific Table SEQ for the Rooms table - Add specific Table SEQ for the StateSnapshot table
This commit is contained in:
parent
b696923333
commit
5d68daef80
|
|
@ -291,7 +291,7 @@ room_server:
|
|||
listen: http://localhost:7770
|
||||
connect: http://localhost:7770
|
||||
database:
|
||||
connection_string: file:roomserver.db
|
||||
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;"
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
|
@ -354,12 +354,12 @@ user_api:
|
|||
listen: http://localhost:7781
|
||||
connect: http://localhost:7781
|
||||
account_database:
|
||||
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=criticalarc.com;"
|
||||
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;"
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
device_database:
|
||||
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=criticalarc.com;"
|
||||
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;"
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
package cosmosdbapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
)
|
||||
|
||||
func GetDocumentId(tenantName string, collectionName string, id string) string {
|
||||
|
|
@ -11,4 +11,25 @@ func GetDocumentId(tenantName string, collectionName string, id string) string {
|
|||
|
||||
func GetPartitionKey(tenantName string, collectionName string) string {
|
||||
return fmt.Sprintf("%s,%s", collectionName, tenantName)
|
||||
}
|
||||
}
|
||||
|
||||
func GetDocumentOrNil(connection CosmosConnection, config CosmosConfig, ctx context.Context, partitionKey string, cosmosDocId string, dbData interface{}) error {
|
||||
var _, err = GetClient(connection).GetDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.ContainerName,
|
||||
cosmosDocId,
|
||||
GetGetDocumentOptions(partitionKey),
|
||||
&dbData,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err.Error() == "Resource that no longer exists" {
|
||||
dbData = nil
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,14 +6,14 @@ import (
|
|||
|
||||
func GetCreateDocumentOptions(pk string) cosmosapi.CreateDocumentOptions {
|
||||
return cosmosapi.CreateDocumentOptions{
|
||||
IsUpsert: false,
|
||||
IsUpsert: false,
|
||||
PartitionKeyValue: pk,
|
||||
}
|
||||
}
|
||||
|
||||
func GetUpsertDocumentOptions(pk string) cosmosapi.CreateDocumentOptions {
|
||||
return cosmosapi.CreateDocumentOptions{
|
||||
IsUpsert: true,
|
||||
IsUpsert: true,
|
||||
PartitionKeyValue: pk,
|
||||
}
|
||||
}
|
||||
|
|
@ -21,8 +21,16 @@ func GetUpsertDocumentOptions(pk string) cosmosapi.CreateDocumentOptions {
|
|||
func GetQueryDocumentsOptions(pk string) cosmosapi.QueryDocumentsOptions {
|
||||
return cosmosapi.QueryDocumentsOptions{
|
||||
PartitionKeyValue: pk,
|
||||
IsQuery: true,
|
||||
ContentType: cosmosapi.QUERY_CONTENT_TYPE,
|
||||
IsQuery: true,
|
||||
ContentType: cosmosapi.QUERY_CONTENT_TYPE,
|
||||
}
|
||||
}
|
||||
|
||||
func GetQueryAllPartitionsDocumentsOptions() cosmosapi.QueryDocumentsOptions {
|
||||
return cosmosapi.QueryDocumentsOptions{
|
||||
IsQuery: true,
|
||||
EnableCrossPartition: true,
|
||||
ContentType: cosmosapi.QUERY_CONTENT_TYPE,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -35,7 +43,7 @@ func GetGetDocumentOptions(pk string) cosmosapi.GetDocumentOptions {
|
|||
func GetReplaceDocumentOptions(pk string, etag string) cosmosapi.ReplaceDocumentOptions {
|
||||
return cosmosapi.ReplaceDocumentOptions{
|
||||
PartitionKeyValue: pk,
|
||||
IfMatch: etag,
|
||||
IfMatch: etag,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
76
internal/cosmosdbutil/document_seq.go
Normal file
76
internal/cosmosdbutil/document_seq.go
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
package cosmosdbutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
)
|
||||
|
||||
type SequenceCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
Value int64 `json:"_value"`
|
||||
}
|
||||
|
||||
func GetNextSequence(
|
||||
ctx context.Context,
|
||||
connection cosmosdbapi.CosmosConnection,
|
||||
config cosmosdbapi.CosmosConfig,
|
||||
serviceName string,
|
||||
tableName string,
|
||||
seqId string,
|
||||
initial int64,
|
||||
) (int64, error) {
|
||||
collName := fmt.Sprintf("%s_%s", tableName, seqId)
|
||||
dbCollectionName := cosmosdbapi.GetCollectionName(serviceName, collName)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(config.ContainerName, dbCollectionName, seqId)
|
||||
pk := cosmosDocId
|
||||
|
||||
dbData := SequenceCosmosData{}
|
||||
cosmosdbapi.GetDocumentOrNil(
|
||||
connection,
|
||||
config,
|
||||
ctx,
|
||||
pk,
|
||||
cosmosDocId,
|
||||
&dbData,
|
||||
)
|
||||
|
||||
if dbData.Id == "" {
|
||||
dbData = SequenceCosmosData{}
|
||||
dbData.Id = cosmosDocId
|
||||
dbData.Pk = pk
|
||||
dbData.Cn = dbCollectionName
|
||||
dbData.Value = initial
|
||||
var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||
var _, _, err = cosmosdbapi.GetClient(connection).CreateDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.ContainerName,
|
||||
dbData,
|
||||
optionsCreate,
|
||||
)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
} else {
|
||||
dbData.Value++
|
||||
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(dbData.Pk, dbData.ETag)
|
||||
var _, _, err = cosmosdbapi.GetClient(connection).ReplaceDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.ContainerName,
|
||||
cosmosDocId,
|
||||
dbData,
|
||||
optionsReplace,
|
||||
)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
return dbData.Value, nil
|
||||
}
|
||||
12
internal/cosmosdbutil/errors.go
Normal file
12
internal/cosmosdbutil/errors.go
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
package cosmosdbutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// ErrNoRows is returned by Scan when QueryRow doesn't return a
|
||||
// row. Used to simulate the SQL responses as its used for business logic
|
||||
var ErrNoRows = sql.ErrNoRows
|
||||
|
||||
var ErrNotImplemented = errors.New("cosmosdb: not implemented")
|
||||
77
internal/cosmosdbutil/writer_exclusive.go
Normal file
77
internal/cosmosdbutil/writer_exclusive.go
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
package cosmosdbutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// ExclusiveWriter implements sqlutil.Writer.
|
||||
// Allows non-SQL DBs to still use the same infrastructure
|
||||
// as Matrix assumes SQL TXN
|
||||
type ExclusiveWriterFake struct {
|
||||
running atomic.Bool
|
||||
todo chan transactionWriterTaskFake
|
||||
}
|
||||
|
||||
func NewExclusiveWriterFake() sqlutil.Writer {
|
||||
return &ExclusiveWriterFake{
|
||||
todo: make(chan transactionWriterTaskFake),
|
||||
}
|
||||
}
|
||||
|
||||
// transactionWriterTask represents a specific task.
|
||||
type transactionWriterTaskFake struct {
|
||||
db *sql.DB
|
||||
txn *sql.Tx
|
||||
f func(txn *sql.Tx) error
|
||||
wait chan error
|
||||
}
|
||||
|
||||
func (w *ExclusiveWriterFake) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
|
||||
if w.todo == nil {
|
||||
return errors.New("not initialised")
|
||||
}
|
||||
if !w.running.Load() {
|
||||
go w.run()
|
||||
}
|
||||
task := transactionWriterTaskFake{
|
||||
db: db,
|
||||
txn: txn,
|
||||
f: f,
|
||||
wait: make(chan error, 1),
|
||||
}
|
||||
w.todo <- task
|
||||
return <-task.wait
|
||||
}
|
||||
|
||||
// run processes the tasks for a given transaction writer. Only one
|
||||
// of these goroutines will run at a time. A transaction will be
|
||||
// opened using the database object from the task and then this will
|
||||
// be passed as a parameter to the task function.
|
||||
func (w *ExclusiveWriterFake) run() {
|
||||
if !w.running.CAS(false, true) {
|
||||
return
|
||||
}
|
||||
// if tracingEnabled {
|
||||
// gid := goid()
|
||||
// goidToWriter.Store(gid, w)
|
||||
// defer goidToWriter.Delete(gid)
|
||||
// }
|
||||
|
||||
defer w.running.Store(false)
|
||||
for task := range w.todo {
|
||||
if task.db != nil && task.txn != nil {
|
||||
task.wait <- task.f(task.txn)
|
||||
// } else if task.db != nil && task.txn == nil {
|
||||
// task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
|
||||
// return task.f(txn)
|
||||
// })
|
||||
} else {
|
||||
task.wait <- task.f(nil)
|
||||
}
|
||||
close(task.wait)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/pressly/goose"
|
||||
)
|
||||
|
||||
func LoadFromGoose() {
|
||||
goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
|
||||
}
|
||||
|
||||
func LoadAddForgottenColumn(m *sqlutil.Migrations) {
|
||||
m.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
|
||||
}
|
||||
|
||||
func UpAddForgottenColumn(tx *sql.Tx) error {
|
||||
_, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp;
|
||||
CREATE TABLE IF NOT EXISTS roomserver_membership (
|
||||
room_nid INTEGER NOT NULL,
|
||||
target_nid INTEGER NOT NULL,
|
||||
sender_nid INTEGER NOT NULL DEFAULT 0,
|
||||
membership_nid INTEGER NOT NULL DEFAULT 1,
|
||||
event_nid INTEGER NOT NULL DEFAULT 0,
|
||||
target_local BOOLEAN NOT NULL DEFAULT false,
|
||||
forgotten BOOLEAN NOT NULL DEFAULT false,
|
||||
UNIQUE (room_nid, target_nid)
|
||||
);
|
||||
INSERT
|
||||
INTO roomserver_membership (
|
||||
room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local
|
||||
) SELECT
|
||||
room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local
|
||||
FROM roomserver_membership_tmp
|
||||
;
|
||||
DROP TABLE roomserver_membership_tmp;`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownAddForgottenColumn(tx *sql.Tx) error {
|
||||
_, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp;
|
||||
CREATE TABLE IF NOT EXISTS roomserver_membership (
|
||||
room_nid INTEGER NOT NULL,
|
||||
target_nid INTEGER NOT NULL,
|
||||
sender_nid INTEGER NOT NULL DEFAULT 0,
|
||||
membership_nid INTEGER NOT NULL DEFAULT 1,
|
||||
event_nid INTEGER NOT NULL DEFAULT 0,
|
||||
target_local BOOLEAN NOT NULL DEFAULT false,
|
||||
UNIQUE (room_nid, target_nid)
|
||||
);
|
||||
INSERT
|
||||
INTO roomserver_membership (
|
||||
room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local
|
||||
) SELECT
|
||||
room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local
|
||||
FROM roomserver_membership_tmp
|
||||
;
|
||||
DROP TABLE roomserver_membership_tmp;`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -18,76 +18,152 @@ package cosmosdb
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
const eventJSONSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_event_json (
|
||||
event_nid INTEGER NOT NULL PRIMARY KEY,
|
||||
event_json TEXT NOT NULL
|
||||
);
|
||||
`
|
||||
// const eventJSONSchema = `
|
||||
// CREATE TABLE IF NOT EXISTS roomserver_event_json (
|
||||
// event_nid INTEGER NOT NULL PRIMARY KEY,
|
||||
// event_json TEXT NOT NULL
|
||||
// );
|
||||
// `
|
||||
|
||||
const insertEventJSONSQL = `
|
||||
INSERT OR REPLACE INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2)
|
||||
`
|
||||
type EventJSONCosmos struct {
|
||||
EventNID int64 `json:"event_nid"`
|
||||
EventJSON []byte `json:"event_json"`
|
||||
}
|
||||
|
||||
type EventJSONCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
EventJSON EventJSONCosmos `json:"mx_roomserver_event_json"`
|
||||
}
|
||||
|
||||
// const insertEventJSONSQL = `
|
||||
// INSERT OR REPLACE INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2)
|
||||
// `
|
||||
|
||||
// Bulk event JSON lookup by numeric event ID.
|
||||
// Sort by the numeric event ID.
|
||||
// This means that we can use binary search to lookup by numeric event ID.
|
||||
const bulkSelectEventJSONSQL = `
|
||||
SELECT event_nid, event_json FROM roomserver_event_json
|
||||
WHERE event_nid IN ($1)
|
||||
ORDER BY event_nid ASC
|
||||
`
|
||||
// SELECT event_nid, event_json FROM roomserver_event_json
|
||||
// WHERE event_nid IN ($1)
|
||||
// ORDER BY event_nid ASC
|
||||
const bulkSelectEventJSONSQL = "" +
|
||||
"select * from c where c._cn = @x1 " +
|
||||
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_json.event_nid) " +
|
||||
"order by c.mx_roomserver_event_json.event_nid asc"
|
||||
|
||||
type eventJSONStatements struct {
|
||||
db *sql.DB
|
||||
insertEventJSONStmt *sql.Stmt
|
||||
bulkSelectEventJSONStmt *sql.Stmt
|
||||
db *Database
|
||||
// insertEventJSONStmt *sql.Stmt
|
||||
bulkSelectEventJSONStmt string
|
||||
tableName string
|
||||
}
|
||||
|
||||
func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
|
||||
s := &eventJSONStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(eventJSONSchema)
|
||||
func queryEventJSON(s *eventJSONStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventJSONCosmosData, error) {
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []EventJSONCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, shared.StatementList{
|
||||
{&s.insertEventJSONStmt, insertEventJSONSQL},
|
||||
{&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL},
|
||||
}.Prepare(db)
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func NewCosmosDBEventJSONTable(db *Database) (tables.EventJSON, error) {
|
||||
s := &eventJSONStatements{
|
||||
db: db,
|
||||
}
|
||||
// _, err := db.Exec(eventJSONSchema)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// return s, shared.StatementList{
|
||||
// {&s.insertEventJSONStmt, insertEventJSONSQL},
|
||||
s.bulkSelectEventJSONStmt = bulkSelectEventJSONSQL
|
||||
// }.Prepare(db)
|
||||
s.tableName = "event_json"
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *eventJSONStatements) InsertEventJSON(
|
||||
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
|
||||
|
||||
// _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
|
||||
// INSERT OR REPLACE INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2)
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
|
||||
docId := fmt.Sprintf("%d", eventNID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
data := EventJSONCosmos{
|
||||
EventNID: int64(eventNID),
|
||||
EventJSON: eventJSON,
|
||||
}
|
||||
|
||||
var dbData = EventJSONCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
EventJSON: data,
|
||||
}
|
||||
|
||||
//Insert OR Replace
|
||||
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
&dbData,
|
||||
options)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *eventJSONStatements) BulkSelectEventJSON(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
) ([]tables.EventJSONPair, error) {
|
||||
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||
for k, v := range eventNIDs {
|
||||
iEventNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...)
|
||||
// SELECT event_nid, event_json FROM roomserver_event_json
|
||||
// WHERE event_nid IN ($1)
|
||||
// ORDER BY event_nid ASC
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": eventNIDs,
|
||||
}
|
||||
|
||||
response, err := queryEventJSON(s, ctx, s.bulkSelectEventJSONStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventJSON: rows.close() failed")
|
||||
|
||||
// We know that we will only get as many results as event NIDs
|
||||
// because of the unique constraint on event NIDs.
|
||||
|
|
@ -95,13 +171,11 @@ func (s *eventJSONStatements) BulkSelectEventJSON(
|
|||
// We might get fewer results than NIDs so we adjust the length of the slice before returning it.
|
||||
results := make([]tables.EventJSONPair, len(eventNIDs))
|
||||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
for _, item := range response {
|
||||
result := &results[i]
|
||||
var eventNID int64
|
||||
if err := rows.Scan(&eventNID, &result.EventJSON); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.EventNID = types.EventNID(eventNID)
|
||||
result.EventNID = types.EventNID(item.EventJSON.EventNID)
|
||||
result.EventJSON = item.EventJSON.EventJSON
|
||||
i++
|
||||
}
|
||||
return results[:i], nil
|
||||
}
|
||||
|
|
|
|||
24
roomserver/storage/cosmosdb/event_seq.go
Normal file
24
roomserver/storage/cosmosdb/event_seq.go
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
)
|
||||
|
||||
func GetNextEventStateKeyNID(s *eventStateKeyStatements, ctx context.Context) (int64, error) {
|
||||
const docId = "eventstatekeynid_seq"
|
||||
//1 insert start at 2
|
||||
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 2)
|
||||
}
|
||||
|
||||
func GetNextEventTypeNID(s *eventTypeStatements, ctx context.Context) (int64, error) {
|
||||
const docId = "eventtypenid_seq"
|
||||
//7 inserts start at 8
|
||||
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 8)
|
||||
}
|
||||
|
||||
func GetNextEventNID(s *eventStatements, ctx context.Context) (int64, error) {
|
||||
const docId = "eventnid_seq"
|
||||
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1)
|
||||
}
|
||||
|
|
@ -18,96 +18,248 @@ package cosmosdb
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
const eventStateKeysSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_event_state_keys (
|
||||
event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
event_state_key TEXT NOT NULL UNIQUE
|
||||
);
|
||||
INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key)
|
||||
VALUES (1, '')
|
||||
ON CONFLICT DO NOTHING;
|
||||
`
|
||||
// const eventStateKeysSchema = `
|
||||
// CREATE TABLE IF NOT EXISTS roomserver_event_state_keys (
|
||||
// event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
// event_state_key TEXT NOT NULL UNIQUE
|
||||
// );
|
||||
// INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key)
|
||||
// VALUES (1, '')
|
||||
// ON CONFLICT DO NOTHING;
|
||||
// `
|
||||
|
||||
type EventStateKeysCosmos struct {
|
||||
EventStateKeyNID int64 `json:"event_state_key_nid"`
|
||||
EventStateKey string `json:"event_state_key"`
|
||||
}
|
||||
|
||||
type EventStateKeysCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
EventStateKeys EventStateKeysCosmos `json:"mx_roomserver_event_state_keys"`
|
||||
}
|
||||
|
||||
// Same as insertEventTypeNIDSQL
|
||||
const insertEventStateKeyNIDSQL = `
|
||||
INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1)
|
||||
ON CONFLICT DO NOTHING;
|
||||
`
|
||||
// const insertEventStateKeyNIDSQL = `
|
||||
// INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1)
|
||||
// ON CONFLICT DO NOTHING;
|
||||
// `
|
||||
|
||||
const selectEventStateKeyNIDSQL = `
|
||||
SELECT event_state_key_nid FROM roomserver_event_state_keys
|
||||
WHERE event_state_key = $1
|
||||
`
|
||||
// SELECT event_state_key_nid FROM roomserver_event_state_keys
|
||||
// WHERE event_state_key = $1
|
||||
const selectEventStateKeyNIDSQL = "" +
|
||||
"select * from c where c._cn = @x1 " +
|
||||
"and c.mx_roomserver_event_state_keys.event_state_key = @x2"
|
||||
|
||||
// Bulk lookup from string state key to numeric ID for that state key.
|
||||
// Takes an array of strings as the query parameter.
|
||||
const bulkSelectEventStateKeySQL = `
|
||||
SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
|
||||
WHERE event_state_key IN ($1)
|
||||
`
|
||||
// // Bulk lookup from string state key to numeric ID for that state key.
|
||||
// // Takes an array of strings as the query parameter.
|
||||
// SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
|
||||
// WHERE event_state_key IN ($1)
|
||||
const bulkSelectEventStateKeySQL = "" +
|
||||
"select * from c where c._cn = @x1 " +
|
||||
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_state_keys.event_state_key_nid)"
|
||||
|
||||
// Bulk lookup from numeric ID to string state key for that state key.
|
||||
// Takes an array of strings as the query parameter.
|
||||
const bulkSelectEventStateKeyNIDSQL = `
|
||||
SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
|
||||
WHERE event_state_key_nid IN ($1)
|
||||
`
|
||||
// SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
|
||||
// WHERE event_state_key_nid IN ($1)
|
||||
const bulkSelectEventStateKeyNIDSQL = "" +
|
||||
"select * from c where c._cn = @x1 " +
|
||||
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_state_keys.event_state_key)"
|
||||
|
||||
type eventStateKeyStatements struct {
|
||||
db *sql.DB
|
||||
insertEventStateKeyNIDStmt *sql.Stmt
|
||||
selectEventStateKeyNIDStmt *sql.Stmt
|
||||
bulkSelectEventStateKeyNIDStmt *sql.Stmt
|
||||
bulkSelectEventStateKeyStmt *sql.Stmt
|
||||
db *Database
|
||||
insertEventStateKeyNIDStmt string
|
||||
selectEventStateKeyNIDStmt string
|
||||
bulkSelectEventStateKeyNIDStmt string
|
||||
bulkSelectEventStateKeyStmt string
|
||||
tableName string
|
||||
}
|
||||
|
||||
func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) {
|
||||
s := &eventStateKeyStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(eventStateKeysSchema)
|
||||
func queryEventStateKeys(s *eventStateKeyStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventStateKeysCosmosData, error) {
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []EventStateKeysCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, shared.StatementList{
|
||||
{&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL},
|
||||
{&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL},
|
||||
{&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL},
|
||||
{&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL},
|
||||
}.Prepare(db)
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func getEventStateKeys(s *eventStateKeyStatements, ctx context.Context, pk string, docId string) (*EventStateKeysCosmosData, error) {
|
||||
response := EventStateKeysCosmosData{}
|
||||
err := cosmosdbapi.GetDocumentOrNil(
|
||||
s.db.connection,
|
||||
s.db.cosmosConfig,
|
||||
ctx,
|
||||
pk,
|
||||
docId,
|
||||
&response)
|
||||
|
||||
if response.Id == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &response, err
|
||||
}
|
||||
|
||||
func NewCosmosDBEventStateKeysTable(db *Database) (tables.EventStateKeys, error) {
|
||||
s := &eventStateKeyStatements{
|
||||
db: db,
|
||||
}
|
||||
// return s, shared.StatementList{
|
||||
// {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL},
|
||||
s.selectEventStateKeyNIDStmt = selectEventStateKeyNIDSQL
|
||||
s.bulkSelectEventStateKeyNIDStmt = bulkSelectEventStateKeyNIDSQL
|
||||
s.bulkSelectEventStateKeyStmt = bulkSelectEventStateKeySQL
|
||||
// }.Prepare(db)
|
||||
s.tableName = "event_state_keys"
|
||||
//Add in the initial data
|
||||
ensureEventStateKeys(s, context.Background())
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func ensureEventStateKeys(s *eventStateKeyStatements, ctx context.Context) {
|
||||
|
||||
// INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key)
|
||||
// VALUES (1, '')
|
||||
// ON CONFLICT DO NOTHING;
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
// event_state_key TEXT NOT NULL UNIQUE
|
||||
docId := ""
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
data := EventStateKeysCosmos{
|
||||
EventStateKey: "",
|
||||
EventStateKeyNID: 1,
|
||||
}
|
||||
|
||||
// event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
dbData := EventStateKeysCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
EventStateKeys: data,
|
||||
}
|
||||
|
||||
insertEventStateKeyCore(s, ctx, dbData)
|
||||
}
|
||||
|
||||
func insertEventStateKeyCore(s *eventStateKeyStatements, ctx context.Context, dbData EventStateKeysCosmosData) error {
|
||||
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
dbData,
|
||||
options)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *eventStateKeyStatements) InsertEventStateKeyNID(
|
||||
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
||||
) (types.EventStateKeyNID, error) {
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt)
|
||||
res, err := insertStmt.ExecContext(ctx, eventStateKey)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
||||
// INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1)
|
||||
// ON CONFLICT DO NOTHING;
|
||||
if len(eventStateKey) == 0 {
|
||||
return 0, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
eventStateKeyNID, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
// event_state_key TEXT NOT NULL UNIQUE
|
||||
docId := eventStateKey
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
existing, _ := getEventStateKeys(s, ctx, pk, cosmosDocId)
|
||||
|
||||
var dbData EventStateKeysCosmosData
|
||||
if existing == nil {
|
||||
//Not exists, we need to create a new one with a SEQ
|
||||
eventStateKeyNIDSeq, seqErr := GetNextEventStateKeyNID(s, ctx)
|
||||
if seqErr != nil {
|
||||
return -1, seqErr
|
||||
}
|
||||
|
||||
data := EventStateKeysCosmos{
|
||||
EventStateKey: eventStateKey,
|
||||
EventStateKeyNID: eventStateKeyNIDSeq,
|
||||
}
|
||||
|
||||
// event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
dbData = EventStateKeysCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
EventStateKeys: data,
|
||||
}
|
||||
} else {
|
||||
dbData.EventStateKeys = existing.EventStateKeys
|
||||
}
|
||||
return types.EventStateKeyNID(eventStateKeyNID), err
|
||||
|
||||
err := insertEventStateKeyCore(s, ctx, dbData)
|
||||
|
||||
return types.EventStateKeyNID(dbData.EventStateKeys.EventStateKeyNID), err
|
||||
}
|
||||
|
||||
func (s *eventStateKeyStatements) SelectEventStateKeyNID(
|
||||
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
||||
) (types.EventStateKeyNID, error) {
|
||||
var eventStateKeyNID int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectEventStateKeyNIDStmt)
|
||||
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
|
||||
return types.EventStateKeyNID(eventStateKeyNID), err
|
||||
|
||||
// SELECT event_state_key_nid FROM roomserver_event_state_keys
|
||||
// WHERE event_state_key = $1
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": eventStateKey,
|
||||
}
|
||||
|
||||
response, err := queryEventStateKeys(s, ctx, s.selectEventStateKeyNIDStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
//See storage.assignStateKeyNID()
|
||||
if len(response) == 0 {
|
||||
return 0, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
|
||||
return types.EventStateKeyNID(response[0].EventStateKeys.EventStateKeyNID), err
|
||||
}
|
||||
|
||||
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
||||
|
|
@ -117,21 +269,25 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
|||
for k, v := range eventStateKeys {
|
||||
iEventStateKeys[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...)
|
||||
// SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
|
||||
// WHERE event_state_key IN ($1)
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": eventStateKeys,
|
||||
}
|
||||
|
||||
response, err := queryEventStateKeys(s, ctx, s.bulkSelectEventStateKeyNIDStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKeyNID: rows.close() failed")
|
||||
|
||||
result := make(map[string]types.EventStateKeyNID, len(eventStateKeys))
|
||||
for rows.Next() {
|
||||
var stateKey string
|
||||
var stateKeyNID int64
|
||||
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[stateKey] = types.EventStateKeyNID(stateKeyNID)
|
||||
for _, item := range response {
|
||||
result[item.EventStateKeys.EventStateKey] = types.EventStateKeyNID(item.EventStateKeys.EventStateKeyNID)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
|
@ -139,25 +295,24 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
|||
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
|
||||
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||
) (map[types.EventStateKeyNID]string, error) {
|
||||
iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
|
||||
for k, v := range eventStateKeyNIDs {
|
||||
iEventStateKeyNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...)
|
||||
// SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
|
||||
// WHERE event_state_key_nid IN ($1)
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": eventStateKeyNIDs,
|
||||
}
|
||||
|
||||
response, err := queryEventStateKeys(s, ctx, s.bulkSelectEventStateKeyStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKey: rows.close() failed")
|
||||
result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs))
|
||||
for rows.Next() {
|
||||
var stateKey string
|
||||
var stateKeyNID int64
|
||||
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[types.EventStateKeyNID(stateKeyNID)] = stateKey
|
||||
for _, item := range response {
|
||||
result[types.EventStateKeyNID(item.EventStateKeys.EventStateKeyNID)] = item.EventStateKeys.EventStateKey
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,30 +18,44 @@ package cosmosdb
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
const eventTypesSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_event_types (
|
||||
event_type_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
event_type TEXT NOT NULL UNIQUE
|
||||
);
|
||||
INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES
|
||||
(1, 'm.room.create'),
|
||||
(2, 'm.room.power_levels'),
|
||||
(3, 'm.room.join_rules'),
|
||||
(4, 'm.room.third_party_invite'),
|
||||
(5, 'm.room.member'),
|
||||
(6, 'm.room.redaction'),
|
||||
(7, 'm.room.history_visibility') ON CONFLICT DO NOTHING;
|
||||
`
|
||||
// const eventTypesSchema = `
|
||||
// CREATE TABLE IF NOT EXISTS roomserver_event_types (
|
||||
// event_type_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
// event_type TEXT NOT NULL UNIQUE
|
||||
// );
|
||||
// INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES
|
||||
// (1, 'm.room.create'),
|
||||
// (2, 'm.room.power_levels'),
|
||||
// (3, 'm.room.join_rules'),
|
||||
// (4, 'm.room.third_party_invite'),
|
||||
// (5, 'm.room.member'),
|
||||
// (6, 'm.room.redaction'),
|
||||
// (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING;
|
||||
// `
|
||||
|
||||
type EventTypeCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
EventType EventTypeCosmos `json:"mx_roomserver_event_type"`
|
||||
}
|
||||
|
||||
type EventTypeCosmos struct {
|
||||
EventTypeNID int64 `json:"event_type_nid"`
|
||||
EventType string `json:"event_type"`
|
||||
}
|
||||
|
||||
// Assign a new numeric event type ID.
|
||||
// The usual case is that the event type is not in the database.
|
||||
|
|
@ -56,105 +70,243 @@ const eventTypesSchema = `
|
|||
// return it. Modifying the rows will cause postgres to assign a new tuple for the
|
||||
// row even though the data doesn't change resulting in unncesssary modifications
|
||||
// to the indexes.
|
||||
const insertEventTypeNIDSQL = `
|
||||
INSERT INTO roomserver_event_types (event_type) VALUES ($1)
|
||||
ON CONFLICT DO NOTHING;
|
||||
`
|
||||
// const insertEventTypeNIDSQL = `
|
||||
// INSERT INTO roomserver_event_types (event_type) VALUES ($1)
|
||||
// ON CONFLICT DO NOTHING;
|
||||
// `
|
||||
|
||||
const insertEventTypeNIDResultSQL = `
|
||||
SELECT event_type_nid FROM roomserver_event_types
|
||||
WHERE rowid = last_insert_rowid();
|
||||
`
|
||||
// const insertEventTypeNIDResultSQL = `
|
||||
// SELECT event_type_nid FROM roomserver_event_types
|
||||
// WHERE rowid = last_insert_rowid();
|
||||
// `
|
||||
|
||||
const selectEventTypeNIDSQL = `
|
||||
SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1
|
||||
`
|
||||
// const selectEventTypeNIDSQL = `
|
||||
// SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1
|
||||
// `
|
||||
|
||||
// Bulk lookup from string event type to numeric ID for that event type.
|
||||
// Takes an array of strings as the query parameter.
|
||||
const bulkSelectEventTypeNIDSQL = `
|
||||
SELECT event_type, event_type_nid FROM roomserver_event_types
|
||||
WHERE event_type IN ($1)
|
||||
`
|
||||
// SELECT event_type, event_type_nid FROM roomserver_event_types
|
||||
// WHERE event_type IN ($1)
|
||||
const bulkSelectEventTypeNIDSQL = "" +
|
||||
"select * from c where c._cn = @x1 " +
|
||||
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_type.event_type)"
|
||||
|
||||
type eventTypeStatements struct {
|
||||
db *sql.DB
|
||||
insertEventTypeNIDStmt *sql.Stmt
|
||||
insertEventTypeNIDResultStmt *sql.Stmt
|
||||
selectEventTypeNIDStmt *sql.Stmt
|
||||
bulkSelectEventTypeNIDStmt *sql.Stmt
|
||||
db *Database
|
||||
// insertEventTypeNIDStmt *sql.Stmt
|
||||
// insertEventTypeNIDResultStmt *sql.Stmt
|
||||
// selectEventTypeNIDStmt *sql.Stmt
|
||||
bulkSelectEventTypeNIDStmt string
|
||||
tableName string
|
||||
}
|
||||
|
||||
func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
|
||||
func NewCosmosDBEventTypesTable(db *Database) (tables.EventTypes, error) {
|
||||
s := &eventTypeStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(eventTypesSchema)
|
||||
|
||||
// return s, shared.StatementList{
|
||||
// {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL},
|
||||
// {&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL},
|
||||
// {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL},
|
||||
s.bulkSelectEventTypeNIDStmt = bulkSelectEventTypeNIDSQL
|
||||
// }.Prepare(db)
|
||||
s.tableName = "event_types"
|
||||
ensureEventTypes(s, context.Background())
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func queryEventTypes(s *eventTypeStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventTypeCosmosData, error) {
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []EventTypeCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL},
|
||||
{&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL},
|
||||
{&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL},
|
||||
{&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL},
|
||||
}.Prepare(db)
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *eventTypeStatements) InsertEventTypeNID(
|
||||
ctx context.Context, txn *sql.Tx, eventType string,
|
||||
) (types.EventTypeNID, error) {
|
||||
var eventTypeNID int64
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDStmt)
|
||||
resultStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDResultStmt)
|
||||
_, err := insertStmt.ExecContext(ctx, eventType)
|
||||
//We need to create a new one with a SEQ
|
||||
eventTypeNIDSeq, seqErr := GetNextEventTypeNID(s, ctx)
|
||||
if seqErr != nil {
|
||||
return -1, seqErr
|
||||
}
|
||||
|
||||
data := EventTypeCosmos{
|
||||
EventType: eventType,
|
||||
EventTypeNID: eventTypeNIDSeq,
|
||||
}
|
||||
|
||||
dbData, err := insertEventTypeCore(s, ctx, data)
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insertStmt.ExecContext: %w", err)
|
||||
return 0, err
|
||||
}
|
||||
if err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID); err != nil {
|
||||
return 0, fmt.Errorf("resultStmt.QueryRowContext.Scan: %w", err)
|
||||
|
||||
return types.EventTypeNID(dbData.EventTypeNID), err
|
||||
}
|
||||
|
||||
func insertEventTypeCore(s *eventTypeStatements, ctx context.Context, eventType EventTypeCosmos) (*EventTypeCosmos, error) {
|
||||
// INSERT INTO roomserver_event_types (event_type) VALUES ($1)
|
||||
// ON CONFLICT DO NOTHING;
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
|
||||
//Unique on eventType
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, eventType.EventType)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
var dbData = EventTypeCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
EventType: eventType,
|
||||
}
|
||||
return types.EventTypeNID(eventTypeNID), err
|
||||
|
||||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
&dbData,
|
||||
options)
|
||||
|
||||
if err != nil {
|
||||
dbData, errGet := selectEventTypeCore(s, ctx, eventType.EventType)
|
||||
if errGet != nil {
|
||||
return nil, errGet
|
||||
}
|
||||
return dbData, nil
|
||||
}
|
||||
|
||||
return &dbData.EventType, err
|
||||
}
|
||||
|
||||
func ensureEventTypes(s *eventTypeStatements, ctx context.Context) error {
|
||||
// INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES
|
||||
// (1, 'm.room.create'),
|
||||
// (2, 'm.room.power_levels'),
|
||||
// (3, 'm.room.join_rules'),
|
||||
// (4, 'm.room.third_party_invite'),
|
||||
// (5, 'm.room.member'),
|
||||
// (6, 'm.room.redaction'),
|
||||
// (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING;
|
||||
|
||||
// (1, 'm.room.create'),
|
||||
_, err := insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 1, EventType: "m.room.create"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// (2, 'm.room.power_levels'),
|
||||
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 2, EventType: "m.room.power_levels"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// (3, 'm.room.join_rules'),
|
||||
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 3, EventType: "m.room.join_rules"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// (4, 'm.room.third_party_invite'),
|
||||
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 4, EventType: "m.room.third_party_invite"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// (5, 'm.room.member'),
|
||||
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 5, EventType: "m.room.member"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// (6, 'm.room.redaction'),
|
||||
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 6, EventType: "m.room.redaction"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING;
|
||||
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 7, EventType: "m.room.history_visibility"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func selectEventTypeCore(s *eventTypeStatements, ctx context.Context, eventType string) (*EventTypeCosmos, error) {
|
||||
var response EventTypeCosmosData
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, eventType)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
err := cosmosdbapi.GetDocumentOrNil(
|
||||
s.db.connection,
|
||||
s.db.cosmosConfig,
|
||||
ctx,
|
||||
pk,
|
||||
cosmosDocId,
|
||||
&response)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if response.Id == "" {
|
||||
return nil, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
|
||||
return &response.EventType, nil
|
||||
}
|
||||
|
||||
func (s *eventTypeStatements) SelectEventTypeNID(
|
||||
ctx context.Context, tx *sql.Tx, eventType string,
|
||||
) (types.EventTypeNID, error) {
|
||||
var eventTypeNID int64
|
||||
selectStmt := sqlutil.TxStmt(tx, s.selectEventTypeNIDStmt)
|
||||
err := selectStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID)
|
||||
return types.EventTypeNID(eventTypeNID), err
|
||||
|
||||
// SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1
|
||||
|
||||
dbData, err := selectEventTypeCore(s, ctx, eventType)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
return types.EventTypeNID(dbData.EventTypeNID), nil
|
||||
}
|
||||
|
||||
func (s *eventTypeStatements) BulkSelectEventTypeNID(
|
||||
ctx context.Context, eventTypes []string,
|
||||
) (map[string]types.EventTypeNID, error) {
|
||||
///////////////
|
||||
iEventTypes := make([]interface{}, len(eventTypes))
|
||||
for k, v := range eventTypes {
|
||||
iEventTypes[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventTypes)), 1)
|
||||
selectPrep, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
///////////////
|
||||
|
||||
rows, err := selectPrep.QueryContext(ctx, iEventTypes...)
|
||||
// SELECT event_type, event_type_nid FROM roomserver_event_types
|
||||
// WHERE event_type IN ($1)
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": eventTypes,
|
||||
}
|
||||
|
||||
response, err := queryEventTypes(s, ctx, s.bulkSelectEventTypeNIDStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventTypeNID: rows.close() failed")
|
||||
|
||||
result := make(map[string]types.EventTypeNID, len(eventTypes))
|
||||
for rows.Next() {
|
||||
for _, item := range response {
|
||||
var eventType string
|
||||
var eventTypeNID int64
|
||||
if err := rows.Scan(&eventType, &eventTypeNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eventType = item.EventType.EventType
|
||||
eventTypeNID = item.EventType.EventTypeNID
|
||||
result[eventType] = types.EventTypeNID(eventTypeNID)
|
||||
}
|
||||
return result, nil
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -18,73 +18,148 @@ package cosmosdb
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
const inviteSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_invites (
|
||||
invite_event_id TEXT PRIMARY KEY,
|
||||
room_nid INTEGER NOT NULL,
|
||||
target_nid INTEGER NOT NULL,
|
||||
sender_nid INTEGER NOT NULL DEFAULT 0,
|
||||
retired BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
invite_event_json TEXT NOT NULL
|
||||
);
|
||||
// const inviteSchema = `
|
||||
// CREATE TABLE IF NOT EXISTS roomserver_invites (
|
||||
// invite_event_id TEXT PRIMARY KEY,
|
||||
// room_nid INTEGER NOT NULL,
|
||||
// target_nid INTEGER NOT NULL,
|
||||
// sender_nid INTEGER NOT NULL DEFAULT 0,
|
||||
// retired BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
// invite_event_json TEXT NOT NULL
|
||||
// );
|
||||
|
||||
CREATE INDEX IF NOT EXISTS roomserver_invites_active_idx ON roomserver_invites (target_nid, room_nid)
|
||||
WHERE NOT retired;
|
||||
`
|
||||
const insertInviteEventSQL = "" +
|
||||
"INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," +
|
||||
" sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" +
|
||||
" ON CONFLICT DO NOTHING"
|
||||
// CREATE INDEX IF NOT EXISTS roomserver_invites_active_idx ON roomserver_invites (target_nid, room_nid)
|
||||
// WHERE NOT retired;
|
||||
// `
|
||||
|
||||
type InviteCosmos struct {
|
||||
InviteEventID string `json:"invite_event_id"`
|
||||
RoomNID int64 `json:"room_nid"`
|
||||
TargetNID int64 `json:"target_nid"`
|
||||
SenderNID int64 `json:"sender_nid"`
|
||||
Retired bool `json:"retired"`
|
||||
InviteEventJSON []byte `json:"invite_event_json"`
|
||||
}
|
||||
|
||||
type InviteCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
Invite InviteCosmos `json:"mx_roomserver_invite"`
|
||||
}
|
||||
|
||||
// const insertInviteEventSQL = "" +
|
||||
// "INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," +
|
||||
// " sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" +
|
||||
// " ON CONFLICT DO NOTHING"
|
||||
|
||||
// "SELECT invite_event_id, sender_nid FROM roomserver_invites" +
|
||||
// " WHERE target_nid = $1 AND room_nid = $2" +
|
||||
// " AND NOT retired"
|
||||
const selectInviteActiveForUserInRoomSQL = "" +
|
||||
"SELECT invite_event_id, sender_nid FROM roomserver_invites" +
|
||||
" WHERE target_nid = $1 AND room_nid = $2" +
|
||||
" AND NOT retired"
|
||||
"select * from c where c._cn = @x1 " +
|
||||
" and c.mx_roomserver_invite.target_nid = @x2" +
|
||||
" and c.mx_roomserver_invite.room_nid = @x3" +
|
||||
" and c.mx_roomserver_invite.retired = false"
|
||||
|
||||
// Retire every active invite for a user in a room.
|
||||
// Ideally we'd know which invite events were retired by a given update so we
|
||||
// wouldn't need to remove every active invite.
|
||||
// However the matrix protocol doesn't give us a way to reliably identify the
|
||||
// invites that were retired, so we are forced to retire all of them.
|
||||
const updateInviteRetiredSQL = `
|
||||
UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
|
||||
`
|
||||
// const updateInviteRetiredSQL = `
|
||||
// UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
|
||||
// `
|
||||
|
||||
const selectInvitesAboutToRetireSQL = `
|
||||
SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
|
||||
`
|
||||
// SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
|
||||
const selectInvitesAboutToRetireSQL = "" +
|
||||
"select * from c where c._cn = @x1 " +
|
||||
" and c.mx_roomserver_invite.room_nid = @x2" +
|
||||
" and c.mx_roomserver_invite.target_nid = @x3" +
|
||||
" and c.mx_roomserver_invite.retired = false"
|
||||
|
||||
type inviteStatements struct {
|
||||
db *sql.DB
|
||||
insertInviteEventStmt *sql.Stmt
|
||||
selectInviteActiveForUserInRoomStmt *sql.Stmt
|
||||
updateInviteRetiredStmt *sql.Stmt
|
||||
selectInvitesAboutToRetireStmt *sql.Stmt
|
||||
db *Database
|
||||
// insertInviteEventStmt *sql.Stmt
|
||||
selectInviteActiveForUserInRoomStmt string
|
||||
// updateInviteRetiredStmt *sql.Stmt
|
||||
selectInvitesAboutToRetireStmt string
|
||||
tableName string
|
||||
}
|
||||
|
||||
func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) {
|
||||
s := &inviteStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(inviteSchema)
|
||||
func queryInvite(s *inviteStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InviteCosmosData, error) {
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []InviteCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertInviteEventStmt, insertInviteEventSQL},
|
||||
{&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL},
|
||||
{&s.updateInviteRetiredStmt, updateInviteRetiredSQL},
|
||||
{&s.selectInvitesAboutToRetireStmt, selectInvitesAboutToRetireSQL},
|
||||
}.Prepare(db)
|
||||
func getInvite(s *inviteStatements, ctx context.Context, pk string, docId string) (*InviteCosmosData, error) {
|
||||
response := InviteCosmosData{}
|
||||
err := cosmosdbapi.GetDocumentOrNil(
|
||||
s.db.connection,
|
||||
s.db.cosmosConfig,
|
||||
ctx,
|
||||
pk,
|
||||
docId,
|
||||
&response)
|
||||
|
||||
if response.Id == "" {
|
||||
return nil, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
|
||||
return &response, err
|
||||
}
|
||||
|
||||
func setInvite(s *inviteStatements, ctx context.Context, invite InviteCosmosData) (*InviteCosmosData, error) {
|
||||
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(invite.Pk, invite.ETag)
|
||||
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
invite.Id,
|
||||
&invite,
|
||||
optionsReplace)
|
||||
return &invite, ex
|
||||
}
|
||||
|
||||
func NewCosmosDBInvitesTable(db *Database) (tables.Invites, error) {
|
||||
s := &inviteStatements{
|
||||
db: db,
|
||||
}
|
||||
// return s, shared.StatementList{
|
||||
// {&s.insertInviteEventStmt, insertInviteEventSQL},
|
||||
s.selectInviteActiveForUserInRoomStmt = selectInviteActiveForUserInRoomSQL
|
||||
// {&s.updateInviteRetiredStmt, updateInviteRetiredSQL},
|
||||
s.selectInvitesAboutToRetireStmt = selectInvitesAboutToRetireSQL
|
||||
// }.Prepare(db)
|
||||
s.tableName = "invites"
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *inviteStatements) InsertInviteEvent(
|
||||
|
|
@ -93,42 +168,84 @@ func (s *inviteStatements) InsertInviteEvent(
|
|||
targetUserNID, senderUserNID types.EventStateKeyNID,
|
||||
inviteEventJSON []byte,
|
||||
) (bool, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
|
||||
result, err := stmt.ExecContext(
|
||||
ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
|
||||
)
|
||||
|
||||
// "INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," +
|
||||
// " sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" +
|
||||
// " ON CONFLICT DO NOTHING"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
|
||||
data := InviteCosmos{
|
||||
InviteEventID: inviteEventID,
|
||||
InviteEventJSON: inviteEventJSON,
|
||||
Retired: false,
|
||||
RoomNID: int64(roomNID),
|
||||
SenderNID: int64(senderUserNID),
|
||||
TargetNID: int64(targetUserNID),
|
||||
}
|
||||
|
||||
// invite_event_id TEXT PRIMARY KEY,
|
||||
docId := inviteEventID
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
var dbData = InviteCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Invite: data,
|
||||
}
|
||||
|
||||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
&dbData,
|
||||
options)
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
count, err = result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count != 0, err
|
||||
// TODO: Is this important?
|
||||
// count, err = result.RowsAffected()
|
||||
// return count != 0, err
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *inviteStatements) UpdateInviteRetired(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) (eventIDs []string, err error) {
|
||||
|
||||
// "SELECT invite_event_id, sender_nid FROM roomserver_invites" +
|
||||
// " WHERE target_nid = $1 AND room_nid = $2" +
|
||||
// " AND NOT retired"
|
||||
|
||||
// gather all the event IDs we will retire
|
||||
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": targetUserNID,
|
||||
"@x3": roomNID,
|
||||
}
|
||||
|
||||
response, err := queryInvite(s, ctx, s.selectInvitesAboutToRetireStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var inviteEventID string
|
||||
if err = rows.Scan(&inviteEventID); err != nil {
|
||||
return
|
||||
}
|
||||
eventIDs = append(eventIDs, inviteEventID)
|
||||
|
||||
for _, item := range response {
|
||||
eventIDs = append(eventIDs, item.Invite.InviteEventID)
|
||||
// UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
|
||||
|
||||
// now retire the invites
|
||||
item.Invite.Retired = true
|
||||
_, err = setInvite(s, ctx, item)
|
||||
}
|
||||
// now retire the invites
|
||||
stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
|
||||
_, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -137,21 +254,27 @@ func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
|||
ctx context.Context,
|
||||
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
||||
) ([]types.EventStateKeyNID, []string, error) {
|
||||
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
|
||||
ctx, targetUserNID, roomNID,
|
||||
)
|
||||
|
||||
// SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": roomNID,
|
||||
"@x3": targetUserNID,
|
||||
}
|
||||
|
||||
response, err := queryInvite(s, ctx, s.selectInviteActiveForUserInRoomStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed")
|
||||
|
||||
var result []types.EventStateKeyNID
|
||||
var eventIDs []string
|
||||
for rows.Next() {
|
||||
var eventID string
|
||||
var senderUserNID int64
|
||||
if err := rows.Scan(&eventID, &senderUserNID); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, item := range response {
|
||||
var eventID = item.Invite.InviteEventID
|
||||
var senderUserNID = item.Invite.SenderNID
|
||||
result = append(result, types.EventStateKeyNID(senderUserNID))
|
||||
eventIDs = append(eventIDs, eventID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,125 +19,233 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
const membershipSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_membership (
|
||||
room_nid INTEGER NOT NULL,
|
||||
target_nid INTEGER NOT NULL,
|
||||
sender_nid INTEGER NOT NULL DEFAULT 0,
|
||||
membership_nid INTEGER NOT NULL DEFAULT 1,
|
||||
event_nid INTEGER NOT NULL DEFAULT 0,
|
||||
target_local BOOLEAN NOT NULL DEFAULT false,
|
||||
forgotten BOOLEAN NOT NULL DEFAULT false,
|
||||
UNIQUE (room_nid, target_nid)
|
||||
);
|
||||
`
|
||||
// const membershipSchema = `
|
||||
// CREATE TABLE IF NOT EXISTS roomserver_membership (
|
||||
// room_nid INTEGER NOT NULL,
|
||||
// target_nid INTEGER NOT NULL,
|
||||
// sender_nid INTEGER NOT NULL DEFAULT 0,
|
||||
// membership_nid INTEGER NOT NULL DEFAULT 1,
|
||||
// event_nid INTEGER NOT NULL DEFAULT 0,
|
||||
// target_local BOOLEAN NOT NULL DEFAULT false,
|
||||
// forgotten BOOLEAN NOT NULL DEFAULT false,
|
||||
// UNIQUE (room_nid, target_nid)
|
||||
// );
|
||||
// `
|
||||
|
||||
type MembershipCosmos struct {
|
||||
RoomNID int64 `json:"room_nid"`
|
||||
TargetNID int64 `json:"target_nid"`
|
||||
SenderNID int64 `json:"sender_nid"`
|
||||
MembershipNID int64 `json:"membership_nid"`
|
||||
EventNID int64 `json:"event_nid"`
|
||||
TargetLocal bool `json:"target_local"`
|
||||
Forgotten bool `json:"forgotten"`
|
||||
}
|
||||
|
||||
type MembershipCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
Membership MembershipCosmos `json:"mx_roomserver_membership"`
|
||||
}
|
||||
|
||||
type MembershipJoinedCountCosmosData struct {
|
||||
TargetNID int64 `json:"target_nid"`
|
||||
RoomCount int `json:"room_count"`
|
||||
}
|
||||
|
||||
// "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" +
|
||||
// " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
// " GROUP BY target_nid"
|
||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" +
|
||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
" GROUP BY target_nid"
|
||||
"select c.mx_roomserver_membership.target_nid, count(c.mx_roomserver_membership.room_id) as room_count from c where c._cn = @x1 " +
|
||||
" and ARRAY_CONTAINS(@x2, c.mx_roomserver_membership.room_id)" +
|
||||
" and c.mx_roomserver_membership.membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
||||
" and c.mx_roomserver_membership.forgotten = false" +
|
||||
" group by c.mx_roomserver_membership.target_nid"
|
||||
|
||||
// Insert a row in to membership table so that it can be locked by the
|
||||
// SELECT FOR UPDATE
|
||||
const insertMembershipSQL = "" +
|
||||
"INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" +
|
||||
" VALUES ($1, $2, $3)" +
|
||||
" ON CONFLICT DO NOTHING"
|
||||
// const insertMembershipSQL = "" +
|
||||
// "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" +
|
||||
// " VALUES ($1, $2, $3)" +
|
||||
// " ON CONFLICT DO NOTHING"
|
||||
|
||||
const selectMembershipFromRoomAndTargetSQL = "" +
|
||||
"SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" +
|
||||
" WHERE room_nid = $1 AND target_nid = $2"
|
||||
// const selectMembershipFromRoomAndTargetSQL = "" +
|
||||
// "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" +
|
||||
// " WHERE room_nid = $1 AND target_nid = $2"
|
||||
|
||||
// "SELECT event_nid FROM roomserver_membership" +
|
||||
// " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false"
|
||||
const selectMembershipsFromRoomAndMembershipSQL = "" +
|
||||
"SELECT event_nid FROM roomserver_membership" +
|
||||
" WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false"
|
||||
"select * from c where c._cn = @x1 " +
|
||||
" and c.mx_roomserver_membership.room_nid = @x2" +
|
||||
" and c.mx_roomserver_membership.membership_nid = @x3" +
|
||||
" and c.mx_roomserver_membership.forgotten = false"
|
||||
|
||||
// "SELECT event_nid FROM roomserver_membership" +
|
||||
// " WHERE room_nid = $1 AND membership_nid = $2" +
|
||||
// " AND target_local = true and forgotten = false"
|
||||
const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
|
||||
"SELECT event_nid FROM roomserver_membership" +
|
||||
" WHERE room_nid = $1 AND membership_nid = $2" +
|
||||
" AND target_local = true and forgotten = false"
|
||||
"select * from c where c._cn = @x1 " +
|
||||
" and c.mx_roomserver_membership.room_nid = @x2" +
|
||||
" and c.mx_roomserver_membership.membership_nid = @x3" +
|
||||
" and c.mx_roomserver_membership.target_local = true" +
|
||||
" and c.mx_roomserver_membership.forgotten = false"
|
||||
|
||||
// "SELECT event_nid FROM roomserver_membership" +
|
||||
// " WHERE room_nid = $1 and forgotten = false"
|
||||
const selectMembershipsFromRoomSQL = "" +
|
||||
"SELECT event_nid FROM roomserver_membership" +
|
||||
" WHERE room_nid = $1 and forgotten = false"
|
||||
"select * from c where c._cn = @x1 " +
|
||||
" and c.mx_roomserver_membership.room_nid = @x2" +
|
||||
" and c.mx_roomserver_membership.forgotten = false"
|
||||
|
||||
// "SELECT event_nid FROM roomserver_membership" +
|
||||
// " WHERE room_nid = $1" +
|
||||
// " AND target_local = true and forgotten = false"
|
||||
const selectLocalMembershipsFromRoomSQL = "" +
|
||||
"SELECT event_nid FROM roomserver_membership" +
|
||||
" WHERE room_nid = $1" +
|
||||
" AND target_local = true and forgotten = false"
|
||||
"select * from c where c._cn = @x1 " +
|
||||
" and c.mx_roomserver_membership.room_nid = @x2" +
|
||||
" and c.mx_roomserver_membership.target_local = true" +
|
||||
" and c.mx_roomserver_membership.forgotten = false"
|
||||
|
||||
const selectMembershipForUpdateSQL = "" +
|
||||
"SELECT membership_nid FROM roomserver_membership" +
|
||||
" WHERE room_nid = $1 AND target_nid = $2"
|
||||
// const selectMembershipForUpdateSQL = "" +
|
||||
// "SELECT membership_nid FROM roomserver_membership" +
|
||||
// " WHERE room_nid = $1 AND target_nid = $2"
|
||||
|
||||
const updateMembershipSQL = "" +
|
||||
"UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" +
|
||||
" WHERE room_nid = $5 AND target_nid = $6"
|
||||
// const updateMembershipSQL = "" +
|
||||
// "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" +
|
||||
// " WHERE room_nid = $5 AND target_nid = $6"
|
||||
|
||||
const updateMembershipForgetRoom = "" +
|
||||
"UPDATE roomserver_membership SET forgotten = $1" +
|
||||
" WHERE room_nid = $2 AND target_nid = $3"
|
||||
// const updateMembershipForgetRoom = "" +
|
||||
// "UPDATE roomserver_membership SET forgotten = $1" +
|
||||
// " WHERE room_nid = $2 AND target_nid = $3"
|
||||
|
||||
// "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
|
||||
const selectRoomsWithMembershipSQL = "" +
|
||||
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
|
||||
"select * from c where c._cn = @x1 " +
|
||||
" and c.mx_roomserver_membership.membership_nid = @x2" +
|
||||
" and c.mx_roomserver_membership.target_nid = true" +
|
||||
" and c.mx_roomserver_membership.forgotten = false"
|
||||
|
||||
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
|
||||
// joined to. Since this information is used to populate the user directory, we will
|
||||
// only return users that the user would ordinarily be able to see anyway.
|
||||
var selectKnownUsersSQL = "" +
|
||||
"SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " +
|
||||
"roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
|
||||
" WHERE room_nid IN (" +
|
||||
" SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
||||
") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3"
|
||||
// var selectKnownUsersSQL = "" +
|
||||
// "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " +
|
||||
// "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
|
||||
// " WHERE room_nid IN (" +
|
||||
// " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
||||
// ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3"
|
||||
|
||||
var selectKnownUsersSQLRooms = "" +
|
||||
"select * from c where c._cn = @x1 " +
|
||||
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_membership.room_id)"
|
||||
|
||||
var selectKnownUsersSQLDistinctRoom = "" +
|
||||
"select distinct top @x4 c.mx_roomserver_membership.room_nid as room_nid from c where c._cn = @x1 " +
|
||||
"and c.mx_roomserver_membership.target_nid = @x2 " +
|
||||
"and c.mx_roomserver_membership.membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " " +
|
||||
"and contains(c.mx_roomserver_membership.event_state_key, @x3) "
|
||||
|
||||
type membershipStatements struct {
|
||||
db *sql.DB
|
||||
insertMembershipStmt *sql.Stmt
|
||||
selectMembershipForUpdateStmt *sql.Stmt
|
||||
selectMembershipFromRoomAndTargetStmt *sql.Stmt
|
||||
selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
|
||||
selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt
|
||||
selectMembershipsFromRoomStmt *sql.Stmt
|
||||
selectLocalMembershipsFromRoomStmt *sql.Stmt
|
||||
selectRoomsWithMembershipStmt *sql.Stmt
|
||||
updateMembershipStmt *sql.Stmt
|
||||
selectKnownUsersStmt *sql.Stmt
|
||||
updateMembershipForgetRoomStmt *sql.Stmt
|
||||
db *Database
|
||||
// insertMembershipStmt *sql.Stmt
|
||||
// selectMembershipForUpdateStmt string
|
||||
// selectMembershipFromRoomAndTargetStmt string
|
||||
selectMembershipsFromRoomAndMembershipStmt string
|
||||
selectLocalMembershipsFromRoomAndMembershipStmt string
|
||||
selectMembershipsFromRoomStmt string
|
||||
selectLocalMembershipsFromRoomStmt string
|
||||
selectRoomsWithMembershipStmt string
|
||||
// updateMembershipStmt *sql.Stmt
|
||||
// selectKnownUsersStmt string
|
||||
// updateMembershipForgetRoomStmt *sql.Stmt
|
||||
tableName string
|
||||
}
|
||||
|
||||
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||
func queryMembership(s *membershipStatements, ctx context.Context, qry string, params map[string]interface{}) ([]MembershipCosmosData, error) {
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []MembershipCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func getMembership(s *membershipStatements, ctx context.Context, pk string, docId string) (*MembershipCosmosData, error) {
|
||||
response := MembershipCosmosData{}
|
||||
err := cosmosdbapi.GetDocumentOrNil(
|
||||
s.db.connection,
|
||||
s.db.cosmosConfig,
|
||||
ctx,
|
||||
pk,
|
||||
docId,
|
||||
&response)
|
||||
|
||||
if response.Id == "" {
|
||||
return nil, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
|
||||
return &response, err
|
||||
}
|
||||
|
||||
func setMembership(s *membershipStatements, ctx context.Context, pk string, membership MembershipCosmosData) (*MembershipCosmosData, error) {
|
||||
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, membership.ETag)
|
||||
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
membership.Id,
|
||||
&membership,
|
||||
optionsReplace)
|
||||
return &membership, ex
|
||||
}
|
||||
|
||||
func NewCosmosDBMembershipTable(db *Database) (tables.Membership, error) {
|
||||
s := &membershipStatements{
|
||||
db: db,
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertMembershipStmt, insertMembershipSQL},
|
||||
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
|
||||
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
|
||||
{&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL},
|
||||
{&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL},
|
||||
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
|
||||
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
|
||||
{&s.updateMembershipStmt, updateMembershipSQL},
|
||||
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
|
||||
{&s.selectKnownUsersStmt, selectKnownUsersSQL},
|
||||
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
|
||||
}.Prepare(db)
|
||||
}
|
||||
// return s, shared.StatementList{
|
||||
// {&s.insertMembershipStmt, insertMembershipSQL},
|
||||
// s.selectMembershipForUpdateStmt = selectMembershipForUpdateSQL
|
||||
// s.selectMembershipFromRoomAndTargetStmt = selectMembershipFromRoomAndTargetSQL
|
||||
s.selectMembershipsFromRoomAndMembershipStmt = selectMembershipsFromRoomAndMembershipSQL
|
||||
s.selectLocalMembershipsFromRoomAndMembershipStmt = selectLocalMembershipsFromRoomAndMembershipSQL
|
||||
s.selectMembershipsFromRoomStmt = selectMembershipsFromRoomSQL
|
||||
s.selectLocalMembershipsFromRoomStmt = selectLocalMembershipsFromRoomSQL
|
||||
// {&s.updateMembershipStmt, updateMembershipSQL},
|
||||
s.selectRoomsWithMembershipStmt = selectRoomsWithMembershipSQL
|
||||
// {&s.selectKnownUsersStmt, selectKnownUsersSQL},
|
||||
// {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
|
||||
// }.Prepare(db)
|
||||
|
||||
func (s *membershipStatements) execSchema(db *sql.DB) error {
|
||||
_, err := db.Exec(membershipSchema)
|
||||
return err
|
||||
s.tableName = "memberships"
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *membershipStatements) InsertMembership(
|
||||
|
|
@ -145,8 +253,45 @@ func (s *membershipStatements) InsertMembership(
|
|||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
localTarget bool,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
|
||||
|
||||
// "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" +
|
||||
// " VALUES ($1, $2, $3)" +
|
||||
// " ON CONFLICT DO NOTHING"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
|
||||
// UNIQUE (room_nid, target_nid)
|
||||
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
data := MembershipCosmos{
|
||||
EventNID: 0,
|
||||
Forgotten: false,
|
||||
MembershipNID: 1,
|
||||
RoomNID: int64(roomNID),
|
||||
SenderNID: 0,
|
||||
TargetLocal: false,
|
||||
TargetNID: int64(targetUserNID),
|
||||
}
|
||||
|
||||
var dbData = MembershipCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Membership: data,
|
||||
}
|
||||
|
||||
// " ON CONFLICT DO NOTHING"
|
||||
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
&dbData,
|
||||
options)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -154,10 +299,18 @@ func (s *membershipStatements) SelectMembershipForUpdate(
|
|||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) (membership tables.MembershipState, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt)
|
||||
err = stmt.QueryRowContext(
|
||||
ctx, roomNID, targetUserNID,
|
||||
).Scan(&membership)
|
||||
|
||||
// "SELECT membership_nid FROM roomserver_membership" +
|
||||
// " WHERE room_nid = $1 AND target_nid = $2"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response, err := getMembership(s, ctx, pk, cosmosDocId)
|
||||
if response != nil {
|
||||
membership = tables.MembershipState(response.Membership.MembershipNID)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -165,9 +318,20 @@ func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
|
|||
ctx context.Context,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
|
||||
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
|
||||
ctx, roomNID, targetUserNID,
|
||||
).Scan(&membership, &eventNID, &forgotten)
|
||||
|
||||
// "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" +
|
||||
// " WHERE room_nid = $1 AND target_nid = $2"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response, err := getMembership(s, ctx, pk, cosmosDocId)
|
||||
if response != nil {
|
||||
eventNID = types.EventNID(response.Membership.EventNID)
|
||||
forgotten = response.Membership.Forgotten
|
||||
membership = tables.MembershipState(response.Membership.MembershipNID)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -175,24 +339,31 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
|
|||
ctx context.Context,
|
||||
roomNID types.RoomNID, localOnly bool,
|
||||
) (eventNIDs []types.EventNID, err error) {
|
||||
var selectStmt *sql.Stmt
|
||||
var selectStmt string
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": roomNID,
|
||||
}
|
||||
if localOnly {
|
||||
// "SELECT event_nid FROM roomserver_membership" +
|
||||
// " WHERE room_nid = $1" +
|
||||
// " AND target_local = true and forgotten = false"
|
||||
selectStmt = s.selectLocalMembershipsFromRoomStmt
|
||||
|
||||
} else {
|
||||
// "SELECT event_nid FROM roomserver_membership" +
|
||||
// " WHERE room_nid = $1 and forgotten = false"
|
||||
selectStmt = s.selectMembershipsFromRoomStmt
|
||||
}
|
||||
rows, err := selectStmt.QueryContext(ctx, roomNID)
|
||||
response, err := queryMembership(s, ctx, selectStmt, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoom: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
var eNID types.EventNID
|
||||
if err = rows.Scan(&eNID); err != nil {
|
||||
return
|
||||
}
|
||||
eventNIDs = append(eventNIDs, eNID)
|
||||
for _, item := range response {
|
||||
eventNIDs = append(eventNIDs, types.EventNID(item.Membership.EventNID))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -201,24 +372,31 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
|||
ctx context.Context,
|
||||
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
|
||||
) (eventNIDs []types.EventNID, err error) {
|
||||
var stmt *sql.Stmt
|
||||
var stmt string
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": roomNID,
|
||||
"@x3": membership,
|
||||
}
|
||||
if localOnly {
|
||||
// "SELECT event_nid FROM roomserver_membership" +
|
||||
// " WHERE room_nid = $1 AND membership_nid = $2" +
|
||||
// " AND target_local = true and forgotten = false"
|
||||
stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt
|
||||
} else {
|
||||
// "SELECT event_nid FROM roomserver_membership" +
|
||||
// " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false"
|
||||
stmt = s.selectMembershipsFromRoomAndMembershipStmt
|
||||
}
|
||||
rows, err := stmt.QueryContext(ctx, roomNID, membership)
|
||||
response, err := queryMembership(s, ctx, stmt, params)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoomAndMembership: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
var eNID types.EventNID
|
||||
if err = rows.Scan(&eNID); err != nil {
|
||||
return
|
||||
}
|
||||
eventNIDs = append(eventNIDs, eNID)
|
||||
for _, item := range response {
|
||||
eventNIDs = append(eventNIDs, types.EventNID(item.Membership.EventNID))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -228,28 +406,48 @@ func (s *membershipStatements) UpdateMembership(
|
|||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
|
||||
eventNID types.EventNID, forgotten bool,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx, senderUserNID, membership, eventNID, forgotten, roomNID, targetUserNID,
|
||||
)
|
||||
|
||||
// "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" +
|
||||
// " WHERE room_nid = $5 AND target_nid = $6"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
dbData, err := getMembership(s, ctx, pk, cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbData.Membership.SenderNID = int64(senderUserNID)
|
||||
dbData.Membership.MembershipNID = int64(membership)
|
||||
dbData.Membership.EventNID = int64(eventNID)
|
||||
dbData.Membership.Forgotten = forgotten
|
||||
|
||||
_, err = setMembership(s, ctx, pk, *dbData)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||
) ([]types.RoomNID, error) {
|
||||
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
|
||||
|
||||
// "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": membershipState,
|
||||
"@x3": userID,
|
||||
}
|
||||
response, err := queryMembership(s, ctx, s.selectRoomsWithMembershipStmt, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed")
|
||||
var roomNIDs []types.RoomNID
|
||||
for rows.Next() {
|
||||
var roomNID types.RoomNID
|
||||
if err := rows.Scan(&roomNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomNIDs = append(roomNIDs, roomNID)
|
||||
for _, item := range response {
|
||||
roomNIDs = append(roomNIDs, types.RoomNID(item.Membership.RoomNID))
|
||||
}
|
||||
return roomNIDs, nil
|
||||
}
|
||||
|
|
@ -259,39 +457,136 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
|
|||
for i, v := range roomNIDs {
|
||||
iRoomNIDs[i] = v
|
||||
}
|
||||
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
|
||||
rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...)
|
||||
|
||||
// "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" +
|
||||
// " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
// " GROUP BY target_nid"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": roomNIDs,
|
||||
}
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []MembershipJoinedCountCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(selectJoinedUsersSetForRoomsSQL, params)
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
|
||||
|
||||
result := make(map[types.EventStateKeyNID]int)
|
||||
for rows.Next() {
|
||||
var userID types.EventStateKeyNID
|
||||
var count int
|
||||
if err := rows.Scan(&userID, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, item := range response {
|
||||
userID := types.EventStateKeyNID(item.TargetNID)
|
||||
count := item.RoomCount
|
||||
result[userID] = count
|
||||
}
|
||||
return result, rows.Err()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
|
||||
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
|
||||
|
||||
// " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
||||
// ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": userID,
|
||||
"@x3": searchString,
|
||||
"@x4": limit,
|
||||
}
|
||||
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var responseDistinctRoom []MembershipCosmos
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(selectKnownUsersSQLDistinctRoom, params) //
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&responseDistinctRoom,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rooms := []int64{}
|
||||
for _, item := range responseDistinctRoom {
|
||||
rooms = append(rooms, item.RoomNID)
|
||||
}
|
||||
|
||||
// "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " +
|
||||
// "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
|
||||
// " WHERE room_nid IN (" +
|
||||
|
||||
params = map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": rooms,
|
||||
}
|
||||
|
||||
var responseRooms []MembershipCosmos
|
||||
query = cosmosdbapi.GetQuery(selectKnownUsersSQLRooms, params)
|
||||
_, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&responseRooms,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetNIDs := []int64{}
|
||||
for _, item := range responseRooms {
|
||||
targetNIDs = append(targetNIDs, item.TargetNID)
|
||||
}
|
||||
|
||||
// HACK: Joined table
|
||||
var dbCollectionNameEventStateKeys = cosmosdbapi.GetCollectionName(s.db.databaseName, "event_state_keys")
|
||||
params = map[string]interface{}{
|
||||
"@x1": dbCollectionNameEventStateKeys,
|
||||
"@x2": targetNIDs,
|
||||
}
|
||||
|
||||
bulkSelectEventStateKeyStmt := "select * from c where c._cn = @x1 and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_state_keys.event_state_key_nid)"
|
||||
|
||||
var responseEventStateKeys []EventStateKeysCosmos
|
||||
query = cosmosdbapi.GetQuery(bulkSelectEventStateKeyStmt, params)
|
||||
_, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&responseEventStateKeys,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// SELECT DISTINCT event_state_key
|
||||
|
||||
result := []string{}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, item := range responseEventStateKeys {
|
||||
userID := item.EventStateKey
|
||||
result = append(result, userID)
|
||||
}
|
||||
return result, rows.Err()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *membershipStatements) UpdateForgetMembership(
|
||||
|
|
@ -299,8 +594,22 @@ func (s *membershipStatements) UpdateForgetMembership(
|
|||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
forget bool,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
|
||||
ctx, forget, roomNID, targetUserNID,
|
||||
)
|
||||
|
||||
// "UPDATE roomserver_membership SET forgotten = $1" +
|
||||
// " WHERE room_nid = $2 AND target_nid = $3"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
dbData, err := getMembership(s, ctx, pk, cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbData.Membership.Forgotten = forget
|
||||
|
||||
_, err = setMembership(s, ctx, pk, *dbData)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,9 +20,12 @@ import (
|
|||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
|
@ -32,59 +35,90 @@ import (
|
|||
// In Postgres an empty BYTEA field is not NULL so it's fine there. In SQLite it
|
||||
// seems to care that it's empty and therefore hits a NOT NULL constraint on insert.
|
||||
// We should really work out what the right thing to do here is.
|
||||
const previousEventSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_previous_events (
|
||||
previous_event_id TEXT NOT NULL,
|
||||
previous_reference_sha256 BLOB,
|
||||
event_nids TEXT NOT NULL,
|
||||
UNIQUE (previous_event_id, previous_reference_sha256)
|
||||
);
|
||||
`
|
||||
// const previousEventSchema = `
|
||||
// CREATE TABLE IF NOT EXISTS roomserver_previous_events (
|
||||
// previous_event_id TEXT NOT NULL,
|
||||
// previous_reference_sha256 BLOB,
|
||||
// event_nids TEXT NOT NULL,
|
||||
// UNIQUE (previous_event_id, previous_reference_sha256)
|
||||
// );
|
||||
// `
|
||||
|
||||
type PreviousEventCosmos struct {
|
||||
PreviousEventID string `json:"previous_event_id"`
|
||||
PreviousReferenceSha256 []byte `json:"previous_reference_sha256"`
|
||||
EventNIDs string `json:"event_nids"`
|
||||
}
|
||||
|
||||
type PreviousEventCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
PreviousEvent PreviousEventCosmos `json:"mx_roomserver_previous_event"`
|
||||
}
|
||||
|
||||
// Insert an entry into the previous_events table.
|
||||
// If there is already an entry indicating that an event references that previous event then
|
||||
// add the event NID to the list to indicate that this event references that previous event as well.
|
||||
// This should only be modified while holding a "FOR UPDATE" lock on the row in the rooms table for this room.
|
||||
// The lock is necessary to avoid data races when checking whether an event is already referenced by another event.
|
||||
const insertPreviousEventSQL = `
|
||||
INSERT OR REPLACE INTO roomserver_previous_events
|
||||
(previous_event_id, previous_reference_sha256, event_nids)
|
||||
VALUES ($1, $2, $3)
|
||||
`
|
||||
// const insertPreviousEventSQL = `
|
||||
// INSERT OR REPLACE INTO roomserver_previous_events
|
||||
// (previous_event_id, previous_reference_sha256, event_nids)
|
||||
// VALUES ($1, $2, $3)
|
||||
// `
|
||||
|
||||
const selectPreviousEventNIDsSQL = `
|
||||
SELECT event_nids FROM roomserver_previous_events
|
||||
WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
|
||||
`
|
||||
// const selectPreviousEventNIDsSQL = `
|
||||
// SELECT event_nids FROM roomserver_previous_events
|
||||
// WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
|
||||
// `
|
||||
|
||||
// Check if the event is referenced by another event in the table.
|
||||
// This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room.
|
||||
const selectPreviousEventExistsSQL = `
|
||||
SELECT 1 FROM roomserver_previous_events
|
||||
WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
|
||||
`
|
||||
// const selectPreviousEventExistsSQL = `
|
||||
// SELECT 1 FROM roomserver_previous_events
|
||||
// WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
|
||||
// `
|
||||
|
||||
type previousEventStatements struct {
|
||||
db *sql.DB
|
||||
insertPreviousEventStmt *sql.Stmt
|
||||
selectPreviousEventNIDsStmt *sql.Stmt
|
||||
selectPreviousEventExistsStmt *sql.Stmt
|
||||
db *Database
|
||||
// insertPreviousEventStmt *sql.Stmt
|
||||
// selectPreviousEventNIDsStmt *sql.Stmt
|
||||
// selectPreviousEventExistsStmt *sql.Stmt
|
||||
tableName string
|
||||
}
|
||||
|
||||
func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
|
||||
func getPreviousEvent(s *previousEventStatements, ctx context.Context, pk string, docId string) (*PreviousEventCosmosData, error) {
|
||||
response := PreviousEventCosmosData{}
|
||||
err := cosmosdbapi.GetDocumentOrNil(
|
||||
s.db.connection,
|
||||
s.db.cosmosConfig,
|
||||
ctx,
|
||||
pk,
|
||||
docId,
|
||||
&response)
|
||||
|
||||
if response.Id == "" {
|
||||
return nil, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
|
||||
return &response, err
|
||||
}
|
||||
|
||||
func NewCosmosDBPrevEventsTable(db *Database) (tables.PreviousEvents, error) {
|
||||
s := &previousEventStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(previousEventSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertPreviousEventStmt, insertPreviousEventSQL},
|
||||
{&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL},
|
||||
{&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
|
||||
}.Prepare(db)
|
||||
// return s, shared.StatementList{
|
||||
// {&s.insertPreviousEventStmt, insertPreviousEventSQL},
|
||||
// {&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL},
|
||||
// {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
|
||||
// }.Prepare(db)
|
||||
s.tableName = "previous_events"
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *previousEventStatements) InsertPreviousEvent(
|
||||
|
|
@ -94,28 +128,71 @@ func (s *previousEventStatements) InsertPreviousEvent(
|
|||
previousEventReferenceSHA256 []byte,
|
||||
eventNID types.EventNID,
|
||||
) error {
|
||||
var eventNIDs string
|
||||
eventNIDAsString := fmt.Sprintf("%d", eventNID)
|
||||
selectStmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt)
|
||||
err := selectStmt.QueryRowContext(ctx, previousEventID, previousEventReferenceSHA256).Scan(&eventNIDs)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err)
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
|
||||
// UNIQUE (previous_event_id, previous_reference_sha256)
|
||||
// TODO: Check value
|
||||
// docId := fmt.Sprintf("%s_%s", previousEventID, previousEventReferenceSHA256)
|
||||
docId := previousEventID
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
// SELECT 1 FROM roomserver_previous_events
|
||||
// WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
|
||||
existing, err := getPreviousEvent(s, ctx, pk, cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
if err != cosmosdbutil.ErrNoRows {
|
||||
return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var dbData PreviousEventCosmosData
|
||||
// Doesnt exist, create a new one
|
||||
if existing == nil {
|
||||
data := PreviousEventCosmos{
|
||||
EventNIDs: "",
|
||||
PreviousEventID: previousEventID,
|
||||
PreviousReferenceSha256: previousEventReferenceSHA256,
|
||||
}
|
||||
|
||||
dbData = PreviousEventCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
PreviousEvent: data,
|
||||
}
|
||||
} else {
|
||||
dbData = *existing
|
||||
}
|
||||
|
||||
var nids []string
|
||||
if eventNIDs != "" {
|
||||
nids = strings.Split(eventNIDs, ",")
|
||||
if dbData.PreviousEvent.EventNIDs != "" {
|
||||
nids = strings.Split(dbData.PreviousEvent.EventNIDs, ",")
|
||||
for _, nid := range nids {
|
||||
if nid == eventNIDAsString {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
eventNIDs = strings.Join(append(nids, eventNIDAsString), ",")
|
||||
dbData.PreviousEvent.EventNIDs = strings.Join(append(nids, eventNIDAsString), ",")
|
||||
} else {
|
||||
eventNIDs = eventNIDAsString
|
||||
dbData.PreviousEvent.EventNIDs = eventNIDAsString
|
||||
}
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
|
||||
_, err = insertStmt.ExecContext(
|
||||
ctx, previousEventID, previousEventReferenceSHA256, eventNIDs,
|
||||
|
||||
// INSERT OR REPLACE INTO roomserver_previous_events
|
||||
// (previous_event_id, previous_reference_sha256, event_nids)
|
||||
// VALUES ($1, $2, $3)
|
||||
|
||||
var optionsReplace = cosmosdbapi.GetUpsertDocumentOptions(pk)
|
||||
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
&dbData,
|
||||
optionsReplace,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
|
@ -125,7 +202,24 @@ func (s *previousEventStatements) InsertPreviousEvent(
|
|||
func (s *previousEventStatements) SelectPreviousEventExists(
|
||||
ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte,
|
||||
) error {
|
||||
var ok int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt)
|
||||
return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok)
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
|
||||
// UNIQUE (previous_event_id, previous_reference_sha256)
|
||||
// TODO: Check value
|
||||
// docId := fmt.Sprintf("%s_%s", previousEventID, previousEventReferenceSHA256)
|
||||
docId := eventID
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, string(docId))
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
// SELECT 1 FROM roomserver_previous_events
|
||||
// WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
|
||||
dbData, err := getPreviousEvent(s, ctx, pk, cosmosDocId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if dbData == nil {
|
||||
return cosmosdbutil.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,89 +17,199 @@ package cosmosdb
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
)
|
||||
|
||||
const publishedSchema = `
|
||||
-- Stores which rooms are published in the room directory
|
||||
CREATE TABLE IF NOT EXISTS roomserver_published (
|
||||
-- The room ID of the room
|
||||
room_id TEXT NOT NULL PRIMARY KEY,
|
||||
-- Whether it is published or not
|
||||
published BOOLEAN NOT NULL DEFAULT false
|
||||
);
|
||||
`
|
||||
// const publishedSchema = `
|
||||
// -- Stores which rooms are published in the room directory
|
||||
// CREATE TABLE IF NOT EXISTS roomserver_published (
|
||||
// -- The room ID of the room
|
||||
// room_id TEXT NOT NULL PRIMARY KEY,
|
||||
// -- Whether it is published or not
|
||||
// published BOOLEAN NOT NULL DEFAULT false
|
||||
// );
|
||||
// `
|
||||
|
||||
const upsertPublishedSQL = "" +
|
||||
"INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)"
|
||||
|
||||
const selectAllPublishedSQL = "" +
|
||||
"SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC"
|
||||
|
||||
const selectPublishedSQL = "" +
|
||||
"SELECT published FROM roomserver_published WHERE room_id = $1"
|
||||
|
||||
type publishedStatements struct {
|
||||
db *sql.DB
|
||||
upsertPublishedStmt *sql.Stmt
|
||||
selectAllPublishedStmt *sql.Stmt
|
||||
selectPublishedStmt *sql.Stmt
|
||||
type PublishCosmos struct {
|
||||
RoomID string `json:"room_id"`
|
||||
Published bool `json:"published"`
|
||||
}
|
||||
|
||||
func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) {
|
||||
s := &publishedStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(publishedSchema)
|
||||
type PublishCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
Publish PublishCosmos `json:"mx_roomserver_publish"`
|
||||
}
|
||||
|
||||
// const upsertPublishedSQL = "" +
|
||||
// "INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)"
|
||||
|
||||
// "SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC"
|
||||
const selectAllPublishedSQL = "" +
|
||||
"select * from c where c._cn = @x1 " +
|
||||
" and c.mx_roomserver_publish.published = @x2" +
|
||||
" order by c.mx_roomserver_publish.room_id asc"
|
||||
|
||||
// const selectPublishedSQL = "" +
|
||||
// "SELECT published FROM roomserver_published WHERE room_id = $1"
|
||||
|
||||
type publishedStatements struct {
|
||||
db *Database
|
||||
// upsertPublishedStmt *sql.Stmt
|
||||
selectAllPublishedStmt string
|
||||
// selectPublishedStmt *sql.Stmt
|
||||
tableName string
|
||||
}
|
||||
|
||||
func queryPublish(s *publishedStatements, ctx context.Context, qry string, params map[string]interface{}) ([]PublishCosmosData, error) {
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []PublishCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, shared.StatementList{
|
||||
{&s.upsertPublishedStmt, upsertPublishedSQL},
|
||||
{&s.selectAllPublishedStmt, selectAllPublishedSQL},
|
||||
{&s.selectPublishedStmt, selectPublishedSQL},
|
||||
}.Prepare(db)
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func getPublish(s *publishedStatements, ctx context.Context, pk string, docId string) (*PublishCosmosData, error) {
|
||||
response := PublishCosmosData{}
|
||||
err := cosmosdbapi.GetDocumentOrNil(
|
||||
s.db.connection,
|
||||
s.db.cosmosConfig,
|
||||
ctx,
|
||||
pk,
|
||||
docId,
|
||||
&response)
|
||||
|
||||
if response.Id == "" {
|
||||
return nil, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
|
||||
return &response, err
|
||||
}
|
||||
|
||||
func setPublish(s *publishedStatements, ctx context.Context, pk string, publish PublishCosmosData) (*PublishCosmosData, error) {
|
||||
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, publish.ETag)
|
||||
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
publish.Id,
|
||||
&publish,
|
||||
optionsReplace)
|
||||
return &publish, ex
|
||||
}
|
||||
|
||||
func NewCosmosDBPublishedTable(db *Database) (tables.Published, error) {
|
||||
s := &publishedStatements{
|
||||
db: db,
|
||||
}
|
||||
// _, err := db.Exec(publishedSchema)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// return s, shared.StatementList{
|
||||
// {&s.upsertPublishedStmt, upsertPublishedSQL},
|
||||
s.selectAllPublishedStmt = selectAllPublishedSQL
|
||||
// {&s.selectPublishedStmt, selectPublishedSQL},
|
||||
// }.Prepare(db)
|
||||
s.tableName = "published"
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *publishedStatements) UpsertRoomPublished(
|
||||
ctx context.Context, txn *sql.Tx, roomID string, published bool,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomID, published)
|
||||
|
||||
// "INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
// room_id TEXT NOT NULL PRIMARY KEY,
|
||||
docId := roomID
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
data := PublishCosmos{
|
||||
RoomID: roomID,
|
||||
Published: false,
|
||||
}
|
||||
|
||||
var dbData = PublishCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Publish: data,
|
||||
}
|
||||
|
||||
// "INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)"
|
||||
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
&dbData,
|
||||
options)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *publishedStatements) SelectPublishedFromRoomID(
|
||||
ctx context.Context, roomID string,
|
||||
) (published bool, err error) {
|
||||
err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
|
||||
// "SELECT published FROM roomserver_published WHERE room_id = $1"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
// room_id TEXT NOT NULL PRIMARY KEY,
|
||||
docId := roomID
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
response, err := getPublish(s, ctx, pk, cosmosDocId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return
|
||||
return response.Publish.Published, nil
|
||||
}
|
||||
|
||||
func (s *publishedStatements) SelectAllPublishedRooms(
|
||||
ctx context.Context, published bool,
|
||||
) ([]string, error) {
|
||||
rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published)
|
||||
|
||||
// "SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": published,
|
||||
}
|
||||
|
||||
response, err := queryPublish(s, ctx, s.selectAllPublishedStmt, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAllPublishedStmt: rows.close() failed")
|
||||
|
||||
var roomIDs []string
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
if err = rows.Scan(&roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
roomIDs = append(roomIDs, roomID)
|
||||
for _, item := range response {
|
||||
roomIDs = append(roomIDs, item.Publish.RoomID)
|
||||
}
|
||||
return roomIDs, rows.Err()
|
||||
return roomIDs, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,84 +17,207 @@ package cosmosdb
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
)
|
||||
|
||||
const redactionsSchema = `
|
||||
-- Stores information about the redacted state of events.
|
||||
-- We need to track redactions rather than blindly updating the event JSON table on receipt of a redaction
|
||||
-- because we might receive the redaction BEFORE we receive the event which it redacts (think backfill).
|
||||
CREATE TABLE IF NOT EXISTS roomserver_redactions (
|
||||
redaction_event_id TEXT PRIMARY KEY,
|
||||
redacts_event_id TEXT NOT NULL,
|
||||
-- Initially FALSE, set to TRUE when the redaction has been validated according to rooms v3+ spec
|
||||
-- https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events
|
||||
validated BOOLEAN NOT NULL
|
||||
);
|
||||
`
|
||||
// const redactionsSchema = `
|
||||
// -- Stores information about the redacted state of events.
|
||||
// -- We need to track redactions rather than blindly updating the event JSON table on receipt of a redaction
|
||||
// -- because we might receive the redaction BEFORE we receive the event which it redacts (think backfill).
|
||||
// CREATE TABLE IF NOT EXISTS roomserver_redactions (
|
||||
// redaction_event_id TEXT PRIMARY KEY,
|
||||
// redacts_event_id TEXT NOT NULL,
|
||||
// -- Initially FALSE, set to TRUE when the redaction has been validated according to rooms v3+ spec
|
||||
// -- https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events
|
||||
// validated BOOLEAN NOT NULL
|
||||
// );
|
||||
// `
|
||||
|
||||
const insertRedactionSQL = "" +
|
||||
"INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" +
|
||||
" VALUES ($1, $2, $3)"
|
||||
|
||||
const selectRedactionInfoByRedactionEventIDSQL = "" +
|
||||
"SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
|
||||
" WHERE redaction_event_id = $1"
|
||||
|
||||
const selectRedactionInfoByEventBeingRedactedSQL = "" +
|
||||
"SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
|
||||
" WHERE redacts_event_id = $1"
|
||||
|
||||
const markRedactionValidatedSQL = "" +
|
||||
" UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1"
|
||||
|
||||
type redactionStatements struct {
|
||||
db *sql.DB
|
||||
insertRedactionStmt *sql.Stmt
|
||||
selectRedactionInfoByRedactionEventIDStmt *sql.Stmt
|
||||
selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt
|
||||
markRedactionValidatedStmt *sql.Stmt
|
||||
type RedactionCosmos struct {
|
||||
RedactionEventID string `json:"redaction_event_id"`
|
||||
RedactsEventID string `json:"redacts_event_id"`
|
||||
Validated bool `json:"validated"`
|
||||
}
|
||||
|
||||
func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) {
|
||||
s := &redactionStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(redactionsSchema)
|
||||
type RedactionCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
Redaction RedactionCosmos `json:"mx_roomserver_redaction"`
|
||||
}
|
||||
|
||||
// const insertRedactionSQL = "" +
|
||||
// "INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" +
|
||||
// " VALUES ($1, $2, $3)"
|
||||
|
||||
// const selectRedactionInfoByRedactionEventIDSQL = "" +
|
||||
// "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
|
||||
// " WHERE redaction_event_id = $1"
|
||||
|
||||
// "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
|
||||
// " WHERE redacts_event_id = $1"
|
||||
const selectRedactionInfoByEventBeingRedactedSQL = "" +
|
||||
"select * from c where c._cn = @x1 " +
|
||||
" and c.mx_roomserver_redaction.redacts_event_id = @x2"
|
||||
|
||||
// const markRedactionValidatedSQL = "" +
|
||||
// " UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1"
|
||||
|
||||
type redactionStatements struct {
|
||||
db *Database
|
||||
// insertRedactionStmt *sql.Stmt
|
||||
// selectRedactionInfoByRedactionEventIDStmt *sql.Stmt
|
||||
selectRedactionInfoByEventBeingRedactedStmt string
|
||||
// markRedactionValidatedStmt *sql.Stmt
|
||||
tableName string
|
||||
}
|
||||
|
||||
func queryRedaction(s *redactionStatements, ctx context.Context, qry string, params map[string]interface{}) ([]RedactionCosmosData, error) {
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []RedactionCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertRedactionStmt, insertRedactionSQL},
|
||||
{&s.selectRedactionInfoByRedactionEventIDStmt, selectRedactionInfoByRedactionEventIDSQL},
|
||||
{&s.selectRedactionInfoByEventBeingRedactedStmt, selectRedactionInfoByEventBeingRedactedSQL},
|
||||
{&s.markRedactionValidatedStmt, markRedactionValidatedSQL},
|
||||
}.Prepare(db)
|
||||
func getRedaction(s *redactionStatements, ctx context.Context, pk string, docId string) (*RedactionCosmosData, error) {
|
||||
response := RedactionCosmosData{}
|
||||
err := cosmosdbapi.GetDocumentOrNil(
|
||||
s.db.connection,
|
||||
s.db.cosmosConfig,
|
||||
ctx,
|
||||
pk,
|
||||
docId,
|
||||
&response)
|
||||
|
||||
if response.Id == "" {
|
||||
return nil, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
|
||||
return &response, err
|
||||
}
|
||||
|
||||
func setRedaction(s *redactionStatements, ctx context.Context, pk string, redaction RedactionCosmosData) (*RedactionCosmosData, error) {
|
||||
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, redaction.ETag)
|
||||
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
redaction.Id,
|
||||
&redaction,
|
||||
optionsReplace)
|
||||
return &redaction, ex
|
||||
}
|
||||
|
||||
func NewCosmosDBRedactionsTable(db *Database) (tables.Redactions, error) {
|
||||
s := &redactionStatements{
|
||||
db: db,
|
||||
}
|
||||
|
||||
// return s, shared.StatementList{
|
||||
// {&s.insertRedactionStmt, insertRedactionSQL},
|
||||
// {&s.selectRedactionInfoByRedactionEventIDStmt, selectRedactionInfoByRedactionEventIDSQL},
|
||||
s.selectRedactionInfoByEventBeingRedactedStmt = selectRedactionInfoByEventBeingRedactedSQL
|
||||
// {&s.markRedactionValidatedStmt, markRedactionValidatedSQL},
|
||||
// }.Prepare(db)
|
||||
s.tableName = "redactions"
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *redactionStatements) InsertRedaction(
|
||||
ctx context.Context, txn *sql.Tx, info tables.RedactionInfo,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt)
|
||||
_, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated)
|
||||
return err
|
||||
|
||||
// "INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" +
|
||||
// " VALUES ($1, $2, $3)"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
// redaction_event_id TEXT PRIMARY KEY,
|
||||
docId := info.RedactionEventID
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
data := RedactionCosmos{
|
||||
RedactionEventID: info.RedactionEventID,
|
||||
RedactsEventID: info.RedactsEventID,
|
||||
Validated: info.Validated,
|
||||
}
|
||||
|
||||
var dbData = RedactionCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Redaction: data,
|
||||
}
|
||||
|
||||
// "INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" +
|
||||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
&dbData,
|
||||
options)
|
||||
|
||||
// TODO: Just forDebug - Remove exception
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//Ignore Error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *redactionStatements) SelectRedactionInfoByRedactionEventID(
|
||||
ctx context.Context, txn *sql.Tx, redactionEventID string,
|
||||
) (info *tables.RedactionInfo, err error) {
|
||||
info = &tables.RedactionInfo{}
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRedactionInfoByRedactionEventIDStmt)
|
||||
err = stmt.QueryRowContext(ctx, redactionEventID).Scan(
|
||||
&info.RedactionEventID, &info.RedactsEventID, &info.Validated,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
|
||||
// "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
|
||||
// " WHERE redaction_event_id = $1"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
// redaction_event_id TEXT PRIMARY KEY,
|
||||
docId := redactionEventID
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
response, err := getRedaction(s, ctx, pk, cosmosDocId)
|
||||
if err != nil {
|
||||
info = nil
|
||||
err = err
|
||||
return
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
info = nil
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
info = &tables.RedactionInfo{
|
||||
RedactionEventID: response.Redaction.RedactionEventID,
|
||||
RedactsEventID: response.Redaction.RedactsEventID,
|
||||
Validated: response.Redaction.Validated,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -102,14 +225,31 @@ func (s *redactionStatements) SelectRedactionInfoByRedactionEventID(
|
|||
func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted(
|
||||
ctx context.Context, txn *sql.Tx, eventID string,
|
||||
) (info *tables.RedactionInfo, err error) {
|
||||
info = &tables.RedactionInfo{}
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRedactionInfoByEventBeingRedactedStmt)
|
||||
err = stmt.QueryRowContext(ctx, eventID).Scan(
|
||||
&info.RedactionEventID, &info.RedactsEventID, &info.Validated,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
|
||||
// "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
|
||||
// " WHERE redacts_event_id = $1"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": eventID,
|
||||
}
|
||||
response, err := queryRedaction(s, ctx, s.selectRedactionInfoByEventBeingRedactedStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(response) == 0 {
|
||||
info = nil
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
// TODO: Check this is ok to return the 1st one
|
||||
*info = tables.RedactionInfo{
|
||||
RedactionEventID: response[0].Redaction.RedactionEventID,
|
||||
RedactsEventID: response[0].Redaction.RedactsEventID,
|
||||
Validated: response[0].Redaction.Validated,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -117,7 +257,22 @@ func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted(
|
|||
func (s *redactionStatements) MarkRedactionValidated(
|
||||
ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt)
|
||||
_, err := stmt.ExecContext(ctx, redactionEventID, validated)
|
||||
|
||||
// " UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
// redaction_event_id TEXT PRIMARY KEY,
|
||||
docId := redactionEventID
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
response, err := getRedaction(s, ctx, pk, cosmosDocId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response.Redaction.Validated = validated
|
||||
|
||||
_, err = setRedaction(s, ctx, pk, *response)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,84 +18,185 @@ package cosmosdb
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
)
|
||||
|
||||
const roomAliasesSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_room_aliases (
|
||||
alias TEXT NOT NULL PRIMARY KEY,
|
||||
room_id TEXT NOT NULL,
|
||||
creator_id TEXT NOT NULL
|
||||
);
|
||||
// const roomAliasesSchema = `
|
||||
// CREATE TABLE IF NOT EXISTS roomserver_room_aliases (
|
||||
// alias TEXT NOT NULL PRIMARY KEY,
|
||||
// room_id TEXT NOT NULL,
|
||||
// creator_id TEXT NOT NULL
|
||||
// );
|
||||
|
||||
CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id);
|
||||
`
|
||||
// CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id);
|
||||
// `
|
||||
|
||||
const insertRoomAliasSQL = `
|
||||
INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3)
|
||||
`
|
||||
|
||||
const selectRoomIDFromAliasSQL = `
|
||||
SELECT room_id FROM roomserver_room_aliases WHERE alias = $1
|
||||
`
|
||||
|
||||
const selectAliasesFromRoomIDSQL = `
|
||||
SELECT alias FROM roomserver_room_aliases WHERE room_id = $1
|
||||
`
|
||||
|
||||
const selectCreatorIDFromAliasSQL = `
|
||||
SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1
|
||||
`
|
||||
|
||||
const deleteRoomAliasSQL = `
|
||||
DELETE FROM roomserver_room_aliases WHERE alias = $1
|
||||
`
|
||||
|
||||
type roomAliasesStatements struct {
|
||||
db *sql.DB
|
||||
insertRoomAliasStmt *sql.Stmt
|
||||
selectRoomIDFromAliasStmt *sql.Stmt
|
||||
selectAliasesFromRoomIDStmt *sql.Stmt
|
||||
selectCreatorIDFromAliasStmt *sql.Stmt
|
||||
deleteRoomAliasStmt *sql.Stmt
|
||||
type RoomAliasCosmos struct {
|
||||
Alias string `json:"alias"`
|
||||
RoomID string `json:"room_id"`
|
||||
CreatorID string `json:"creator_id"`
|
||||
}
|
||||
|
||||
func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
|
||||
s := &roomAliasesStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(roomAliasesSchema)
|
||||
type RoomAliasCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
RoomAlias RoomAliasCosmos `json:"mx_roomserver_room_alias"`
|
||||
}
|
||||
|
||||
// const insertRoomAliasSQL = `
|
||||
// INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3)
|
||||
// `
|
||||
|
||||
// const selectRoomIDFromAliasSQL = `
|
||||
// SELECT room_id FROM roomserver_room_aliases WHERE alias = $1
|
||||
// `
|
||||
|
||||
// SELECT alias FROM roomserver_room_aliases WHERE room_id = $1
|
||||
const selectAliasesFromRoomIDSQL = `
|
||||
select * from c where c._cn = @x1 and c.mx_roomserver_room_alias.room_id = @x2
|
||||
`
|
||||
|
||||
// const selectCreatorIDFromAliasSQL = `
|
||||
// SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1
|
||||
// `
|
||||
|
||||
// const deleteRoomAliasSQL = `
|
||||
// DELETE FROM roomserver_room_aliases WHERE alias = $1
|
||||
// `
|
||||
|
||||
type roomAliasesStatements struct {
|
||||
db *Database
|
||||
// insertRoomAliasStmt *sql.Stmt
|
||||
// selectRoomIDFromAliasStmt string
|
||||
selectAliasesFromRoomIDStmt string
|
||||
// selectCreatorIDFromAliasStmt string
|
||||
// deleteRoomAliasStmt *sql.Stmt
|
||||
tableName string
|
||||
}
|
||||
|
||||
func queryRoomAlias(s *roomAliasesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]RoomAliasCosmosData, error) {
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []RoomAliasCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, shared.StatementList{
|
||||
{&s.insertRoomAliasStmt, insertRoomAliasSQL},
|
||||
{&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL},
|
||||
{&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL},
|
||||
{&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL},
|
||||
{&s.deleteRoomAliasStmt, deleteRoomAliasSQL},
|
||||
}.Prepare(db)
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func getRoomAlias(s *roomAliasesStatements, ctx context.Context, pk string, docId string) (*RoomAliasCosmosData, error) {
|
||||
response := RoomAliasCosmosData{}
|
||||
err := cosmosdbapi.GetDocumentOrNil(
|
||||
s.db.connection,
|
||||
s.db.cosmosConfig,
|
||||
ctx,
|
||||
pk,
|
||||
docId,
|
||||
&response)
|
||||
|
||||
if response.Id == "" {
|
||||
return nil, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
|
||||
return &response, err
|
||||
}
|
||||
|
||||
func NewCosmosDBRoomAliasesTable(db *Database) (tables.RoomAliases, error) {
|
||||
s := &roomAliasesStatements{
|
||||
db: db,
|
||||
}
|
||||
// _, err := db.Exec(roomAliasesSchema)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// return s, shared.StatementList{
|
||||
// {&s.insertRoomAliasStmt, insertRoomAliasSQL},
|
||||
// {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL},
|
||||
s.selectAliasesFromRoomIDStmt = selectAliasesFromRoomIDSQL
|
||||
// {&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL},
|
||||
// {&s.deleteRoomAliasStmt, deleteRoomAliasSQL},
|
||||
// }.Prepare(db)
|
||||
s.tableName = "room_aliases"
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *roomAliasesStatements) InsertRoomAlias(
|
||||
ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertRoomAliasStmt)
|
||||
_, err := stmt.ExecContext(ctx, alias, roomID, creatorUserID)
|
||||
|
||||
// INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3)
|
||||
data := RoomAliasCosmos{
|
||||
Alias: alias,
|
||||
CreatorID: creatorUserID,
|
||||
RoomID: roomID,
|
||||
}
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
|
||||
// alias TEXT NOT NULL PRIMARY KEY,
|
||||
docId := alias
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
var dbData = RoomAliasCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
RoomAlias: data,
|
||||
}
|
||||
|
||||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
&dbData,
|
||||
options)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
||||
ctx context.Context, alias string,
|
||||
) (roomID string, err error) {
|
||||
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
|
||||
if err == sql.ErrNoRows {
|
||||
|
||||
// SELECT room_id FROM roomserver_room_aliases WHERE alias = $1
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
|
||||
// alias TEXT NOT NULL PRIMARY KEY,
|
||||
docId := alias
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response, err := getRoomAlias(s, ctx, pk, cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
return "", nil
|
||||
}
|
||||
roomID = response.RoomAlias.RoomID
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -103,20 +204,23 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
|||
ctx context.Context, roomID string,
|
||||
) (aliases []string, err error) {
|
||||
aliases = []string{}
|
||||
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return
|
||||
|
||||
// SELECT alias FROM roomserver_room_aliases WHERE room_id = $1
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": roomID,
|
||||
}
|
||||
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed")
|
||||
response, err := queryRoomAlias(s, ctx, s.selectAliasesFromRoomIDStmt, params)
|
||||
|
||||
for rows.Next() {
|
||||
var alias string
|
||||
if err = rows.Scan(&alias); err != nil {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
aliases = append(aliases, alias)
|
||||
for _, item := range response {
|
||||
aliases = append(aliases, item.RoomAlias.Alias)
|
||||
}
|
||||
|
||||
return
|
||||
|
|
@ -125,17 +229,48 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
|||
func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
|
||||
ctx context.Context, alias string,
|
||||
) (creatorID string, err error) {
|
||||
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
|
||||
if err == sql.ErrNoRows {
|
||||
|
||||
// SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
|
||||
// alias TEXT NOT NULL PRIMARY KEY,
|
||||
docId := alias
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response, err := getRoomAlias(s, ctx, pk, cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
return "", nil
|
||||
}
|
||||
creatorID = response.RoomAlias.CreatorID
|
||||
return
|
||||
}
|
||||
|
||||
func (s *roomAliasesStatements) DeleteRoomAlias(
|
||||
ctx context.Context, txn *sql.Tx, alias string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteRoomAliasStmt)
|
||||
_, err := stmt.ExecContext(ctx, alias)
|
||||
|
||||
// DELETE FROM roomserver_room_aliases WHERE alias = $1
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
docId := alias
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
cosmosDocId,
|
||||
options)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
12
roomserver/storage/cosmosdb/room_seq.go
Normal file
12
roomserver/storage/cosmosdb/room_seq.go
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
)
|
||||
|
||||
func GetNextRoomNID(s *roomStatements, ctx context.Context) (int64, error) {
|
||||
const docId = "roomnid_seq"
|
||||
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1)
|
||||
}
|
||||
|
|
@ -18,128 +18,227 @@ package cosmosdb
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const roomsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_rooms (
|
||||
room_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
room_id TEXT NOT NULL UNIQUE,
|
||||
latest_event_nids TEXT NOT NULL DEFAULT '[]',
|
||||
last_event_sent_nid INTEGER NOT NULL DEFAULT 0,
|
||||
state_snapshot_nid INTEGER NOT NULL DEFAULT 0,
|
||||
room_version TEXT NOT NULL
|
||||
);
|
||||
`
|
||||
// const roomsSchema = `
|
||||
// CREATE TABLE IF NOT EXISTS roomserver_rooms (
|
||||
// room_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
// room_id TEXT NOT NULL UNIQUE,
|
||||
// latest_event_nids TEXT NOT NULL DEFAULT '[]',
|
||||
// last_event_sent_nid INTEGER NOT NULL DEFAULT 0,
|
||||
// state_snapshot_nid INTEGER NOT NULL DEFAULT 0,
|
||||
// room_version TEXT NOT NULL
|
||||
// );
|
||||
// `
|
||||
|
||||
// Same as insertEventTypeNIDSQL
|
||||
const insertRoomNIDSQL = `
|
||||
INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2)
|
||||
ON CONFLICT DO NOTHING;
|
||||
`
|
||||
|
||||
const selectRoomNIDSQL = "" +
|
||||
"SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
|
||||
|
||||
const selectLatestEventNIDsSQL = "" +
|
||||
"SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
|
||||
|
||||
const selectLatestEventNIDsForUpdateSQL = "" +
|
||||
"SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
|
||||
|
||||
const updateLatestEventNIDsSQL = "" +
|
||||
"UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4"
|
||||
|
||||
const selectRoomVersionsForRoomNIDsSQL = "" +
|
||||
"SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)"
|
||||
|
||||
const selectRoomInfoSQL = "" +
|
||||
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
|
||||
|
||||
const selectRoomIDsSQL = "" +
|
||||
"SELECT room_id FROM roomserver_rooms"
|
||||
|
||||
const bulkSelectRoomIDsSQL = "" +
|
||||
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
|
||||
|
||||
const bulkSelectRoomNIDsSQL = "" +
|
||||
"SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)"
|
||||
|
||||
type roomStatements struct {
|
||||
db *sql.DB
|
||||
insertRoomNIDStmt *sql.Stmt
|
||||
selectRoomNIDStmt *sql.Stmt
|
||||
selectLatestEventNIDsStmt *sql.Stmt
|
||||
selectLatestEventNIDsForUpdateStmt *sql.Stmt
|
||||
updateLatestEventNIDsStmt *sql.Stmt
|
||||
//selectRoomVersionForRoomNIDStmt *sql.Stmt
|
||||
selectRoomInfoStmt *sql.Stmt
|
||||
selectRoomIDsStmt *sql.Stmt
|
||||
type RoomCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
Room RoomCosmos `json:"mx_roomserver_room"`
|
||||
}
|
||||
|
||||
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||
type RoomCosmos struct {
|
||||
RoomNID int64 `json:"room_nid"`
|
||||
RoomID string `json:"room_id"`
|
||||
LatestEventNIDs []int64 `json:"latest_event_nids"`
|
||||
LastEventSentNID int64 `json:"last_event_sent_nid"`
|
||||
StateSnapshotNID int64 `json:"state_snapshot_nid"`
|
||||
RoomVersion string `json:"room_version"`
|
||||
}
|
||||
|
||||
// Same as insertEventTypeNIDSQL
|
||||
// const insertRoomNIDSQL = `
|
||||
// INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2)
|
||||
// ON CONFLICT DO NOTHING;
|
||||
// `
|
||||
|
||||
// "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
|
||||
// const selectRoomNIDSQL = "" +
|
||||
// "select * from c where c._cn = @x1 and c.mx_roomserver_room.room_nid = @x1"
|
||||
|
||||
// "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
|
||||
const selectLatestEventNIDsSQL = "" +
|
||||
"select * from c where c._cn = @x1 " +
|
||||
"and c.mx_roomserver_room.room_nid = @x2"
|
||||
|
||||
// "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
|
||||
const selectLatestEventNIDsForUpdateSQL = "" +
|
||||
"select * from c where c._cn = @x1 " +
|
||||
" and c.mx_roomserver_room.room_nid = @x2"
|
||||
|
||||
// const updateLatestEventNIDsSQL = "" +
|
||||
// "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4"
|
||||
|
||||
// "SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)"
|
||||
const selectRoomVersionsForRoomNIDsSQL = "" +
|
||||
"select * from c where c._cn = @x1 " +
|
||||
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_room.room_nid)"
|
||||
|
||||
// "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
|
||||
// const selectRoomInfoSQL = "" +
|
||||
// "select * from c where c._cn = @x1 and c.mx_roomserver_room.room_id = @x2"
|
||||
|
||||
// "SELECT room_id FROM roomserver_rooms"
|
||||
const selectRoomIDsSQL = "" +
|
||||
"select * from c where c._cn = @x1"
|
||||
|
||||
// "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
|
||||
const bulkSelectRoomIDsSQL = "" +
|
||||
"select * from c where c._cn = @x1 " +
|
||||
" and ARRAY_CONTAINS(@x2, c.mx_roomserver_room.room_nid)"
|
||||
|
||||
// "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)"
|
||||
const bulkSelectRoomNIDsSQL = "" +
|
||||
"select * from c where c._cn = @x1 " +
|
||||
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_room.room_nid)"
|
||||
|
||||
type roomStatements struct {
|
||||
db *Database
|
||||
// insertRoomNIDStmt *sql.Stmt
|
||||
// selectRoomNIDStmt string
|
||||
selectLatestEventNIDsStmt string
|
||||
selectLatestEventNIDsForUpdateStmt string
|
||||
updateLatestEventNIDsStmt string
|
||||
selectRoomVersionForRoomNIDStmt string
|
||||
// selectRoomInfoStmt *sql.Stmt
|
||||
selectRoomIDsStmt string
|
||||
tableName string
|
||||
}
|
||||
|
||||
func NewCosmosDBRoomsTable(db *Database) (tables.Rooms, error) {
|
||||
s := &roomStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(roomsSchema)
|
||||
// return s, shared.StatementList{
|
||||
// {&s.insertRoomNIDStmt, insertRoomNIDSQL},
|
||||
// {&s.selectRoomNIDStmt, selectRoomNIDSQL},
|
||||
s.selectLatestEventNIDsStmt = selectLatestEventNIDsSQL
|
||||
s.selectLatestEventNIDsForUpdateStmt = selectLatestEventNIDsForUpdateSQL
|
||||
// {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
||||
//{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL},
|
||||
// {&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
||||
s.selectRoomIDsStmt = selectRoomIDsSQL
|
||||
// }.Prepare(db)
|
||||
s.tableName = "rooms"
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func mapToRoomEventNIDArray(eventNIDs []int64) []types.EventNID {
|
||||
result := []types.EventNID{}
|
||||
for i := 0; i < len(eventNIDs); i++ {
|
||||
result = append(result, types.EventNID(eventNIDs[i]))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func queryRoom(s *roomStatements, ctx context.Context, qry string, params map[string]interface{}) ([]RoomCosmosData, error) {
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []RoomCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, shared.StatementList{
|
||||
{&s.insertRoomNIDStmt, insertRoomNIDSQL},
|
||||
{&s.selectRoomNIDStmt, selectRoomNIDSQL},
|
||||
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
|
||||
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
|
||||
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
||||
//{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL},
|
||||
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
||||
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
|
||||
}.Prepare(db)
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func getRoom(s *roomStatements, ctx context.Context, pk string, docId string) (*RoomCosmosData, error) {
|
||||
response := RoomCosmosData{}
|
||||
err := cosmosdbapi.GetDocumentOrNil(
|
||||
s.db.connection,
|
||||
s.db.cosmosConfig,
|
||||
ctx,
|
||||
pk,
|
||||
docId,
|
||||
&response)
|
||||
|
||||
if response.Id == "" {
|
||||
return nil, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
|
||||
return &response, err
|
||||
}
|
||||
|
||||
func setRoom(s *roomStatements, ctx context.Context, pk string, room RoomCosmosData) (*RoomCosmosData, error) {
|
||||
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, room.ETag)
|
||||
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
room.Id,
|
||||
&room,
|
||||
optionsReplace)
|
||||
return &room, ex
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
|
||||
rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
|
||||
|
||||
// "SELECT room_id FROM roomserver_rooms"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
}
|
||||
|
||||
response, err := queryRoom(s, ctx, s.selectRoomIDsStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
|
||||
|
||||
var roomIDs []string
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
if err = rows.Scan(&roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomIDs = append(roomIDs, roomID)
|
||||
for _, item := range response {
|
||||
roomIDs = append(roomIDs, item.Room.RoomID)
|
||||
}
|
||||
return roomIDs, nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
||||
var info types.RoomInfo
|
||||
var latestNIDsJSON string
|
||||
err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan(
|
||||
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON,
|
||||
)
|
||||
info := types.RoomInfo{}
|
||||
|
||||
// "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
// room_id TEXT NOT NULL UNIQUE,
|
||||
docId := roomID
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
room, err := getRoom(s, ctx, pk, cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
if err == cosmosdbutil.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
var latestNIDs []int64
|
||||
if err = json.Unmarshal([]byte(latestNIDsJSON), &latestNIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info.IsStub = len(latestNIDs) == 0
|
||||
|
||||
info.RoomVersion = gomatrixserverlib.RoomVersion(room.Room.RoomVersion)
|
||||
info.RoomNID = types.RoomNID(room.Room.RoomNID)
|
||||
info.StateSnapshotNID = types.StateSnapshotNID(room.Room.StateSnapshotNID)
|
||||
info.IsStub = len(room.Room.LatestEventNIDs) == 0
|
||||
return &info, err
|
||||
}
|
||||
|
||||
|
|
@ -147,60 +246,135 @@ func (s *roomStatements) InsertRoomNID(
|
|||
ctx context.Context, txn *sql.Tx,
|
||||
roomID string, roomVersion gomatrixserverlib.RoomVersion,
|
||||
) (roomNID types.RoomNID, err error) {
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
|
||||
_, err = insertStmt.ExecContext(ctx, roomID, roomVersion)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insertStmt.ExecContext: %w", err)
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
|
||||
// INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2)
|
||||
// ON CONFLICT DO NOTHING;
|
||||
// room_id TEXT NOT NULL UNIQUE,
|
||||
docId := roomID
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
dbData, errGet := getRoom(s, ctx, pk, cosmosDocId)
|
||||
|
||||
if errGet == cosmosdbutil.ErrNoRows {
|
||||
// room_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
roomNIDSeq, seqErr := GetNextRoomNID(s, ctx)
|
||||
if seqErr != nil {
|
||||
return 0, seqErr
|
||||
}
|
||||
|
||||
data := RoomCosmos{
|
||||
RoomNID: int64(roomNIDSeq),
|
||||
RoomID: roomID,
|
||||
RoomVersion: string(roomVersion),
|
||||
}
|
||||
|
||||
dbData = &RoomCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Room: data,
|
||||
}
|
||||
}
|
||||
roomNID, err = s.SelectRoomNID(ctx, txn, roomID)
|
||||
|
||||
// ON CONFLICT DO NOTHING; - Do Upsert
|
||||
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
&dbData,
|
||||
options)
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("s.SelectRoomNID: %w", err)
|
||||
}
|
||||
|
||||
roomNID = types.RoomNID(dbData.Room.RoomNID)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomNID(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (types.RoomNID, error) {
|
||||
var roomNID int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomNIDStmt)
|
||||
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
|
||||
return types.RoomNID(roomNID), err
|
||||
|
||||
// "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
// room_id TEXT NOT NULL UNIQUE,
|
||||
docId := roomID
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
room, err := getRoom(s, ctx, pk, cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if room == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return types.RoomNID(room.Room.RoomNID), err
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectLatestEventNIDs(
|
||||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
|
||||
) ([]types.EventNID, types.StateSnapshotNID, error) {
|
||||
var eventNIDs []types.EventNID
|
||||
var nidsJSON string
|
||||
var stateSnapshotNID int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt)
|
||||
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nidsJSON, &stateSnapshotNID)
|
||||
|
||||
// "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": roomNID,
|
||||
}
|
||||
|
||||
response, err := queryRoom(s, ctx, s.selectLatestEventNIDsStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if err := json.Unmarshal([]byte(nidsJSON), &eventNIDs); err != nil {
|
||||
return nil, 0, err
|
||||
|
||||
// TODO: Check the error handling
|
||||
if len(response) == 0 {
|
||||
return nil, 0, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil
|
||||
|
||||
//Assume 1 per RoomNID
|
||||
room := response[0]
|
||||
return mapToRoomEventNIDArray(room.Room.LatestEventNIDs), types.StateSnapshotNID(room.Room.StateSnapshotNID), nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectLatestEventsNIDsForUpdate(
|
||||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
|
||||
) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) {
|
||||
var eventNIDs []types.EventNID
|
||||
var nidsJSON string
|
||||
var lastEventSentNID int64
|
||||
var stateSnapshotNID int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt)
|
||||
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nidsJSON, &lastEventSentNID, &stateSnapshotNID)
|
||||
|
||||
// "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": roomNID,
|
||||
}
|
||||
|
||||
response, err := queryRoom(s, ctx, s.selectLatestEventNIDsForUpdateStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
if err := json.Unmarshal([]byte(nidsJSON), &eventNIDs); err != nil {
|
||||
return nil, 0, 0, err
|
||||
|
||||
// TODO: Check the error handling
|
||||
if len(response) == 0 {
|
||||
return nil, 0, 0, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil
|
||||
|
||||
//Assume 1 per RoomNID
|
||||
room := response[0]
|
||||
return mapToRoomEventNIDArray(room.Room.LatestEventNIDs), types.EventNID(room.Room.LastEventSentNID), types.StateSnapshotNID(room.Room.StateSnapshotNID), nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) UpdateLatestEventNIDs(
|
||||
|
|
@ -211,86 +385,113 @@ func (s *roomStatements) UpdateLatestEventNIDs(
|
|||
lastEventSentNID types.EventNID,
|
||||
stateSnapshotNID types.StateSnapshotNID,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
eventNIDsAsArray(eventNIDs),
|
||||
int64(lastEventSentNID),
|
||||
int64(stateSnapshotNID),
|
||||
roomNID,
|
||||
)
|
||||
|
||||
// "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": roomNID,
|
||||
}
|
||||
|
||||
response, err := queryRoom(s, ctx, s.selectLatestEventNIDsForUpdateStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: Check the error handling
|
||||
if len(response) == 0 {
|
||||
return cosmosdbutil.ErrNoRows
|
||||
}
|
||||
|
||||
//Assume 1 per RoomNID
|
||||
room := response[0]
|
||||
|
||||
room.Room.LatestEventNIDs = mapFromEventNIDArray(eventNIDs)
|
||||
room.Room.LastEventSentNID = int64(lastEventSentNID)
|
||||
room.Room.StateSnapshotNID = int64(stateSnapshotNID)
|
||||
|
||||
_, err = setRoom(s, ctx, room.Pk, room)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
||||
ctx context.Context, roomNIDs []types.RoomNID,
|
||||
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
|
||||
sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||
sqlPrep, err := s.db.Prepare(sqlStr)
|
||||
if roomNIDs == nil || len(roomNIDs) == 0 {
|
||||
return make(map[types.RoomNID]gomatrixserverlib.RoomVersion), nil
|
||||
}
|
||||
|
||||
// "SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": roomNIDs,
|
||||
}
|
||||
|
||||
response, err := queryRoom(s, ctx, selectRoomVersionsForRoomNIDsSQL, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||
for i, v := range roomNIDs {
|
||||
iRoomNIDs[i] = v
|
||||
}
|
||||
rows, err := sqlPrep.QueryContext(ctx, iRoomNIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
|
||||
|
||||
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
|
||||
for rows.Next() {
|
||||
var roomNID types.RoomNID
|
||||
var roomVersion gomatrixserverlib.RoomVersion
|
||||
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[roomNID] = roomVersion
|
||||
for _, item := range response {
|
||||
result[types.RoomNID(item.Room.RoomNID)] = gomatrixserverlib.RoomVersion(item.Room.RoomVersion)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
|
||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||
for i, v := range roomNIDs {
|
||||
iRoomNIDs[i] = v
|
||||
if roomNIDs == nil || len(roomNIDs) == 0 {
|
||||
return []string{}, nil
|
||||
}
|
||||
sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...)
|
||||
|
||||
// "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": roomNIDs,
|
||||
}
|
||||
|
||||
response, err := queryRoom(s, ctx, bulkSelectRoomIDsSQL, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
|
||||
|
||||
var roomIDs []string
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
if err = rows.Scan(&roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomIDs = append(roomIDs, roomID)
|
||||
for _, item := range response {
|
||||
roomIDs = append(roomIDs, item.Room.RoomID)
|
||||
}
|
||||
return roomIDs, nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
|
||||
iRoomIDs := make([]interface{}, len(roomIDs))
|
||||
for i, v := range roomIDs {
|
||||
iRoomIDs[i] = v
|
||||
if roomIDs == nil || len(roomIDs) == 0 {
|
||||
return []types.RoomNID{}, nil
|
||||
}
|
||||
sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
|
||||
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...)
|
||||
|
||||
// "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": roomIDs,
|
||||
}
|
||||
|
||||
response, err := queryRoom(s, ctx, bulkSelectRoomNIDsSQL, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
|
||||
|
||||
var roomNIDs []types.RoomNID
|
||||
for rows.Next() {
|
||||
var roomNID types.RoomNID
|
||||
if err = rows.Scan(&roomNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomNIDs = append(roomNIDs, roomNID)
|
||||
for _, item := range response {
|
||||
roomNIDs = append(roomNIDs, types.RoomNID(item.Room.RoomNID))
|
||||
}
|
||||
return roomNIDs, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,33 +20,54 @@ import (
|
|||
"database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
const stateDataSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_state_block (
|
||||
state_block_nid INTEGER NOT NULL,
|
||||
event_type_nid INTEGER NOT NULL,
|
||||
event_state_key_nid INTEGER NOT NULL,
|
||||
event_nid INTEGER NOT NULL,
|
||||
UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
|
||||
);
|
||||
`
|
||||
// const stateDataSchema = `
|
||||
// CREATE TABLE IF NOT EXISTS roomserver_state_block (
|
||||
// state_block_nid INTEGER NOT NULL,
|
||||
// event_type_nid INTEGER NOT NULL,
|
||||
// event_state_key_nid INTEGER NOT NULL,
|
||||
// event_nid INTEGER NOT NULL,
|
||||
// UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
|
||||
// );
|
||||
// `
|
||||
|
||||
const insertStateDataSQL = "" +
|
||||
"INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
|
||||
" VALUES ($1, $2, $3, $4)"
|
||||
type StateBlockCosmos struct {
|
||||
StateBlockNID int64 `json:"state_block_nid"`
|
||||
EventTypeNID int64 `json:"event_type_nid"`
|
||||
EventStateKeyNID int64 `json:"event_state_key_nid"`
|
||||
EventNID int64 `json:"event_nid"`
|
||||
}
|
||||
|
||||
const selectNextStateBlockNIDSQL = `
|
||||
SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block
|
||||
`
|
||||
type StateBlockCosmosMaxNID struct {
|
||||
Max int64 `json:"maxstateblocknid"`
|
||||
}
|
||||
|
||||
type StateBlockCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
StateBlock StateBlockCosmos `json:"mx_roomserver_state_block"`
|
||||
}
|
||||
|
||||
// const insertStateDataSQL = "" +
|
||||
// "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
|
||||
// " VALUES ($1, $2, $3, $4)"
|
||||
|
||||
// SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block
|
||||
const selectNextStateBlockNIDSQL = "" +
|
||||
"select sub.maxinner != null ? sub.maxinner + 1 : 1 as maxstateblocknid " +
|
||||
"from " +
|
||||
"(select MAX(c.mx_roomserver_state_block.state_block_nid) maxinner from c where c._cn = @x1) as sub"
|
||||
|
||||
// Bulk state lookup by numeric state block ID.
|
||||
// Sort by the state_block_nid, event_type_nid, event_state_key_nid
|
||||
|
|
@ -54,10 +75,17 @@ SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block
|
|||
// together in the list and those entries will sorted by event_type_nid
|
||||
// and event_state_key_nid. This property makes it easier to merge two
|
||||
// state data blocks together.
|
||||
// "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
||||
// " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
|
||||
// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
||||
const bulkSelectStateBlockEntriesSQL = "" +
|
||||
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
||||
" FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
|
||||
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
||||
"select * from c where c._cn = @x1 " +
|
||||
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_state_block.state_block_nid) " +
|
||||
"order by c.mx_roomserver_state_block.state_block_nid " +
|
||||
// Cant do multi field order by - The order by query does not have a corresponding composite index that it can be served from
|
||||
// ", c.mx_roomserver_state_block.event_type_nid " +
|
||||
// ", c.mx_roomserver_state_block.event_state_key_nid " +
|
||||
" asc"
|
||||
|
||||
// Bulk state lookup by numeric state block ID.
|
||||
// Filters the rows in each block to the requested types and state keys.
|
||||
|
|
@ -66,35 +94,126 @@ const bulkSelectStateBlockEntriesSQL = "" +
|
|||
// of types and a list state_keys. So we have to filter the result in the
|
||||
// application to restrict it to the list of event types and state keys we
|
||||
// actually wanted.
|
||||
// "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
||||
// " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
|
||||
// " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" +
|
||||
// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
||||
const bulkSelectFilteredStateBlockEntriesSQL = "" +
|
||||
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
||||
" FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
|
||||
" AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" +
|
||||
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
||||
"select * from c where c._cn = @x1 " +
|
||||
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_state_block.state_block_nid) " +
|
||||
"and ARRAY_CONTAINS(@x3, c.mx_roomserver_state_block.event_type_nid) " +
|
||||
"and ARRAY_CONTAINS(@x4, c.mx_roomserver_state_block.event_state_key_nid) " +
|
||||
"order by c.mx_roomserver_state_block.state_block_nid " +
|
||||
// Cant do multi field order by - The order by query does not have a corresponding composite index that it can be served from
|
||||
// ", c.mx_roomserver_state_block.event_type_nid " +
|
||||
// ", c.mx_roomserver_state_block.event_state_key_nid " +
|
||||
"asc"
|
||||
|
||||
type stateBlockStatements struct {
|
||||
db *sql.DB
|
||||
insertStateDataStmt *sql.Stmt
|
||||
selectNextStateBlockNIDStmt *sql.Stmt
|
||||
bulkSelectStateBlockEntriesStmt *sql.Stmt
|
||||
bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
|
||||
db *Database
|
||||
// insertStateDataStmt *sql.Stmt
|
||||
selectNextStateBlockNIDStmt string
|
||||
bulkSelectStateBlockEntriesStmt string
|
||||
bulkSelectFilteredStateBlockEntriesStmt string
|
||||
tableName string
|
||||
}
|
||||
|
||||
func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
||||
s := &stateBlockStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(stateDataSchema)
|
||||
func queryStateBlock(s *stateBlockStatements, ctx context.Context, qry string, params map[string]interface{}) ([]StateBlockCosmosData, error) {
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []StateBlockCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertStateDataStmt, insertStateDataSQL},
|
||||
{&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
|
||||
{&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL},
|
||||
{&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL},
|
||||
}.Prepare(db)
|
||||
func NewCosmosDBStateBlockTable(db *Database) (tables.StateBlock, error) {
|
||||
s := &stateBlockStatements{
|
||||
db: db,
|
||||
}
|
||||
|
||||
// return s, shared.StatementList{
|
||||
// {&s.insertStateDataStmt, insertStateDataSQL},
|
||||
s.selectNextStateBlockNIDStmt = selectNextStateBlockNIDSQL
|
||||
s.bulkSelectStateBlockEntriesStmt = bulkSelectStateBlockEntriesSQL
|
||||
s.bulkSelectFilteredStateBlockEntriesStmt = bulkSelectFilteredStateBlockEntriesSQL
|
||||
// }.Prepare(db)
|
||||
s.tableName = "state_block"
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func inertStateBlockCore(s *stateBlockStatements, ctx context.Context, stateBlockNID types.StateBlockNID, entry types.StateEntry) error {
|
||||
|
||||
// "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
|
||||
// " VALUES ($1, $2, $3, $4)"
|
||||
data := StateBlockCosmos{
|
||||
EventNID: int64(entry.EventNID),
|
||||
EventStateKeyNID: int64(entry.EventStateKeyNID),
|
||||
EventTypeNID: int64(entry.EventTypeNID),
|
||||
StateBlockNID: int64(stateBlockNID),
|
||||
}
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
|
||||
// UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
|
||||
docId := fmt.Sprintf("%d_%d_%d", data.StateBlockNID, data.EventTypeNID, data.EventStateKeyNID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
var dbData = StateBlockCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
StateBlock: data,
|
||||
}
|
||||
|
||||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
&dbData,
|
||||
options)
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
func getNextStateBlockNID(s *stateBlockStatements, ctx context.Context) (int64, error) {
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var stateBlockNext []StateBlockCosmosMaxNID
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
}
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
|
||||
var query = cosmosdbapi.GetQuery(s.selectNextStateBlockNIDStmt, params)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&stateBlockNext,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return stateBlockNext[0].Max, nil
|
||||
}
|
||||
|
||||
func (s *stateBlockStatements) BulkInsertStateData(
|
||||
|
|
@ -104,75 +223,64 @@ func (s *stateBlockStatements) BulkInsertStateData(
|
|||
if len(entries) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
var stateBlockNID types.StateBlockNID
|
||||
err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
||||
nextID, errNextID := getNextStateBlockNID(s, ctx)
|
||||
if errNextID != nil {
|
||||
return 0, errNextID
|
||||
}
|
||||
|
||||
stateBlockNID := types.StateBlockNID(nextID)
|
||||
|
||||
for _, entry := range entries {
|
||||
_, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext(
|
||||
ctx,
|
||||
int64(stateBlockNID),
|
||||
int64(entry.EventTypeNID),
|
||||
int64(entry.EventStateKeyNID),
|
||||
int64(entry.EventNID),
|
||||
)
|
||||
err := inertStateBlockCore(s, ctx, stateBlockNID, entry)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return stateBlockNID, err
|
||||
return stateBlockNID, nil
|
||||
}
|
||||
|
||||
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
||||
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
||||
) ([]types.StateEntryList, error) {
|
||||
nids := make([]interface{}, len(stateBlockNIDs))
|
||||
for k, v := range stateBlockNIDs {
|
||||
nids[k] = v
|
||||
|
||||
// "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
||||
// " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
|
||||
// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var response []StateBlockCosmosData
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": stateBlockNIDs,
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
|
||||
response, err := queryStateBlock(s, ctx, s.bulkSelectStateBlockEntriesStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := selectStmt.QueryContext(ctx, nids...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed")
|
||||
|
||||
results := make([]types.StateEntryList, len(stateBlockNIDs))
|
||||
// current is a pointer to the StateEntryList to append the state entries to.
|
||||
var current *types.StateEntryList
|
||||
i := 0
|
||||
for rows.Next() {
|
||||
var (
|
||||
stateBlockNID int64
|
||||
eventTypeNID int64
|
||||
eventStateKeyNID int64
|
||||
eventNID int64
|
||||
entry types.StateEntry
|
||||
)
|
||||
if err := rows.Scan(
|
||||
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
|
||||
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
|
||||
entry.EventNID = types.EventNID(eventNID)
|
||||
if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
|
||||
for _, item := range response {
|
||||
entry := types.StateEntry{}
|
||||
entry.EventTypeNID = types.EventTypeNID(item.StateBlock.EventTypeNID)
|
||||
entry.EventStateKeyNID = types.EventStateKeyNID(item.StateBlock.EventStateKeyNID)
|
||||
entry.EventNID = types.EventNID(item.StateBlock.EventNID)
|
||||
|
||||
if current == nil || types.StateBlockNID(item.StateBlock.StateBlockNID) != current.StateBlockNID {
|
||||
// The state entry row is for a different state data block to the current one.
|
||||
// So we start appending to the next entry in the list.
|
||||
current = &results[i]
|
||||
current.StateBlockNID = types.StateBlockNID(stateBlockNID)
|
||||
current.StateBlockNID = types.StateBlockNID(item.StateBlock.StateBlockNID)
|
||||
i++
|
||||
}
|
||||
current.StateEntries = append(current.StateEntries, entry)
|
||||
}
|
||||
if i != len(nids) {
|
||||
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids))
|
||||
if i != len(stateBlockNIDs) {
|
||||
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs))
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
|
@ -187,34 +295,33 @@ func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
|
|||
sort.Sort(tuples)
|
||||
|
||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
||||
sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1)
|
||||
sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
|
||||
sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
|
||||
// sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1)
|
||||
// sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
|
||||
// sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
|
||||
|
||||
var params []interface{}
|
||||
for _, val := range stateBlockNIDs {
|
||||
params = append(params, int64(val))
|
||||
}
|
||||
for _, val := range eventTypeNIDArray {
|
||||
params = append(params, val)
|
||||
}
|
||||
for _, val := range eventStateKeyNIDArray {
|
||||
params = append(params, val)
|
||||
// "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
||||
// " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
|
||||
// " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" +
|
||||
// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var response []StateBlockCosmosData
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": stateBlockNIDs,
|
||||
"@x3": eventTypeNIDArray,
|
||||
"@x4": eventStateKeyNIDArray,
|
||||
}
|
||||
|
||||
rows, err := s.db.QueryContext(
|
||||
ctx,
|
||||
sqlStatement,
|
||||
params...,
|
||||
)
|
||||
response, err := queryStateBlock(s, ctx, s.bulkSelectFilteredStateBlockEntriesStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed")
|
||||
|
||||
var results []types.StateEntryList
|
||||
var current types.StateEntryList
|
||||
for rows.Next() {
|
||||
for _, item := range response {
|
||||
var (
|
||||
stateBlockNID int64
|
||||
eventTypeNID int64
|
||||
|
|
@ -222,11 +329,10 @@ func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
|
|||
eventNID int64
|
||||
entry types.StateEntry
|
||||
)
|
||||
if err := rows.Scan(
|
||||
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stateBlockNID = item.StateBlock.StateBlockNID
|
||||
eventTypeNID = item.StateBlock.EventTypeNID
|
||||
eventStateKeyNID = item.StateBlock.EventStateKeyNID
|
||||
eventNID = item.StateBlock.EventNID
|
||||
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
|
||||
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
|
||||
entry.EventNID = types.EventNID(eventNID)
|
||||
|
|
|
|||
12
roomserver/storage/cosmosdb/state_snapshot_seq.go
Normal file
12
roomserver/storage/cosmosdb/state_snapshot_seq.go
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
)
|
||||
|
||||
func GetNextStateSnapshotNID(s *stateSnapshotStatements, ctx context.Context) (int64, error) {
|
||||
const docId = "statesnapshotnid_seq"
|
||||
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1)
|
||||
}
|
||||
|
|
@ -18,106 +18,169 @@ package cosmosdb
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
const stateSnapshotSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
|
||||
state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
room_nid INTEGER NOT NULL,
|
||||
state_block_nids TEXT NOT NULL DEFAULT '[]'
|
||||
);
|
||||
`
|
||||
// const stateSnapshotSchema = `
|
||||
// CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
|
||||
// state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
// room_nid INTEGER NOT NULL,
|
||||
// state_block_nids TEXT NOT NULL DEFAULT '[]'
|
||||
// );
|
||||
// `
|
||||
|
||||
const insertStateSQL = `
|
||||
INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
|
||||
VALUES ($1, $2);`
|
||||
type StateSnapshotCosmos struct {
|
||||
StateSnapshotNID int64 `json:"state_snapshot_nid"`
|
||||
RoomNID int64 `json:"room_nid"`
|
||||
StateBlockNIDs []int64 `json:"state_block_nids"`
|
||||
}
|
||||
|
||||
type StateSnapshotCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
StateSnapshot StateSnapshotCosmos `json:"mx_roomserver_state_snapshot"`
|
||||
}
|
||||
|
||||
// const insertStateSQL = `
|
||||
// INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
|
||||
// VALUES ($1, $2);`
|
||||
|
||||
// Bulk state data NID lookup.
|
||||
// Sorting by state_snapshot_nid means we can use binary search over the result
|
||||
// to lookup the state data NIDs for a state snapshot NID.
|
||||
// "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
|
||||
// " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
|
||||
const bulkSelectStateBlockNIDsSQL = "" +
|
||||
"SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
|
||||
" WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
|
||||
"select * from c where c._cn = @x1 " +
|
||||
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_state_snapshot.state_snapshot_nid) " +
|
||||
"order by c.mx_roomserver_state_snapshot.state_snapshot_nid asc"
|
||||
|
||||
type stateSnapshotStatements struct {
|
||||
db *sql.DB
|
||||
insertStateStmt *sql.Stmt
|
||||
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
||||
db *Database
|
||||
// insertStateStmt *sql.Stmt
|
||||
bulkSelectStateBlockNIDsStmt string
|
||||
tableName string
|
||||
}
|
||||
|
||||
func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||
func mapFromStateBlockNIDArray(stateBlockNIDs []types.StateBlockNID) []int64 {
|
||||
result := []int64{}
|
||||
for i := 0; i < len(stateBlockNIDs); i++ {
|
||||
result = append(result, int64(stateBlockNIDs[i]))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func mapToStateBlockNIDArray(stateBlockNIDs []int64) []types.StateBlockNID {
|
||||
result := []types.StateBlockNID{}
|
||||
for i := 0; i < len(stateBlockNIDs); i++ {
|
||||
result = append(result, types.StateBlockNID(stateBlockNIDs[i]))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func NewCosmosDBStateSnapshotTable(db *Database) (tables.StateSnapshot, error) {
|
||||
s := &stateSnapshotStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(stateSnapshotSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertStateStmt, insertStateSQL},
|
||||
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
|
||||
}.Prepare(db)
|
||||
// return s, shared.StatementList{
|
||||
// {&s.insertStateStmt, insertStateSQL},
|
||||
s.bulkSelectStateBlockNIDsStmt = bulkSelectStateBlockNIDsSQL
|
||||
// }.Prepare(db)
|
||||
s.tableName = "state_snapshots"
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *stateSnapshotStatements) InsertState(
|
||||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
|
||||
) (stateNID types.StateSnapshotNID, err error) {
|
||||
stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs)
|
||||
if err != nil {
|
||||
return
|
||||
|
||||
// INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
|
||||
// VALUES ($1, $2);`
|
||||
stateSnapshotNIDSeq, seqErr := GetNextStateSnapshotNID(s, ctx)
|
||||
if seqErr != nil {
|
||||
return 0, seqErr
|
||||
}
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
|
||||
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
|
||||
|
||||
data := StateSnapshotCosmos{
|
||||
RoomNID: int64(roomNID),
|
||||
StateBlockNIDs: mapFromStateBlockNIDArray(stateBlockNIDs),
|
||||
StateSnapshotNID: int64(stateSnapshotNIDSeq),
|
||||
}
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
|
||||
// state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
docId := fmt.Sprintf("%d", stateSnapshotNIDSeq)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
var dbData = StateSnapshotCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
StateSnapshot: data,
|
||||
}
|
||||
|
||||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
&dbData,
|
||||
options)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
lastRowID, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
stateNID = types.StateSnapshotNID(lastRowID)
|
||||
|
||||
stateNID = types.StateSnapshotNID(stateSnapshotNIDSeq)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||
) ([]types.StateBlockNIDList, error) {
|
||||
nids := make([]interface{}, len(stateNIDs))
|
||||
for k, v := range stateNIDs {
|
||||
nids[k] = v
|
||||
|
||||
// "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
|
||||
// " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []StateSnapshotCosmosData
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": stateNIDs,
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.bulkSelectStateBlockNIDsStmt, params)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := selectStmt.QueryContext(ctx, nids...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed")
|
||||
results := make([]types.StateBlockNIDList, len(stateNIDs))
|
||||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
for _, item := range response {
|
||||
result := &results[i]
|
||||
var stateBlockNIDsJSON string
|
||||
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := json.Unmarshal([]byte(stateBlockNIDsJSON), &result.StateBlockNIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.StateSnapshotNID = types.StateSnapshotNID(item.StateSnapshot.StateSnapshotNID)
|
||||
result.StateBlockNIDs = mapToStateBlockNIDArray(item.StateSnapshot.StateBlockNIDs)
|
||||
i++
|
||||
}
|
||||
if i != len(stateNIDs) {
|
||||
return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))
|
||||
|
|
|
|||
|
|
@ -17,14 +17,15 @@ package cosmosdb
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -33,16 +34,26 @@ import (
|
|||
// A Database is used to store room events and stream offsets.
|
||||
type Database struct {
|
||||
shared.Database
|
||||
connection cosmosdbapi.CosmosConnection
|
||||
databaseName string
|
||||
cosmosConfig cosmosdbapi.CosmosConfig
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
// Open a sqlite database.
|
||||
func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
|
||||
var d Database
|
||||
var db *sql.DB
|
||||
var err error
|
||||
if db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
return nil, err
|
||||
conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
|
||||
config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
|
||||
d := &Database{
|
||||
databaseName: "roomserver",
|
||||
connection: conn,
|
||||
cosmosConfig: config,
|
||||
}
|
||||
// var db *sql.DB
|
||||
// var err error
|
||||
// if db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
//db.Exec("PRAGMA journal_mode=WAL;")
|
||||
//db.Exec("PRAGMA read_uncommitted = true;")
|
||||
|
|
@ -51,89 +62,91 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches)
|
|||
// cause the roomserver to be unresponsive to new events because something will
|
||||
// acquire the global mutex and never unlock it because it is waiting for a connection
|
||||
// which it will never obtain.
|
||||
db.SetMaxOpenConns(20)
|
||||
// db.SetMaxOpenConns(20)
|
||||
|
||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
||||
ms := membershipStatements{}
|
||||
if err := ms.execSchema(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := sqlutil.NewMigrations()
|
||||
deltas.LoadAddForgottenColumn(m)
|
||||
if err := m.RunDeltas(db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := d.prepare(db, cache); err != nil {
|
||||
// ms := membershipStatements{}
|
||||
// if err := ms.execSchema(db); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// m := sqlutil.NewMigrations()
|
||||
// deltas.LoadAddForgottenColumn(m)
|
||||
// if err := m.RunDeltas(db, dbProperties); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
if err := d.prepare(cache); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &d, nil
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// nolint: gocyclo
|
||||
func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
|
||||
func (d *Database) prepare(cache caching.RoomServerCaches) error {
|
||||
var err error
|
||||
eventStateKeys, err := NewSqliteEventStateKeysTable(db)
|
||||
d.databaseName = "roomserver"
|
||||
eventStateKeys, err := NewCosmosDBEventStateKeysTable(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
eventTypes, err := NewSqliteEventTypesTable(db)
|
||||
eventTypes, err := NewCosmosDBEventTypesTable(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
eventJSON, err := NewSqliteEventJSONTable(db)
|
||||
eventJSON, err := NewCosmosDBEventJSONTable(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
events, err := NewSqliteEventsTable(db)
|
||||
events, err := NewCosmosDBEventsTable(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rooms, err := NewSqliteRoomsTable(db)
|
||||
rooms, err := NewCosmosDBRoomsTable(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
transactions, err := NewSqliteTransactionsTable(db)
|
||||
transactions, err := NewCosmosDBTransactionsTable(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stateBlock, err := NewSqliteStateBlockTable(db)
|
||||
stateBlock, err := NewCosmosDBStateBlockTable(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stateSnapshot, err := NewSqliteStateSnapshotTable(db)
|
||||
stateSnapshot, err := NewCosmosDBStateSnapshotTable(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
prevEvents, err := NewSqlitePrevEventsTable(db)
|
||||
prevEvents, err := NewCosmosDBPrevEventsTable(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
roomAliases, err := NewSqliteRoomAliasesTable(db)
|
||||
roomAliases, err := NewCosmosDBRoomAliasesTable(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
invites, err := NewSqliteInvitesTable(db)
|
||||
invites, err := NewCosmosDBInvitesTable(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
membership, err := NewSqliteMembershipTable(db)
|
||||
membership, err := NewCosmosDBMembershipTable(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
published, err := NewSqlitePublishedTable(db)
|
||||
published, err := NewCosmosDBPublishedTable(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
redactions, err := NewSqliteRedactionsTable(db)
|
||||
redactions, err := NewCosmosDBRedactionsTable(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.Database = shared.Database{
|
||||
DB: db,
|
||||
Cache: cache,
|
||||
Writer: sqlutil.NewExclusiveWriter(),
|
||||
DB: nil,
|
||||
Cache: cache,
|
||||
//Use the Fake SQL Writer here
|
||||
Writer: cosmosdbutil.NewExclusiveWriterFake(),
|
||||
EventsTable: events,
|
||||
EventTypesTable: eventTypes,
|
||||
EventStateKeysTable: eventStateKeys,
|
||||
|
|
|
|||
|
|
@ -18,50 +18,84 @@ package cosmosdb
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
)
|
||||
|
||||
const transactionsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_transactions (
|
||||
transaction_id TEXT NOT NULL,
|
||||
session_id INTEGER NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
PRIMARY KEY (transaction_id, session_id, user_id)
|
||||
);
|
||||
`
|
||||
const insertTransactionSQL = `
|
||||
INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
`
|
||||
// const transactionsSchema = `
|
||||
// CREATE TABLE IF NOT EXISTS roomserver_transactions (
|
||||
// transaction_id TEXT NOT NULL,
|
||||
// session_id INTEGER NOT NULL,
|
||||
// user_id TEXT NOT NULL,
|
||||
// event_id TEXT NOT NULL,
|
||||
// PRIMARY KEY (transaction_id, session_id, user_id)
|
||||
// );
|
||||
// `
|
||||
|
||||
const selectTransactionEventIDSQL = `
|
||||
SELECT event_id FROM roomserver_transactions
|
||||
WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3
|
||||
`
|
||||
|
||||
type transactionStatements struct {
|
||||
db *sql.DB
|
||||
insertTransactionStmt *sql.Stmt
|
||||
selectTransactionEventIDStmt *sql.Stmt
|
||||
type TransactionCosmos struct {
|
||||
TransactionID string `json:"transaction_id"`
|
||||
SessionID int64 `json:"session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
EventID string `json:"event_id"`
|
||||
}
|
||||
|
||||
func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) {
|
||||
type TransactionCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
Transaction TransactionCosmos `json:"mx_roomserver_transaction"`
|
||||
}
|
||||
|
||||
// const insertTransactionSQL = `
|
||||
// INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)
|
||||
// VALUES ($1, $2, $3, $4)
|
||||
// `
|
||||
|
||||
// const selectTransactionEventIDSQL = `
|
||||
// SELECT event_id FROM roomserver_transactions
|
||||
// WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3
|
||||
// `
|
||||
|
||||
type transactionStatements struct {
|
||||
db *Database
|
||||
// insertTransactionStmt *sql.Stmt
|
||||
selectTransactionEventIDStmt *sql.Stmt
|
||||
tableName string
|
||||
}
|
||||
|
||||
func getTransaction(s *transactionStatements, ctx context.Context, pk string, docId string) (*TransactionCosmosData, error) {
|
||||
response := TransactionCosmosData{}
|
||||
err := cosmosdbapi.GetDocumentOrNil(
|
||||
s.db.connection,
|
||||
s.db.cosmosConfig,
|
||||
ctx,
|
||||
pk,
|
||||
docId,
|
||||
&response)
|
||||
|
||||
if response.Id == "" {
|
||||
return nil, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
|
||||
return &response, err
|
||||
}
|
||||
|
||||
func NewCosmosDBTransactionsTable(db *Database) (tables.Transactions, error) {
|
||||
s := &transactionStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(transactionsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertTransactionStmt, insertTransactionSQL},
|
||||
{&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL},
|
||||
}.Prepare(db)
|
||||
// return s, shared.StatementList{
|
||||
// {&s.insertTransactionStmt, insertTransactionSQL},
|
||||
// {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL},
|
||||
// }.Prepare(db)
|
||||
s.tableName = "transactions"
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *transactionStatements) InsertTransaction(
|
||||
|
|
@ -71,10 +105,39 @@ func (s *transactionStatements) InsertTransaction(
|
|||
userID string,
|
||||
eventID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx, transactionID, sessionID, userID, eventID,
|
||||
)
|
||||
|
||||
// INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)
|
||||
// VALUES ($1, $2, $3, $4)
|
||||
data := TransactionCosmos{
|
||||
EventID: eventID,
|
||||
SessionID: sessionID,
|
||||
TransactionID: transactionID,
|
||||
UserID: userID,
|
||||
}
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
|
||||
// PRIMARY KEY (transaction_id, session_id, user_id)
|
||||
docId := fmt.Sprintf("%s_%d_%s", transactionID, sessionID, userID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
var dbData = TransactionCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Transaction: data,
|
||||
}
|
||||
|
||||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
&dbData,
|
||||
options)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -84,8 +147,21 @@ func (s *transactionStatements) SelectTransactionEventID(
|
|||
sessionID int64,
|
||||
userID string,
|
||||
) (eventID string, err error) {
|
||||
err = s.selectTransactionEventIDStmt.QueryRowContext(
|
||||
ctx, transactionID, sessionID, userID,
|
||||
).Scan(&eventID)
|
||||
return
|
||||
|
||||
// SELECT event_id FROM roomserver_transactions
|
||||
// WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
// PRIMARY KEY (transaction_id, session_id, user_id)
|
||||
docId := fmt.Sprintf("%s_%d_%s", transactionID, sessionID, userID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
response, err := getTransaction(s, ctx, pk, cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return response.Transaction.EventID, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -71,6 +71,27 @@ func (s *accountDataStatements) prepare(db *Database) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func queryAccountData(s *accountDataStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountDataCosmosData, error) {
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []AccountDataCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) insertAccountData(
|
||||
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
|
|
@ -92,10 +113,14 @@ func (s *accountDataStatements) insertAccountData(
|
|||
id = fmt.Sprintf("%s_%s_%s", result.LocalPart, result.RoomId, result.Type)
|
||||
}
|
||||
|
||||
docId := id
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
var dbData = AccountDataCosmosData{
|
||||
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, id),
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
AccountData: result,
|
||||
}
|
||||
|
|
@ -120,24 +145,15 @@ func (s *accountDataStatements) selectAccountData(
|
|||
) {
|
||||
// "SELECT room_id, type, content FROM account_data WHERE localpart = $1"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response := []AccountDataCosmosData{}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": localpart,
|
||||
}
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectAccountDataStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
||||
if ex != nil {
|
||||
return nil, nil, ex
|
||||
response, err := queryAccountData(s, ctx, s.selectAccountDataStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
global := map[string]json.RawMessage{}
|
||||
|
|
@ -166,26 +182,17 @@ func (s *accountDataStatements) selectAccountDataByType(
|
|||
|
||||
// "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response := []AccountDataCosmosData{}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": localpart,
|
||||
"@x3": roomID,
|
||||
"@x4": dataType,
|
||||
}
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectAccountDataByTypeStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
||||
if ex != nil {
|
||||
return nil, ex
|
||||
response, err := queryAccountData(s, ctx, s.selectAccountDataByTypeStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(response) == 0 {
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
|
|
@ -87,17 +88,42 @@ func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.Serv
|
|||
return
|
||||
}
|
||||
|
||||
func getAccount(s *accountsStatements, ctx context.Context, pk string, docId string) (*AccountCosmosData, error) {
|
||||
response := AccountCosmosData{}
|
||||
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
|
||||
func queryAccount(s *accountsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountCosmosData, error) {
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []AccountCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func getAccount(s *accountsStatements, ctx context.Context, pk string, docId string) (*AccountCosmosData, error) {
|
||||
response := AccountCosmosData{}
|
||||
err := cosmosdbapi.GetDocumentOrNil(
|
||||
s.db.connection,
|
||||
s.db.cosmosConfig,
|
||||
ctx,
|
||||
pk,
|
||||
docId,
|
||||
optionsGet,
|
||||
&response)
|
||||
return &response, ex
|
||||
|
||||
if response.Id == "" {
|
||||
return nil, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
|
||||
return &response, err
|
||||
}
|
||||
|
||||
func setAccount(s *accountsStatements, ctx context.Context, pk string, account AccountCosmosData) (*AccountCosmosData, error) {
|
||||
|
|
@ -155,10 +181,14 @@ func (s *accountsStatements) insertAccount(
|
|||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
|
||||
docId := result.Localpart
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
var dbData = AccountCosmosData{
|
||||
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Localpart),
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Account: data,
|
||||
}
|
||||
|
|
@ -184,10 +214,11 @@ func (s *accountsStatements) updatePassword(
|
|||
|
||||
// "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
docId := localpart
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
var response, exGet = getAccount(s, ctx, pk, docId)
|
||||
var response, exGet = getAccount(s, ctx, pk, cosmosDocId)
|
||||
if exGet != nil {
|
||||
return exGet
|
||||
}
|
||||
|
|
@ -207,10 +238,12 @@ func (s *accountsStatements) deactivateAccount(
|
|||
|
||||
// "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
var response, exGet = getAccount(s, ctx, pk, docId)
|
||||
docId := localpart
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
var response, exGet = getAccount(s, ctx, pk, cosmosDocId)
|
||||
if exGet != nil {
|
||||
return exGet
|
||||
}
|
||||
|
|
@ -230,24 +263,15 @@ func (s *accountsStatements) selectPasswordHash(
|
|||
|
||||
// "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response := []AccountCosmosData{}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": localpart,
|
||||
}
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectPasswordHashStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
||||
if ex != nil {
|
||||
return "", ex
|
||||
response, err := queryAccount(s, ctx, s.selectPasswordHashStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(response) == 0 {
|
||||
|
|
@ -268,24 +292,15 @@ func (s *accountsStatements) selectAccountByLocalpart(
|
|||
|
||||
// "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response := []AccountCosmosData{}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": localpart,
|
||||
}
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectAccountByLocalpartStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
||||
if ex != nil {
|
||||
return nil, ex
|
||||
response, err := queryAccount(s, ctx, s.selectAccountByLocalpartStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(response) == 0 {
|
||||
|
|
|
|||
|
|
@ -62,6 +62,27 @@ func mapToToken(api api.OpenIDToken) OpenIDTokenCosmos {
|
|||
}
|
||||
}
|
||||
|
||||
func queryOpenIdToken(s *tokenStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OpenIdTokenCosmosData, error) {
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []OpenIdTokenCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *tokenStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
|
||||
s.db = db
|
||||
s.selectTokenStmt = "select * from c where c._cn = @x1 and c.mx_userapi_openidtoken.token = @x2"
|
||||
|
|
@ -87,10 +108,14 @@ func (s *tokenStatements) insertToken(
|
|||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
|
||||
|
||||
docId := result.Token
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
var dbData = OpenIdTokenCosmosData{
|
||||
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Token),
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
OpenIdToken: mapToToken(*result),
|
||||
}
|
||||
|
|
@ -120,24 +145,14 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes(
|
|||
|
||||
// "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response := []OpenIdTokenCosmosData{}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": token,
|
||||
}
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectTokenStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
response, err := queryOpenIdToken(s, ctx, s.selectTokenStmt, params)
|
||||
|
||||
if ex != nil {
|
||||
return nil, ex
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(response) == 0 {
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
)
|
||||
|
|
@ -87,17 +88,42 @@ func (s *profilesStatements) prepare(db *Database) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func getProfile(s *profilesStatements, ctx context.Context, pk string, docId string) (*ProfileCosmosData, error) {
|
||||
response := ProfileCosmosData{}
|
||||
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
|
||||
func queryProfile(s *profilesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ProfileCosmosData, error) {
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []ProfileCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func getProfile(s *profilesStatements, ctx context.Context, pk string, docId string) (*ProfileCosmosData, error) {
|
||||
response := ProfileCosmosData{}
|
||||
err := cosmosdbapi.GetDocumentOrNil(
|
||||
s.db.connection,
|
||||
s.db.cosmosConfig,
|
||||
ctx,
|
||||
pk,
|
||||
docId,
|
||||
optionsGet,
|
||||
&response)
|
||||
return &response, ex
|
||||
|
||||
if response.Id == "" {
|
||||
return nil, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
|
||||
return &response, err
|
||||
}
|
||||
|
||||
func setProfile(s *profilesStatements, ctx context.Context, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) {
|
||||
|
|
@ -123,10 +149,14 @@ func (s *profilesStatements) insertProfile(
|
|||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||
|
||||
docId := localpart
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
var dbData = ProfileCosmosData{
|
||||
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Localpart),
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Profile: mapToProfile(*result),
|
||||
}
|
||||
|
|
@ -148,24 +178,15 @@ func (s *profilesStatements) selectProfileByLocalpart(
|
|||
|
||||
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response := []ProfileCosmosData{}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": localpart,
|
||||
}
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectProfileByLocalpartStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
||||
if ex != nil {
|
||||
return nil, ex
|
||||
response, err := queryProfile(s, ctx, s.selectProfileByLocalpartStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(response) == 0 {
|
||||
|
|
@ -186,10 +207,11 @@ func (s *profilesStatements) setAvatarURL(
|
|||
|
||||
// "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
|
||||
docId := localpart
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
var response, exGet = getProfile(s, ctx, pk, docId)
|
||||
var response, exGet = getProfile(s, ctx, pk, cosmosDocId)
|
||||
if exGet != nil {
|
||||
return exGet
|
||||
}
|
||||
|
|
@ -209,9 +231,10 @@ func (s *profilesStatements) setDisplayName(
|
|||
|
||||
// "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
|
||||
var response, exGet = getProfile(s, ctx, pk, docId)
|
||||
docId := localpart
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response, exGet = getProfile(s, ctx, pk, cosmosDocId)
|
||||
if exGet != nil {
|
||||
return exGet
|
||||
}
|
||||
|
|
@ -232,25 +255,16 @@ func (s *profilesStatements) selectProfilesBySearch(
|
|||
|
||||
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response := []ProfileCosmosData{}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": searchString,
|
||||
"@x3": limit,
|
||||
}
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectProfilesBySearchStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
||||
if ex != nil {
|
||||
return nil, ex
|
||||
response, err := queryProfile(s, ctx, s.selectProfilesBySearchStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := 0; i < len(response); i++ {
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ import (
|
|||
"errors"
|
||||
"strconv"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
|
||||
|
|
@ -27,7 +29,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -37,6 +38,7 @@ import (
|
|||
// Database represents an account database
|
||||
type Database struct {
|
||||
sqlutil.PartitionOffsetStatements
|
||||
writer sqlutil.Writer
|
||||
accounts accountsStatements
|
||||
profiles profilesStatements
|
||||
accountDatas accountDataStatements
|
||||
|
|
@ -62,7 +64,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
|||
connection: conn,
|
||||
cosmosConfig: config,
|
||||
// db: db,
|
||||
// writer: sqlutil.NewExclusiveWriter(),
|
||||
writer: sqlutil.NewExclusiveWriter(),
|
||||
// bcryptCost: bcryptCost,
|
||||
// openIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,31 +69,43 @@ func (s *threepidStatements) prepare(db *Database) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func queryThreePID(s *threepidStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ThreePIDCosmosData, error) {
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []ThreePIDCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *threepidStatements) selectLocalpartForThreePID(
|
||||
ctx context.Context, threepid string, medium string,
|
||||
) (localpart string, err error) {
|
||||
|
||||
// "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response := []ThreePIDCosmosData{}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": threepid,
|
||||
"@x3": medium,
|
||||
}
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectLocalpartForThreePIDStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
||||
if ex != nil {
|
||||
return "", ex
|
||||
response, err := queryThreePID(s, ctx, s.selectLocalpartForThreePIDStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(response) == 0 {
|
||||
|
|
@ -109,24 +121,14 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
|
|||
|
||||
// "SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response := []ThreePIDCosmosData{}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": localpart,
|
||||
}
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectThreePIDsForLocalpartStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
response, err := queryThreePID(s, ctx, s.selectThreePIDsForLocalpartStmt, params)
|
||||
|
||||
if ex != nil {
|
||||
return threepids, ex
|
||||
if err != nil {
|
||||
return threepids, err
|
||||
}
|
||||
|
||||
if len(response) == 0 {
|
||||
|
|
@ -158,10 +160,11 @@ func (s *threepidStatements) insertThreePID(
|
|||
|
||||
docId := fmt.Sprintf("%s_%s", threepid, medium)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var dbData = ThreePIDCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
ThreePID: result,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,10 +16,11 @@ package cosmosdb
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
|
|
@ -82,15 +83,15 @@ type DeviceCosmosSessionCount struct {
|
|||
}
|
||||
|
||||
type devicesStatements struct {
|
||||
db *Database
|
||||
db *Database
|
||||
selectDevicesCountStmt string
|
||||
selectDeviceByTokenStmt string
|
||||
// selectDeviceByIDStmt *sql.Stmt
|
||||
selectDevicesByIDStmt string
|
||||
selectDevicesByLocalpartStmt string
|
||||
selectDevicesByLocalpartExceptIDStmt string
|
||||
serverName gomatrixserverlib.ServerName
|
||||
tableName string
|
||||
serverName gomatrixserverlib.ServerName
|
||||
tableName string
|
||||
}
|
||||
|
||||
func mapFromDevice(db DeviceCosmos) api.Device {
|
||||
|
|
@ -121,17 +122,42 @@ func mapTodevice(api api.Device, s *devicesStatements) DeviceCosmos {
|
|||
}
|
||||
}
|
||||
|
||||
func getDevice(s *devicesStatements, ctx context.Context, pk string, docId string) (*DeviceCosmosData, error) {
|
||||
response := DeviceCosmosData{}
|
||||
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
|
||||
func queryDevice(s *devicesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceCosmosData, error) {
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []DeviceCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func getDevice(s *devicesStatements, ctx context.Context, pk string, docId string) (*DeviceCosmosData, error) {
|
||||
response := DeviceCosmosData{}
|
||||
err := cosmosdbapi.GetDocumentOrNil(
|
||||
s.db.connection,
|
||||
s.db.cosmosConfig,
|
||||
ctx,
|
||||
pk,
|
||||
docId,
|
||||
optionsGet,
|
||||
&response)
|
||||
return &response, ex
|
||||
|
||||
if response.Id == "" {
|
||||
return nil, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
|
||||
return &response, err
|
||||
}
|
||||
|
||||
func setDevice(s *devicesStatements, ctx context.Context, pk string, device DeviceCosmosData) (*DeviceCosmosData, error) {
|
||||
|
|
@ -209,10 +235,11 @@ func (s *devicesStatements) insertDevice(
|
|||
// HACK: check for duplicate PK as we are using the UNIQUE key for the DocId
|
||||
docId := fmt.Sprintf("%s_%s", localpart, id)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
|
||||
var dbData = DeviceCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Device: data,
|
||||
}
|
||||
|
|
@ -260,7 +287,6 @@ func (s *devicesStatements) deleteDevices(
|
|||
) error {
|
||||
// "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []DeviceCosmosData
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
|
|
@ -268,15 +294,8 @@ func (s *devicesStatements) deleteDevices(
|
|||
"@x3": devices,
|
||||
}
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartStmt, params)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -291,8 +310,6 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
|
|||
) error {
|
||||
// "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []DeviceCosmosData
|
||||
exceptDevices := []string{
|
||||
exceptDeviceID,
|
||||
}
|
||||
|
|
@ -302,15 +319,8 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
|
|||
"@x3": exceptDevices,
|
||||
}
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartStmt, params)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -325,9 +335,9 @@ func (s *devicesStatements) updateDeviceName(
|
|||
) error {
|
||||
// "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
|
||||
if exGet != nil {
|
||||
return exGet
|
||||
|
|
@ -347,27 +357,19 @@ func (s *devicesStatements) selectDeviceByToken(
|
|||
) (*api.Device, error) {
|
||||
// "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []DeviceCosmosData
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": accessToken,
|
||||
}
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectDeviceByTokenStmt, params)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
response, err := queryDevice(s, ctx, s.selectDeviceByTokenStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(response) == 0 {
|
||||
return nil, errors.New(fmt.Sprintf("No Devices found with AccessToken %s", accessToken))
|
||||
return nil, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
|
|
@ -384,9 +386,9 @@ func (s *devicesStatements) selectDeviceByID(
|
|||
) (*api.Device, error) {
|
||||
// "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
|
||||
if exGet != nil {
|
||||
return nil, exGet
|
||||
|
|
@ -401,23 +403,14 @@ func (s *devicesStatements) selectDevicesByLocalpart(
|
|||
devices := []api.Device{}
|
||||
// "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []DeviceCosmosData
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": localpart,
|
||||
"@x3": exceptDeviceID,
|
||||
}
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartExceptIDStmt, params)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartExceptIDStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -435,22 +428,14 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
|
|||
// "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
|
||||
var devices []api.Device
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []DeviceCosmosData
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": deviceIDs,
|
||||
}
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectDevicesByIDStmt, params)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
response, err := queryDevice(s, ctx, s.selectDevicesByIDStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -466,9 +451,9 @@ func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, localpart,
|
|||
|
||||
// "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
|
||||
if exGet != nil {
|
||||
return exGet
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
|
||||
|
|
@ -37,7 +36,6 @@ var deviceIDByteLength = 6
|
|||
|
||||
// Database represents a device database.
|
||||
type Database struct {
|
||||
writer sqlutil.Writer
|
||||
devices devicesStatements
|
||||
connection cosmosdbapi.CosmosConnection
|
||||
databaseName string
|
||||
|
|
|
|||
Loading…
Reference in a new issue