- Add GO CosmosDB library github.com/vippsas/go-cosmosdb

- Update YAML file to use file: everywhere except for Accounts
- Use the CosmosDB conn string in the YAML
- Add cosmosdbapi package to wrap the external package
- Add Tenant.go to store the tenancy settings - to be removed when tenancy is implemented
- Update the 5 tables to use the internal CosmosDBAPI package instead of SQL
- Remove sql from storage.go and other files
This commit is contained in:
Alex Flatow 2021-05-11 09:11:33 +10:00
parent 234a89db5d
commit a5ddb710d8
17 changed files with 1053 additions and 498 deletions

15
.vscode/launch.json vendored
View file

@ -28,6 +28,19 @@
"${workspaceFolder}\\bin\\dendrite.yaml",
"clientapi",
]
}
},
{
"name": "Launch Package Monolith - CosmosDB",
"type": "go",
"request": "launch",
"mode": "debug",
"program": "${workspaceFolder}\\cmd\\dendrite-monolith-server",
"args": [
"-config",
"${workspaceFolder}\\dendrite-config-cosmosdb.yaml",
//Uncomment below to expose internal api's
// "--api",
// "true"
]}
]
}

View file

@ -90,7 +90,7 @@ global:
# Naffka database options. Not required when using Kafka.
naffka_database:
connection_string: cosmosdb:naffka.db
connection_string: file:naffka.db
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@ -122,7 +122,7 @@ app_service_api:
listen: http://localhost:7777
connect: http://localhost:7777
database:
connection_string: cosmosdb:appservice.db
connection_string: file:appservice.db
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@ -202,7 +202,7 @@ federation_sender:
listen: http://localhost:7775
connect: http://localhost:7775
database:
connection_string: cosmosdb:federationsender.db
connection_string: file:federationsender.db
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@ -228,7 +228,7 @@ key_server:
listen: http://localhost:7779
connect: http://localhost:7779
database:
connection_string: cosmosdb:keyserver.db
connection_string: file:keyserver.db
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@ -241,7 +241,7 @@ media_api:
external_api:
listen: http://[::]:8074
database:
connection_string: cosmosdb:mediaapi.db
connection_string: file:mediaapi.db
max_open_conns: 5
max_idle_conns: 2
conn_max_lifetime: -1
@ -280,7 +280,7 @@ mscs:
# - msc2946 (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946)
mscs: []
database:
connection_string: cosmosdb:mscs.db
connection_string: file:mscs.db
max_open_conns: 5
max_idle_conns: 2
conn_max_lifetime: -1
@ -291,7 +291,7 @@ room_server:
listen: http://localhost:7770
connect: http://localhost:7770
database:
connection_string: cosmosdb:roomserver.db
connection_string: file:roomserver.db
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@ -302,7 +302,7 @@ signing_key_server:
listen: http://localhost:7780
connect: http://localhost:7780
database:
connection_string: cosmosdb:signingkeyserver.db
connection_string: file:signingkeyserver.db
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@ -331,7 +331,7 @@ sync_api:
external_api:
listen: http://[::]:8073
database:
connection_string: cosmosdb:syncapi.db
connection_string: file:syncapi.db
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@ -354,12 +354,12 @@ user_api:
listen: http://localhost:7781
connect: http://localhost:7781
account_database:
connection_string: cosmosdb:userapi_accounts.db
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
device_database:
connection_string: cosmosdb:userapi_devices.db
connection_string: file:userapi_devices.db
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1

1
go.mod
View file

@ -37,6 +37,7 @@ require (
github.com/tidwall/sjson v1.1.5
github.com/uber/jaeger-client-go v2.25.0+incompatible
github.com/uber/jaeger-lib v2.4.0+incompatible
github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d // indirect
github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20210218094457-e77ca8019daa
go.uber.org/atomic v1.7.0
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83

4
go.sum
View file

@ -36,6 +36,7 @@ github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/
github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII=
github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c=
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
github.com/alecthomas/repr v0.0.0-20181024024818-d37bc2a10ba1/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
@ -176,6 +177,7 @@ github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/me
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
github.com/gofrs/uuid v3.1.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
@ -987,6 +989,8 @@ github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPU
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU=
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d h1:MZRYOouO0snrQyBAf4Wljc3qqaispjzMOhFRQgWfKMo=
github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d/go.mod h1:ldPlejlc7ZyiP0QQWGwL9CoZLvEjhD9yzpz0ct7+sXo=
github.com/vishvananda/netlink v1.0.0/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
github.com/vishvananda/netns v0.0.0-20190625233234-7109fa855b0f/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 h1:EKhdznlJHPMoKr0XTrX+IlJs1LH3lyx2nfr1dOlZ79k=

View file

@ -0,0 +1,24 @@
package cosmosdbapi
import (
cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi"
)
type CosmosConnection struct {
Url string
Key string
}
func GetCosmosConnection(accountEndpoint string, accountKey string) CosmosConnection {
return CosmosConnection{
Url: accountEndpoint,
Key: accountKey,
}
}
func GetClient(conn CosmosConnection) *cosmosapi.Client {
cfg := cosmosapi.Config{
MasterKey: conn.Key,
}
return cosmosapi.New(conn.Url, cfg, nil, nil)
}

View file

@ -0,0 +1,10 @@
package cosmosdbapi
import (
"fmt"
)
func GetCollectionName(databaseName string, tableName string) string {
return fmt.Sprintf("matrix_%s_%s", databaseName, tableName)
}

View file

@ -0,0 +1,14 @@
package cosmosdbapi
import (
"fmt"
)
func GetDocumentId(tenantName string, collectionName string, id string) string {
return fmt.Sprintf("%s,%s,%s", collectionName, tenantName, id)
}
func GetPartitionKey(tenantName string, collectionName string) string {
return fmt.Sprintf("%s,%s", collectionName, tenantName)
}

View file

@ -0,0 +1,46 @@
package cosmosdbapi
import (
cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi"
)
func GetCreateDocumentOptions(pk string) cosmosapi.CreateDocumentOptions {
return cosmosapi.CreateDocumentOptions{
IsUpsert: false,
PartitionKeyValue: pk,
}
}
func GetUpsertDocumentOptions(pk string) cosmosapi.CreateDocumentOptions {
return cosmosapi.CreateDocumentOptions{
IsUpsert: true,
PartitionKeyValue: pk,
}
}
func GetQueryDocumentsOptions(pk string) cosmosapi.QueryDocumentsOptions {
return cosmosapi.QueryDocumentsOptions{
PartitionKeyValue: pk,
IsQuery: true,
ContentType: cosmosapi.QUERY_CONTENT_TYPE,
}
}
func GetGetDocumentOptions(pk string) cosmosapi.GetDocumentOptions {
return cosmosapi.GetDocumentOptions{
PartitionKeyValue: pk,
}
}
func GetReplaceDocumentOptions(pk string, etag string) cosmosapi.ReplaceDocumentOptions {
return cosmosapi.ReplaceDocumentOptions{
PartitionKeyValue: pk,
IfMatch: etag,
}
}
func GetDeleteDocumentOptions(pk string) cosmosapi.DeleteDocumentOptions {
return cosmosapi.DeleteDocumentOptions{
PartitionKeyValue: pk,
}
}

View file

@ -0,0 +1,20 @@
package cosmosdbapi
import (
cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi"
)
func GetQuery(qry string, params map[string]interface{}) cosmosapi.Query {
qryParams := []cosmosapi.QueryParam{}
for key, value := range params {
qryParam := cosmosapi.QueryParam {
Name: key,
Value: value,
}
qryParams = append(qryParams, qryParam)
}
return cosmosapi.Query {
Query: qry,
Params: qryParams,
}
}

View file

@ -0,0 +1,14 @@
package cosmosdbapi
type Tenant struct {
DatabaseName string
TenantName string
}
//TODO: Move into Config or the JWT
func DefaultConfig() Tenant {
return Tenant{
DatabaseName: "safezone_local",
TenantName: "criticalarc.com",
}
}

View file

@ -8,5 +8,15 @@ import (
func GetConnectionString(d *config.DataSource) config.DataSource {
var connString string
connString = string(*d)
return config.DataSource(strings.Replace(connString, "cosmosdb:", "file:", 1))
return config.DataSource(strings.Replace(connString, "cosmosdb:", "", 1))
}
func GetConnectionProperties(connectionString string) map[string]string {
connectionItemsRaw := strings.Split(connectionString, ";")
connectionItems := map[string]string{}
for _, item := range connectionItemsRaw {
itemSplit := strings.SplitN(item, "=", 2)
connectionItems[itemSplit[0]] = itemSplit[1]
}
return connectionItems
}

View file

@ -16,68 +16,94 @@ package cosmosdb
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
)
const accountDataSchema = `
-- Stores data about accounts data.
CREATE TABLE IF NOT EXISTS account_data (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL,
-- The room ID for this data (empty string if not specific to a room)
room_id TEXT,
-- The account data type
type TEXT NOT NULL,
-- The account data content
content TEXT NOT NULL,
// const accountDataSchema = `
// -- Stores data about accounts data.
// CREATE TABLE IF NOT EXISTS account_data (
// -- The Matrix user ID localpart for this account
// localpart TEXT NOT NULL,
// -- The room ID for this data (empty string if not specific to a room)
// room_id TEXT,
// -- The account data type
// type TEXT NOT NULL,
// -- The account data content
// content TEXT NOT NULL,
PRIMARY KEY(localpart, room_id, type)
);
`
// PRIMARY KEY(localpart, room_id, type)
// );
// `
const insertAccountDataSQL = `
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
`
const selectAccountDataSQL = "" +
"SELECT room_id, type, content FROM account_data WHERE localpart = $1"
const selectAccountDataByTypeSQL = "" +
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
type accountDataStatements struct {
db *sql.DB
insertAccountDataStmt *sql.Stmt
selectAccountDataStmt *sql.Stmt
selectAccountDataByTypeStmt *sql.Stmt
type AccountCosmosAccountData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Object AccountData `json:"_object"`
}
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
type AccountData struct {
LocalPart string `json:"local_part"`
RoomId string `json:"room_id"`
Type string `json:"type"`
Content []byte `json:"content"`
}
type accountDataStatements struct {
db *Database
tableName string
}
func (s *accountDataStatements) prepare(db *Database) (err error) {
s.db = db
_, err = db.Exec(accountDataSchema)
if err != nil {
return
}
if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil {
return
}
if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil {
return
}
if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil {
return
}
s.tableName = "account_data"
return
}
func (s *accountDataStatements) insertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
// INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
// ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
var result = AccountData{
LocalPart: localpart,
RoomId: roomID,
Type: dataType,
Content: content,
}
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
id := ""
if roomID == "" {
id = fmt.Sprintf("%s_%s", result.LocalPart, result.Type)
} else {
id = fmt.Sprintf("%s_%s_%s", result.LocalPart, result.RoomId, result.Type)
}
var dbData = AccountCosmosAccountData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id),
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
Timestamp: time.Now().Unix(),
Object: result,
}
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
config.DatabaseName,
config.TenantName,
dbData,
options)
return err
}
@ -88,30 +114,43 @@ func (s *accountDataStatements) selectAccountData(
/* rooms */ map[string]map[string]json.RawMessage,
error,
) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
if err != nil {
return nil, nil, err
// "SELECT room_id, type, content FROM account_data WHERE localpart = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
response := []AccountCosmosAccountData{}
var selectAccountDataCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2"
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectAccountDataCosmos, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return nil, nil, ex
}
global := map[string]json.RawMessage{}
rooms := map[string]map[string]json.RawMessage{}
for rows.Next() {
var roomID string
var dataType string
var content []byte
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
return nil, nil, err
}
for i := 0; i < len(response); i++ {
var row = response[i]
var roomID = row.Object.RoomId
if roomID != "" {
if _, ok := rooms[roomID]; !ok {
if _, ok := rooms[row.Object.RoomId]; !ok {
rooms[roomID] = map[string]json.RawMessage{}
}
rooms[roomID][dataType] = content
rooms[roomID][row.Object.Type] = row.Object.Content
} else {
global[dataType] = content
global[row.Object.Type] = row.Object.Content
}
}
@ -122,13 +161,39 @@ func (s *accountDataStatements) selectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) {
var bytes []byte
stmt := s.selectAccountDataByTypeStmt
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return
// "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
response := []AccountCosmosAccountData{}
var selectAccountDataCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2 and c._object.room_id = @x3 and c._object.type = @x4"
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
"@x3": roomID,
"@x4": dataType,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectAccountDataCosmos, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return nil, ex
}
if len(response) == 0 {
return data, nil
}
bytes = response[0].Object.Content
data = json.RawMessage(bytes)
return
}

View file

@ -16,159 +16,264 @@ package cosmosdb
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
const accountsSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS account_accounts (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
-- When this account was first created, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL,
-- The password hash for this account. Can be NULL if this is a passwordless account.
password_hash TEXT,
-- Identifies which application service this account belongs to, if any.
appservice_id TEXT,
-- If the account is currently active
is_deactivated BOOLEAN DEFAULT 0
-- TODO:
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
);
`
// const accountsSchema = `
// -- Stores data about accounts.
// CREATE TABLE IF NOT EXISTS account_accounts (
// -- The Matrix user ID localpart for this account
// localpart TEXT NOT NULL PRIMARY KEY,
// -- When this account was first created, as a unix timestamp (ms resolution).
// created_ts BIGINT NOT NULL,
// -- The password hash for this account. Can be NULL if this is a passwordless account.
// password_hash TEXT,
// -- Identifies which application service this account belongs to, if any.
// appservice_id TEXT,
// -- If the account is currently active
// is_deactivated BOOLEAN DEFAULT 0
// -- TODO:
// -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
// );
// `
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
type AccountExtended struct {
IsDeactivated bool `json:"is_deactivated"`
PasswordHash string `json:"password_hash"`
Created int64 `json:"created_ts"`
}
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
type AccountCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Object api.Account `json:"_object"`
ObjectExtended AccountExtended `json:"_object_extended"`
}
const deactivateAccountSQL = "" +
"UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
const selectNewNumericLocalpartSQL = "" +
"SELECT COUNT(localpart) FROM account_accounts"
type AccountCosmosUserCount struct {
UserCount int64 `json:"usercount"`
}
type accountsStatements struct {
db *sql.DB
insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt
deactivateAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
db *Database
tableName string
serverName gomatrixserverlib.ServerName
}
func (s *accountsStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(accountsSchema)
return err
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
s.db = db
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return
}
if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
return
}
if s.deactivateAccountStmt, err = db.Prepare(deactivateAccountSQL); err != nil {
return
}
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return
}
if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
return
}
if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil {
return
}
s.tableName = "account_accounts"
s.serverName = server
return
}
func getAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*AccountCosmosData, error) {
response := AccountCosmosData{}
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
ctx,
config.DatabaseName,
config.TenantName,
docId,
optionsGet,
&response)
return &response, ex
}
func setAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, account AccountCosmosData) (*AccountCosmosData, error) {
response := AccountCosmosData{}
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, account.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
config.DatabaseName,
config.TenantName,
account.Id,
&account,
optionsReplace)
return &response, ex
}
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) insertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
ctx context.Context, localpart, hash, appserviceID string,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt
// stmt := s.insertAccountStmt
var err error
if appserviceID == "" {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
} else {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
}
if err != nil {
return nil, err
}
return &api.Account{
var result = api.Account{
Localpart: localpart,
UserID: userutil.MakeUserID(localpart, s.serverName),
ServerName: s.serverName,
AppServiceID: appserviceID,
}, nil
}
var extended = AccountExtended{
IsDeactivated: false,
PasswordHash: hash,
Created: createdTimeMS,
}
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var dbData = AccountCosmosData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart),
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
Timestamp: time.Now().Unix(),
Object: result,
ObjectExtended: extended,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
config.DatabaseName,
config.TenantName,
dbData,
options)
if err != nil {
return nil, err
}
return &result, nil
}
func (s *accountsStatements) updatePassword(
ctx context.Context, localpart, passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
// "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var response, exGet = getAccount(s, ctx, config, pk, docId)
if exGet != nil {
return exGet
}
response.ObjectExtended.PasswordHash = passwordHash
var _, exReplace = setAccount(s, ctx, config, pk, *response)
if exReplace != nil {
return exReplace
}
return
}
func (s *accountsStatements) deactivateAccount(
ctx context.Context, localpart string,
) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
// "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var response, exGet = getAccount(s, ctx, config, pk, docId)
if exGet != nil {
return exGet
}
response.ObjectExtended.IsDeactivated = true
var _, exReplace = setAccount(s, ctx, config, pk, *response)
if exReplace != nil {
return exReplace
}
return
}
func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
return
// "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
response := []AccountCosmosData{}
var selectPasswordHashCosmos = "select * from c where c._cn = @x1 and c._object.Localpart = @x2 and c._object_extended.is_deactivated = false"
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectPasswordHashCosmos, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return "", ex
}
if len(response) == 0 {
return "", errors.New(fmt.Sprintf("Localpart %s not found", localpart))
}
if len(response) != 1 {
return "", errors.New(fmt.Sprintf("Localpart %s has multiple entries", localpart))
}
return response[0].ObjectExtended.PasswordHash, nil
}
func (s *accountsStatements) selectAccountByLocalpart(
ctx context.Context, localpart string,
) (*api.Account, error) {
var appserviceIDPtr sql.NullString
var acc api.Account
stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db")
}
return nil, err
// "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
response := []AccountCosmosData{}
var selectPasswordHashCosmos = "select * from c where c._cn = @x1 and c._object.Localpart = @x2"
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
}
if appserviceIDPtr.Valid {
acc.AppServiceID = appserviceIDPtr.String
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectPasswordHashCosmos, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return nil, ex
}
if len(response) == 0 {
return nil, nil
}
acc = response[0].Object
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
acc.ServerName = s.serverName
@ -176,12 +281,31 @@ func (s *accountsStatements) selectAccountByLocalpart(
}
func (s *accountsStatements) selectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx,
ctx context.Context,
) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
// "SELECT COUNT(localpart) FROM account_accounts"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var response []AccountCosmosUserCount
var selectCountCosmos = "select count(c._ts) as usercount from c where c._cn = @x1"
params := map[string]interface{}{
"@x1": dbCollectionName,
}
err = stmt.QueryRowContext(ctx).Scan(&id)
return
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectCountCosmos, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return -1, ex
}
return int64(response[0].UserCount), nil
}

View file

@ -1,13 +1,12 @@
package cosmosdb
import (
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
const openIDTokenSchema = `
@ -21,32 +20,24 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
token_expires_at_ms BIGINT NOT NULL
);
`
const insertTokenSQL = "" +
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
type tokenStatements struct {
db *sql.DB
insertTokenStmt *sql.Stmt
selectTokenStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
type OpenIdTokenCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Object *api.OpenIDToken `json:"_object"`
}
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
type tokenStatements struct {
db *Database
tableName string
serverName gomatrixserverlib.ServerName
}
func (s *tokenStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
s.db = db
_, err = db.Exec(openIDTokenSchema)
if err != nil {
return err
}
if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil {
return
}
if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil {
return
}
s.tableName = "open_id_tokens"
s.serverName = server
return
}
@ -55,12 +46,40 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam
// Returns new token, otherwise returns error if the token already exists.
func (s *tokenStatements) insertToken(
ctx context.Context,
txn *sql.Tx,
token, localpart string,
expiresAtMS int64,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
// "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
var result = &api.OpenIDToken{
UserID: localpart,
Token: token,
ExpiresAtMS: expiresAtMS,
}
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
var dbData = OpenIdTokenCosmosData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Token),
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
Timestamp: time.Now().Unix(),
Object: result,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
config.DatabaseName,
config.TenantName,
dbData,
options)
if ex != nil {
return ex
}
return
}
@ -71,16 +90,39 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes(
token string,
) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDTokenAttrs.UserID,
&openIDTokenAttrs.ExpiresAtMS,
)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve token from the db")
}
return nil, err
// "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
response := []OpenIdTokenCosmosData{}
var selectOpenIdTokenCosmos = "select * from c where c._cn = @x1 and c._object.Token = @x2"
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": token,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectOpenIdTokenCosmos, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return nil, ex
}
if(len(response) == 0) {
return nil, nil
}
var openIdToken = response[0].Object
openIDTokenAttrs = api.OpenIDTokenAttributes{
UserID: openIdToken.UserID,
ExpiresAtMS: openIdToken.ExpiresAtMS,
}
return &openIDTokenAttrs, nil
}

View file

@ -16,107 +16,186 @@ package cosmosdb
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const profilesSchema = `
-- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS account_profiles (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
-- The display name for this account
display_name TEXT,
-- The URL of the avatar for this account
avatar_url TEXT
);
`
// const profilesSchema = `
// -- Stores data about accounts profiles.
// CREATE TABLE IF NOT EXISTS account_profiles (
// -- The Matrix user ID localpart for this account
// localpart TEXT NOT NULL PRIMARY KEY,
// -- The display name for this account
// display_name TEXT,
// -- The URL of the avatar for this account
// avatar_url TEXT
// );
// `
const insertProfileSQL = "" +
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" +
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct {
db *sql.DB
insertProfileStmt *sql.Stmt
selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt
setDisplayNameStmt *sql.Stmt
selectProfilesBySearchStmt *sql.Stmt
type ProfileCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Object authtypes.Profile `json:"_object"`
}
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
type profilesStatements struct {
db *Database
tableName string
}
func (s *profilesStatements) prepare(db *Database) (err error) {
s.db = db
_, err = db.Exec(profilesSchema)
if err != nil {
return
}
if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil {
return
}
if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil {
return
}
if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil {
return
}
if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil {
return
}
if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil {
return
}
s.tableName = "account_profiles"
return
}
func getProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*ProfileCosmosData, error) {
response := ProfileCosmosData{}
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
ctx,
config.DatabaseName,
config.TenantName,
docId,
optionsGet,
&response)
return &response, ex
}
func setProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, profile.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
config.DatabaseName,
config.TenantName,
profile.Id,
&profile,
optionsReplace)
return &profile, ex
}
func (s *profilesStatements) insertProfile(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, localpart string,
) error {
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
// "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
var result = &authtypes.Profile{
Localpart: localpart,
}
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var dbData = ProfileCosmosData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart),
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
Timestamp: time.Now().Unix(),
Object: *result,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
config.DatabaseName,
config.TenantName,
dbData,
options)
return err
}
func (s *profilesStatements) selectProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
var profile authtypes.Profile
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
)
if err != nil {
return nil, err
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
response := []ProfileCosmosData{}
var selectProfileByLocalpartCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2"
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
}
return &profile, nil
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectProfileByLocalpartCosmos, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return nil, ex
}
if len(response) == 0 {
return nil, errors.New(fmt.Sprintf("Localpart %s not found", len(response)))
}
if len(response) != 1 {
return nil, errors.New(fmt.Sprintf("Localpart %s has multiple entries", len(response)))
}
return &response[0].Object, nil
}
func (s *profilesStatements) setAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
ctx context.Context, localpart string, avatarURL string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
_, err = stmt.ExecContext(ctx, avatarURL, localpart)
// "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
var response, exGet = getProfile(s, ctx, config, pk, docId)
if exGet != nil {
return exGet
}
response.Object.AvatarURL = avatarURL
var _, exReplace = setProfile(s, ctx, config, pk, *response)
if exReplace != nil {
return exReplace
}
return
}
func (s *profilesStatements) setDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
ctx context.Context, localpart string, displayName string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
_, err = stmt.ExecContext(ctx, displayName, localpart)
// "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
var response, exGet = getProfile(s, ctx, config, pk, docId)
if exGet != nil {
return exGet
}
response.Object.DisplayName = displayName
var _, exReplace = setProfile(s, ctx, config, pk, *response)
if exReplace != nil {
return exReplace
}
return
}
@ -124,20 +203,36 @@ func (s *profilesStatements) selectProfilesBySearch(
ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
var profiles []authtypes.Profile
// The fmt.Sprintf directive below is building a parameter for the
// "LIKE" condition in the SQL query. %% escapes the % char, so the
// statement in the end will look like "LIKE %searchString%".
rows, err := s.selectProfilesBySearchStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil {
return nil, err
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
response := []ProfileCosmosData{}
var selectProfileByLocalpartCosmos = "select top @x3 * from c where c._cn = @x1 and contains(c._object.local_part, @x2)"
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": searchString,
"@x3": limit,
}
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
for rows.Next() {
var profile authtypes.Profile
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
return nil, err
}
profiles = append(profiles, profile)
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectProfileByLocalpartCosmos, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return nil, ex
}
for i := 0; i < len(response); i++ {
var responseData = response[i]
profiles = append(profiles, responseData.Object)
}
return profiles, nil
}

View file

@ -15,29 +15,27 @@
package cosmosdb
import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"context"
"database/sql"
"encoding/json"
"errors"
"strconv"
"sync"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
// "sync"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
)
// Database represents an account database
type Database struct {
db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
@ -48,55 +46,57 @@ type Database struct {
bcryptCost int
openIDTokenLifetimeMS int64
accountsMu sync.Mutex
profilesMu sync.Mutex
accountDatasMu sync.Mutex
threepidsMu sync.Mutex
databaseName string
connection cosmosdbapi.CosmosConnection
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
connString := cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
connMap := cosmosdbutil.GetConnectionProperties(string(connString))
accountEndpoint := connMap["AccountEndpoint"]
accountKey := connMap["AccountKey"]
conn := cosmosdbapi.GetCosmosConnection(accountEndpoint, accountKey)
d := &Database{
serverName: serverName,
db: db,
writer: sqlutil.NewExclusiveWriter(),
bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
serverName: serverName,
databaseName: "userapi",
connection: conn,
// db: db,
// writer: sqlutil.NewExclusiveWriter(),
// bcryptCost: bcryptCost,
// openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.accounts.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadIsActive(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
// if err = d.accounts.execSchema(db); err != nil {
// return nil, err
// }
// m := sqlutil.NewMigrations()
// deltas.LoadIsActive(m)
// if err = m.RunDeltas(db, dbProperties); err != nil {
// return nil, err
// }
partitions := sqlutil.PartitionOffsetStatements{}
if err = partitions.Prepare(db, d.writer, "account"); err != nil {
// partitions := sqlutil.PartitionOffsetStatements{}
// if err = partitions.Prepare(db, d.writer, "account"); err != nil {
// return nil, err
// }
var err error
if err = d.accounts.prepare(d, serverName); err != nil {
return nil, err
}
if err = d.accounts.prepare(db, serverName); err != nil {
if err = d.profiles.prepare(d); err != nil {
return nil, err
}
if err = d.profiles.prepare(db); err != nil {
if err = d.accountDatas.prepare(d); err != nil {
return nil, err
}
if err = d.accountDatas.prepare(db); err != nil {
if err = d.threepids.prepare(d); err != nil {
return nil, err
}
if err = d.threepids.prepare(db); err != nil {
return nil, err
}
if err = d.openIDTokens.prepare(db, serverName); err != nil {
if err = d.openIDTokens.prepare(d, serverName); err != nil {
return nil, err
}
@ -131,11 +131,11 @@ func (d *Database) GetProfileByLocalpart(
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
d.profilesMu.Lock()
defer d.profilesMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL)
})
// d.profilesMu.Lock()
// defer d.profilesMu.Unlock()
// return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
// })
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
}
// SetDisplayName updates the display name of the profile associated with the given
@ -143,11 +143,12 @@ func (d *Database) SetAvatarURL(
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
d.profilesMu.Lock()
defer d.profilesMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.profiles.setDisplayName(ctx, txn, localpart, displayName)
})
// d.profilesMu.Lock()
// defer d.profilesMu.Unlock()
// return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
// return d.profiles.setDisplayName(ctx, txn, localpart, displayName)
// })
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// SetPassword sets the account password to the given hash.
@ -170,22 +171,23 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
// when the first txn upgrades to a write txn. We also need to lock the account creation else we can
// race with CreateAccount
// We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed.
d.profilesMu.Lock()
d.accountDatasMu.Lock()
d.accountsMu.Lock()
defer d.profilesMu.Unlock()
defer d.accountDatasMu.Unlock()
defer d.accountsMu.Unlock()
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
if err != nil {
return err
}
localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "")
return err
})
// d.profilesMu.Lock()
// d.accountDatasMu.Lock()
// d.accountsMu.Lock()
// defer d.profilesMu.Unlock()
// defer d.accountDatasMu.Unlock()
// defer d.accountsMu.Unlock()
// err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
// })
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx)
if err != nil {
return nil, err
}
localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, localpart, "", "")
return acc, err
}
@ -196,23 +198,25 @@ func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (acc *api.Account, err error) {
// Create one account at a time else we can get 'database is locked'.
d.profilesMu.Lock()
d.accountDatasMu.Lock()
d.accountsMu.Lock()
defer d.profilesMu.Unlock()
defer d.accountDatasMu.Unlock()
defer d.accountsMu.Unlock()
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
return err
})
return
// d.profilesMu.Lock()
// d.accountDatasMu.Lock()
// d.accountsMu.Lock()
// defer d.profilesMu.Unlock()
// defer d.accountDatasMu.Unlock()
// defer d.accountsMu.Unlock()
// err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
// acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
// return err
// })
acc, err = d.createAccount(ctx, localpart, plaintextPassword, appserviceID)
return acc, err
}
// WARNING! This function assumes that the relevant mutexes have already
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (*api.Account, error) {
var err error
var account *api.Account
@ -224,13 +228,13 @@ func (d *Database) createAccount(
return nil, err
}
}
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil {
if account, err = d.accounts.insertAccount(ctx, localpart, hash, appserviceID); err != nil {
return nil, sqlutil.ErrUserExists
}
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
if err = d.profiles.insertProfile(ctx, localpart); err != nil {
return nil, err
}
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
if err = d.accountDatas.insertAccountData(ctx, localpart, "", "m.push_rules", json.RawMessage(`{
"global": {
"content": [],
"override": [],
@ -252,11 +256,11 @@ func (d *Database) createAccount(
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
d.accountDatasMu.Lock()
defer d.accountDatasMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
})
// d.accountDatasMu.Lock()
// defer d.accountDatasMu.Unlock()
// return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
// })
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
}
// GetAccountData returns account data related to a given localpart
@ -286,7 +290,7 @@ func (d *Database) GetAccountDataByType(
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx, nil)
return d.accounts.selectNewNumericLocalpart(ctx)
}
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
@ -305,22 +309,23 @@ var Err3PIDInUse = errors.New("This third-party identifier is already in use")
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
d.threepidsMu.Lock()
defer d.threepidsMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
return err
}
// d.threepidsMu.Lock()
// defer d.threepidsMu.Unlock()
// return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
// })
if len(user) > 0 {
return Err3PIDInUse
}
user, err := d.threepids.selectLocalpartForThreePID(
ctx, threepid, medium,
)
if err != nil {
return err
}
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
})
if len(user) > 0 {
return Err3PIDInUse
}
return d.threepids.insertThreePID(ctx, threepid, medium, localpart)
}
// RemoveThreePIDAssociation removes the association involving a given third-party
@ -330,11 +335,11 @@ func (d *Database) SaveThreePIDAssociation(
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
d.threepidsMu.Lock()
defer d.threepidsMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.threepids.deleteThreePID(ctx, txn, threepid, medium)
})
// d.threepidsMu.Lock()
// defer d.threepidsMu.Unlock()
// return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
// })
return d.threepids.deleteThreePID(ctx, threepid, medium)
}
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
@ -345,7 +350,7 @@ func (d *Database) RemoveThreePIDAssociation(
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
return d.threepids.selectLocalpartForThreePID(ctx, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
@ -362,11 +367,11 @@ func (d *Database) GetThreePIDsForLocalpart(
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows {
return true, nil
}
return false, err
response, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
// if err == sql.ErrNoRows {
// return true, nil
// }
return response == nil, err
}
// GetAccountByLocalpart returns the account associated with the given localpart.
@ -395,9 +400,9 @@ func (d *Database) CreateOpenIDToken(
token, localpart string,
) (int64, error) {
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
})
// err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
// })
var err = d.openIDTokens.insertToken(ctx, token, localpart, expiresAtMS)
return expiresAtMS, err
}

View file

@ -16,118 +16,186 @@ package cosmosdb
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
const threepidSchema = `
-- Stores data about third party identifiers
CREATE TABLE IF NOT EXISTS account_threepid (
-- The third party identifier
threepid TEXT NOT NULL,
-- The 3PID medium
medium TEXT NOT NULL DEFAULT 'email',
-- The localpart of the Matrix user ID associated to this 3PID
localpart TEXT NOT NULL,
// const threepidSchema = `
// -- Stores data about third party identifiers
// CREATE TABLE IF NOT EXISTS account_threepid (
// -- The third party identifier
// threepid TEXT NOT NULL,
// -- The 3PID medium
// medium TEXT NOT NULL DEFAULT 'email',
// -- The localpart of the Matrix user ID associated to this 3PID
// localpart TEXT NOT NULL,
PRIMARY KEY(threepid, medium)
);
// PRIMARY KEY(threepid, medium)
// );
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart);
`
const selectLocalpartForThreePIDSQL = "" +
"SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" +
"SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
const insertThreePIDSQL = "" +
"INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
const deleteThreePIDSQL = "" +
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
type threepidStatements struct {
db *sql.DB
selectLocalpartForThreePIDStmt *sql.Stmt
selectThreePIDsForLocalpartStmt *sql.Stmt
insertThreePIDStmt *sql.Stmt
deleteThreePIDStmt *sql.Stmt
type ThreePIDObject struct {
Localpart string `json:"local_part"`
ThreePID string `json:"three_pid"`
Medium string `json:"medium"`
}
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
s.db = db
_, err = db.Exec(threepidSchema)
if err != nil {
return
}
if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil {
return
}
if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil {
return
}
if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil {
return
}
if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil {
return
}
type ThreePIDCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Object ThreePIDObject `json:"_object"`
}
type threepidStatements struct {
db *Database
tableName string
}
func (s *threepidStatements) prepare(db *Database) (err error) {
s.db = db
s.tableName = "account_threepid"
return
}
func (s *threepidStatements) selectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string,
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
if err == sql.ErrNoRows {
// "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
response := []ThreePIDCosmosData{}
var selectLocalPartThreePIDCosmos = "select * from c where c._cn = @x1 and c._object.three_pid = @x2 and c._object.medium = @x3"
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": threepid,
"@x3": medium,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectLocalPartThreePIDCosmos, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return "", ex
}
if len(response) == 0 {
return "", nil
}
return
return response[0].Object.Localpart, nil
}
func (s *threepidStatements) selectThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
if err != nil {
return
// "SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
response := []ThreePIDCosmosData{}
var selectThreePIDLocalPartCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2"
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectThreePIDLocalPartCosmos, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return threepids, ex
}
if len(response) == 0 {
return threepids, nil
}
defer internal.CloseAndLogIfError(ctx, rows, "selectThreePIDsForLocalpart: rows.close() failed")
threepids = []authtypes.ThreePID{}
for rows.Next() {
var threepid string
var medium string
if err = rows.Scan(&threepid, &medium); err != nil {
return
}
for _, item := range response {
threepids = append(threepids, authtypes.ThreePID{
Address: threepid,
Medium: medium,
Address: item.Object.ThreePID,
Medium: item.Object.Medium,
})
}
return threepids, rows.Err()
return threepids, nil
}
func (s *threepidStatements) insertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
ctx context.Context, threepid, medium, localpart string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
return err
// "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
var result = ThreePIDObject{
Localpart: localpart,
Medium: medium,
ThreePID: threepid,
}
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
id := fmt.Sprintf("%s_%s", threepid, medium)
var dbData = ThreePIDCosmosData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id),
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
Timestamp: time.Now().Unix(),
Object: result,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
config.DatabaseName,
config.TenantName,
dbData,
options)
if err != nil {
return err
}
return
}
func (s *threepidStatements) deleteThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) {
stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium)
return err
ctx context.Context, threepid string, medium string) (err error) {
// "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
id := fmt.Sprintf("%s_%s", threepid, medium)
pk := cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
_, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
config.DatabaseName,
config.TenantName,
id,
options)
if err != nil {
return err
}
return
}