mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-26 08:13:09 -06:00
- Add GO CosmosDB library github.com/vippsas/go-cosmosdb
- Update YAML file to use file: everywhere except for Accounts - Use the CosmosDB conn string in the YAML - Add cosmosdbapi package to wrap the external package - Add Tenant.go to store the tenancy settings - to be removed when tenancy is implemented - Update the 5 tables to use the internal CosmosDBAPI package instead of SQL - Remove sql from storage.go and other files
This commit is contained in:
parent
234a89db5d
commit
a5ddb710d8
15
.vscode/launch.json
vendored
15
.vscode/launch.json
vendored
|
|
@ -28,6 +28,19 @@
|
|||
"${workspaceFolder}\\bin\\dendrite.yaml",
|
||||
"clientapi",
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Launch Package Monolith - CosmosDB",
|
||||
"type": "go",
|
||||
"request": "launch",
|
||||
"mode": "debug",
|
||||
"program": "${workspaceFolder}\\cmd\\dendrite-monolith-server",
|
||||
"args": [
|
||||
"-config",
|
||||
"${workspaceFolder}\\dendrite-config-cosmosdb.yaml",
|
||||
//Uncomment below to expose internal api's
|
||||
// "--api",
|
||||
// "true"
|
||||
]}
|
||||
]
|
||||
}
|
||||
|
|
@ -90,7 +90,7 @@ global:
|
|||
|
||||
# Naffka database options. Not required when using Kafka.
|
||||
naffka_database:
|
||||
connection_string: cosmosdb:naffka.db
|
||||
connection_string: file:naffka.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
|
@ -122,7 +122,7 @@ app_service_api:
|
|||
listen: http://localhost:7777
|
||||
connect: http://localhost:7777
|
||||
database:
|
||||
connection_string: cosmosdb:appservice.db
|
||||
connection_string: file:appservice.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
|
@ -202,7 +202,7 @@ federation_sender:
|
|||
listen: http://localhost:7775
|
||||
connect: http://localhost:7775
|
||||
database:
|
||||
connection_string: cosmosdb:federationsender.db
|
||||
connection_string: file:federationsender.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
|
@ -228,7 +228,7 @@ key_server:
|
|||
listen: http://localhost:7779
|
||||
connect: http://localhost:7779
|
||||
database:
|
||||
connection_string: cosmosdb:keyserver.db
|
||||
connection_string: file:keyserver.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
|
@ -241,7 +241,7 @@ media_api:
|
|||
external_api:
|
||||
listen: http://[::]:8074
|
||||
database:
|
||||
connection_string: cosmosdb:mediaapi.db
|
||||
connection_string: file:mediaapi.db
|
||||
max_open_conns: 5
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
|
@ -280,7 +280,7 @@ mscs:
|
|||
# - msc2946 (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946)
|
||||
mscs: []
|
||||
database:
|
||||
connection_string: cosmosdb:mscs.db
|
||||
connection_string: file:mscs.db
|
||||
max_open_conns: 5
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
|
@ -291,7 +291,7 @@ room_server:
|
|||
listen: http://localhost:7770
|
||||
connect: http://localhost:7770
|
||||
database:
|
||||
connection_string: cosmosdb:roomserver.db
|
||||
connection_string: file:roomserver.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
|
@ -302,7 +302,7 @@ signing_key_server:
|
|||
listen: http://localhost:7780
|
||||
connect: http://localhost:7780
|
||||
database:
|
||||
connection_string: cosmosdb:signingkeyserver.db
|
||||
connection_string: file:signingkeyserver.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
|
@ -331,7 +331,7 @@ sync_api:
|
|||
external_api:
|
||||
listen: http://[::]:8073
|
||||
database:
|
||||
connection_string: cosmosdb:syncapi.db
|
||||
connection_string: file:syncapi.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
|
@ -354,12 +354,12 @@ user_api:
|
|||
listen: http://localhost:7781
|
||||
connect: http://localhost:7781
|
||||
account_database:
|
||||
connection_string: cosmosdb:userapi_accounts.db
|
||||
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
device_database:
|
||||
connection_string: cosmosdb:userapi_devices.db
|
||||
connection_string: file:userapi_devices.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
|
|
|||
1
go.mod
1
go.mod
|
|
@ -37,6 +37,7 @@ require (
|
|||
github.com/tidwall/sjson v1.1.5
|
||||
github.com/uber/jaeger-client-go v2.25.0+incompatible
|
||||
github.com/uber/jaeger-lib v2.4.0+incompatible
|
||||
github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d // indirect
|
||||
github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20210218094457-e77ca8019daa
|
||||
go.uber.org/atomic v1.7.0
|
||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -36,6 +36,7 @@ github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/
|
|||
github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII=
|
||||
github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c=
|
||||
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
|
||||
github.com/alecthomas/repr v0.0.0-20181024024818-d37bc2a10ba1/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ=
|
||||
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
||||
|
|
@ -176,6 +177,7 @@ github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/me
|
|||
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
|
||||
github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
||||
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
|
||||
github.com/gofrs/uuid v3.1.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
|
||||
github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s=
|
||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||
|
|
@ -987,6 +989,8 @@ github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPU
|
|||
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
|
||||
github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU=
|
||||
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
|
||||
github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d h1:MZRYOouO0snrQyBAf4Wljc3qqaispjzMOhFRQgWfKMo=
|
||||
github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d/go.mod h1:ldPlejlc7ZyiP0QQWGwL9CoZLvEjhD9yzpz0ct7+sXo=
|
||||
github.com/vishvananda/netlink v1.0.0/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
|
||||
github.com/vishvananda/netns v0.0.0-20190625233234-7109fa855b0f/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
|
||||
github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 h1:EKhdznlJHPMoKr0XTrX+IlJs1LH3lyx2nfr1dOlZ79k=
|
||||
|
|
|
|||
24
internal/cosmosdbapi/client.go
Normal file
24
internal/cosmosdbapi/client.go
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
package cosmosdbapi
|
||||
|
||||
import (
|
||||
cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi"
|
||||
)
|
||||
|
||||
type CosmosConnection struct {
|
||||
Url string
|
||||
Key string
|
||||
}
|
||||
|
||||
func GetCosmosConnection(accountEndpoint string, accountKey string) CosmosConnection {
|
||||
return CosmosConnection{
|
||||
Url: accountEndpoint,
|
||||
Key: accountKey,
|
||||
}
|
||||
}
|
||||
|
||||
func GetClient(conn CosmosConnection) *cosmosapi.Client {
|
||||
cfg := cosmosapi.Config{
|
||||
MasterKey: conn.Key,
|
||||
}
|
||||
return cosmosapi.New(conn.Url, cfg, nil, nil)
|
||||
}
|
||||
10
internal/cosmosdbapi/collection.go
Normal file
10
internal/cosmosdbapi/collection.go
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
package cosmosdbapi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
)
|
||||
|
||||
func GetCollectionName(databaseName string, tableName string) string {
|
||||
return fmt.Sprintf("matrix_%s_%s", databaseName, tableName)
|
||||
}
|
||||
14
internal/cosmosdbapi/document.go
Normal file
14
internal/cosmosdbapi/document.go
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
package cosmosdbapi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
)
|
||||
|
||||
func GetDocumentId(tenantName string, collectionName string, id string) string {
|
||||
return fmt.Sprintf("%s,%s,%s", collectionName, tenantName, id)
|
||||
}
|
||||
|
||||
func GetPartitionKey(tenantName string, collectionName string) string {
|
||||
return fmt.Sprintf("%s,%s", collectionName, tenantName)
|
||||
}
|
||||
46
internal/cosmosdbapi/documentoperations.go
Normal file
46
internal/cosmosdbapi/documentoperations.go
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
package cosmosdbapi
|
||||
|
||||
import (
|
||||
cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi"
|
||||
)
|
||||
|
||||
func GetCreateDocumentOptions(pk string) cosmosapi.CreateDocumentOptions {
|
||||
return cosmosapi.CreateDocumentOptions{
|
||||
IsUpsert: false,
|
||||
PartitionKeyValue: pk,
|
||||
}
|
||||
}
|
||||
|
||||
func GetUpsertDocumentOptions(pk string) cosmosapi.CreateDocumentOptions {
|
||||
return cosmosapi.CreateDocumentOptions{
|
||||
IsUpsert: true,
|
||||
PartitionKeyValue: pk,
|
||||
}
|
||||
}
|
||||
|
||||
func GetQueryDocumentsOptions(pk string) cosmosapi.QueryDocumentsOptions {
|
||||
return cosmosapi.QueryDocumentsOptions{
|
||||
PartitionKeyValue: pk,
|
||||
IsQuery: true,
|
||||
ContentType: cosmosapi.QUERY_CONTENT_TYPE,
|
||||
}
|
||||
}
|
||||
|
||||
func GetGetDocumentOptions(pk string) cosmosapi.GetDocumentOptions {
|
||||
return cosmosapi.GetDocumentOptions{
|
||||
PartitionKeyValue: pk,
|
||||
}
|
||||
}
|
||||
|
||||
func GetReplaceDocumentOptions(pk string, etag string) cosmosapi.ReplaceDocumentOptions {
|
||||
return cosmosapi.ReplaceDocumentOptions{
|
||||
PartitionKeyValue: pk,
|
||||
IfMatch: etag,
|
||||
}
|
||||
}
|
||||
|
||||
func GetDeleteDocumentOptions(pk string) cosmosapi.DeleteDocumentOptions {
|
||||
return cosmosapi.DeleteDocumentOptions{
|
||||
PartitionKeyValue: pk,
|
||||
}
|
||||
}
|
||||
20
internal/cosmosdbapi/query.go
Normal file
20
internal/cosmosdbapi/query.go
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
package cosmosdbapi
|
||||
|
||||
import (
|
||||
cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi"
|
||||
)
|
||||
|
||||
func GetQuery(qry string, params map[string]interface{}) cosmosapi.Query {
|
||||
qryParams := []cosmosapi.QueryParam{}
|
||||
for key, value := range params {
|
||||
qryParam := cosmosapi.QueryParam {
|
||||
Name: key,
|
||||
Value: value,
|
||||
}
|
||||
qryParams = append(qryParams, qryParam)
|
||||
}
|
||||
return cosmosapi.Query {
|
||||
Query: qry,
|
||||
Params: qryParams,
|
||||
}
|
||||
}
|
||||
14
internal/cosmosdbapi/tenant.go
Normal file
14
internal/cosmosdbapi/tenant.go
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
package cosmosdbapi
|
||||
|
||||
type Tenant struct {
|
||||
DatabaseName string
|
||||
TenantName string
|
||||
}
|
||||
|
||||
//TODO: Move into Config or the JWT
|
||||
func DefaultConfig() Tenant {
|
||||
return Tenant{
|
||||
DatabaseName: "safezone_local",
|
||||
TenantName: "criticalarc.com",
|
||||
}
|
||||
}
|
||||
|
|
@ -8,5 +8,15 @@ import (
|
|||
func GetConnectionString(d *config.DataSource) config.DataSource {
|
||||
var connString string
|
||||
connString = string(*d)
|
||||
return config.DataSource(strings.Replace(connString, "cosmosdb:", "file:", 1))
|
||||
return config.DataSource(strings.Replace(connString, "cosmosdb:", "", 1))
|
||||
}
|
||||
|
||||
func GetConnectionProperties(connectionString string) map[string]string {
|
||||
connectionItemsRaw := strings.Split(connectionString, ";")
|
||||
connectionItems := map[string]string{}
|
||||
for _, item := range connectionItemsRaw {
|
||||
itemSplit := strings.SplitN(item, "=", 2)
|
||||
connectionItems[itemSplit[0]] = itemSplit[1]
|
||||
}
|
||||
return connectionItems
|
||||
}
|
||||
|
|
@ -16,68 +16,94 @@ package cosmosdb
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
)
|
||||
|
||||
const accountDataSchema = `
|
||||
-- Stores data about accounts data.
|
||||
CREATE TABLE IF NOT EXISTS account_data (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL,
|
||||
-- The room ID for this data (empty string if not specific to a room)
|
||||
room_id TEXT,
|
||||
-- The account data type
|
||||
type TEXT NOT NULL,
|
||||
-- The account data content
|
||||
content TEXT NOT NULL,
|
||||
// const accountDataSchema = `
|
||||
// -- Stores data about accounts data.
|
||||
// CREATE TABLE IF NOT EXISTS account_data (
|
||||
// -- The Matrix user ID localpart for this account
|
||||
// localpart TEXT NOT NULL,
|
||||
// -- The room ID for this data (empty string if not specific to a room)
|
||||
// room_id TEXT,
|
||||
// -- The account data type
|
||||
// type TEXT NOT NULL,
|
||||
// -- The account data content
|
||||
// content TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY(localpart, room_id, type)
|
||||
);
|
||||
`
|
||||
// PRIMARY KEY(localpart, room_id, type)
|
||||
// );
|
||||
// `
|
||||
|
||||
const insertAccountDataSQL = `
|
||||
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
||||
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
|
||||
`
|
||||
|
||||
const selectAccountDataSQL = "" +
|
||||
"SELECT room_id, type, content FROM account_data WHERE localpart = $1"
|
||||
|
||||
const selectAccountDataByTypeSQL = "" +
|
||||
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||
|
||||
type accountDataStatements struct {
|
||||
db *sql.DB
|
||||
insertAccountDataStmt *sql.Stmt
|
||||
selectAccountDataStmt *sql.Stmt
|
||||
selectAccountDataByTypeStmt *sql.Stmt
|
||||
type AccountCosmosAccountData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
Object AccountData `json:"_object"`
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
||||
type AccountData struct {
|
||||
LocalPart string `json:"local_part"`
|
||||
RoomId string `json:"room_id"`
|
||||
Type string `json:"type"`
|
||||
Content []byte `json:"content"`
|
||||
}
|
||||
|
||||
type accountDataStatements struct {
|
||||
db *Database
|
||||
tableName string
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) prepare(db *Database) (err error) {
|
||||
s.db = db
|
||||
_, err = db.Exec(accountDataSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil {
|
||||
return
|
||||
}
|
||||
s.tableName = "account_data"
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) insertAccountData(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
||||
|
||||
// INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
||||
// ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
|
||||
var result = AccountData{
|
||||
LocalPart: localpart,
|
||||
RoomId: roomID,
|
||||
Type: dataType,
|
||||
Content: content,
|
||||
}
|
||||
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
|
||||
id := ""
|
||||
if roomID == "" {
|
||||
id = fmt.Sprintf("%s_%s", result.LocalPart, result.Type)
|
||||
} else {
|
||||
id = fmt.Sprintf("%s_%s_%s", result.LocalPart, result.RoomId, result.Type)
|
||||
}
|
||||
|
||||
var dbData = AccountCosmosAccountData{
|
||||
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id),
|
||||
Cn: dbCollectionName,
|
||||
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
||||
Timestamp: time.Now().Unix(),
|
||||
Object: result,
|
||||
}
|
||||
|
||||
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
dbData,
|
||||
options)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -88,30 +114,43 @@ func (s *accountDataStatements) selectAccountData(
|
|||
/* rooms */ map[string]map[string]json.RawMessage,
|
||||
error,
|
||||
) {
|
||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
// "SELECT room_id, type, content FROM account_data WHERE localpart = $1"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
response := []AccountCosmosAccountData{}
|
||||
var selectAccountDataCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2"
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": localpart,
|
||||
}
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(selectAccountDataCosmos, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
||||
if ex != nil {
|
||||
return nil, nil, ex
|
||||
}
|
||||
|
||||
global := map[string]json.RawMessage{}
|
||||
rooms := map[string]map[string]json.RawMessage{}
|
||||
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
var dataType string
|
||||
var content []byte
|
||||
|
||||
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for i := 0; i < len(response); i++ {
|
||||
var row = response[i]
|
||||
var roomID = row.Object.RoomId
|
||||
if roomID != "" {
|
||||
if _, ok := rooms[roomID]; !ok {
|
||||
if _, ok := rooms[row.Object.RoomId]; !ok {
|
||||
rooms[roomID] = map[string]json.RawMessage{}
|
||||
}
|
||||
rooms[roomID][dataType] = content
|
||||
rooms[roomID][row.Object.Type] = row.Object.Content
|
||||
} else {
|
||||
global[dataType] = content
|
||||
global[row.Object.Type] = row.Object.Content
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -122,13 +161,39 @@ func (s *accountDataStatements) selectAccountDataByType(
|
|||
ctx context.Context, localpart, roomID, dataType string,
|
||||
) (data json.RawMessage, err error) {
|
||||
var bytes []byte
|
||||
stmt := s.selectAccountDataByTypeStmt
|
||||
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return
|
||||
|
||||
// "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
response := []AccountCosmosAccountData{}
|
||||
var selectAccountDataCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2 and c._object.room_id = @x3 and c._object.type = @x4"
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": localpart,
|
||||
"@x3": roomID,
|
||||
"@x4": dataType,
|
||||
}
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(selectAccountDataCosmos, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
||||
if ex != nil {
|
||||
return nil, ex
|
||||
}
|
||||
|
||||
if len(response) == 0 {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
bytes = response[0].Object.Content
|
||||
|
||||
data = json.RawMessage(bytes)
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,159 +16,264 @@ package cosmosdb
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const accountsSchema = `
|
||||
-- Stores data about accounts.
|
||||
CREATE TABLE IF NOT EXISTS account_accounts (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
-- When this account was first created, as a unix timestamp (ms resolution).
|
||||
created_ts BIGINT NOT NULL,
|
||||
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
||||
password_hash TEXT,
|
||||
-- Identifies which application service this account belongs to, if any.
|
||||
appservice_id TEXT,
|
||||
-- If the account is currently active
|
||||
is_deactivated BOOLEAN DEFAULT 0
|
||||
-- TODO:
|
||||
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
|
||||
);
|
||||
`
|
||||
// const accountsSchema = `
|
||||
// -- Stores data about accounts.
|
||||
// CREATE TABLE IF NOT EXISTS account_accounts (
|
||||
// -- The Matrix user ID localpart for this account
|
||||
// localpart TEXT NOT NULL PRIMARY KEY,
|
||||
// -- When this account was first created, as a unix timestamp (ms resolution).
|
||||
// created_ts BIGINT NOT NULL,
|
||||
// -- The password hash for this account. Can be NULL if this is a passwordless account.
|
||||
// password_hash TEXT,
|
||||
// -- Identifies which application service this account belongs to, if any.
|
||||
// appservice_id TEXT,
|
||||
// -- If the account is currently active
|
||||
// is_deactivated BOOLEAN DEFAULT 0
|
||||
// -- TODO:
|
||||
// -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
|
||||
// );
|
||||
// `
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
|
||||
type AccountExtended struct {
|
||||
IsDeactivated bool `json:"is_deactivated"`
|
||||
PasswordHash string `json:"password_hash"`
|
||||
Created int64 `json:"created_ts"`
|
||||
}
|
||||
|
||||
const updatePasswordSQL = "" +
|
||||
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
type AccountCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
Object api.Account `json:"_object"`
|
||||
ObjectExtended AccountExtended `json:"_object_extended"`
|
||||
}
|
||||
|
||||
const deactivateAccountSQL = "" +
|
||||
"UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
||||
|
||||
const selectAccountByLocalpartSQL = "" +
|
||||
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
|
||||
|
||||
const selectPasswordHashSQL = "" +
|
||||
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||
|
||||
const selectNewNumericLocalpartSQL = "" +
|
||||
"SELECT COUNT(localpart) FROM account_accounts"
|
||||
type AccountCosmosUserCount struct {
|
||||
UserCount int64 `json:"usercount"`
|
||||
}
|
||||
|
||||
type accountsStatements struct {
|
||||
db *sql.DB
|
||||
insertAccountStmt *sql.Stmt
|
||||
updatePasswordStmt *sql.Stmt
|
||||
deactivateAccountStmt *sql.Stmt
|
||||
selectAccountByLocalpartStmt *sql.Stmt
|
||||
selectPasswordHashStmt *sql.Stmt
|
||||
selectNewNumericLocalpartStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
db *Database
|
||||
tableName string
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func (s *accountsStatements) execSchema(db *sql.DB) error {
|
||||
_, err := db.Exec(accountsSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||
func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
|
||||
s.db = db
|
||||
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deactivateAccountStmt, err = db.Prepare(deactivateAccountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
s.tableName = "account_accounts"
|
||||
s.serverName = server
|
||||
return
|
||||
}
|
||||
|
||||
func getAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*AccountCosmosData, error) {
|
||||
response := AccountCosmosData{}
|
||||
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
docId,
|
||||
optionsGet,
|
||||
&response)
|
||||
return &response, ex
|
||||
}
|
||||
|
||||
func setAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, account AccountCosmosData) (*AccountCosmosData, error) {
|
||||
response := AccountCosmosData{}
|
||||
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, account.ETag)
|
||||
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
account.Id,
|
||||
&account,
|
||||
optionsReplace)
|
||||
return &response, ex
|
||||
}
|
||||
|
||||
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||
// on success.
|
||||
func (s *accountsStatements) insertAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
|
||||
ctx context.Context, localpart, hash, appserviceID string,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := s.insertAccountStmt
|
||||
// stmt := s.insertAccountStmt
|
||||
|
||||
var err error
|
||||
if appserviceID == "" {
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
||||
} else {
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.Account{
|
||||
var result = api.Account{
|
||||
Localpart: localpart,
|
||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
ServerName: s.serverName,
|
||||
AppServiceID: appserviceID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var extended = AccountExtended{
|
||||
IsDeactivated: false,
|
||||
PasswordHash: hash,
|
||||
Created: createdTimeMS,
|
||||
}
|
||||
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
|
||||
var dbData = AccountCosmosData{
|
||||
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart),
|
||||
Cn: dbCollectionName,
|
||||
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
||||
Timestamp: time.Now().Unix(),
|
||||
Object: result,
|
||||
ObjectExtended: extended,
|
||||
}
|
||||
|
||||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
dbData,
|
||||
options)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (s *accountsStatements) updatePassword(
|
||||
ctx context.Context, localpart, passwordHash string,
|
||||
) (err error) {
|
||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
||||
|
||||
// "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
|
||||
var response, exGet = getAccount(s, ctx, config, pk, docId)
|
||||
if exGet != nil {
|
||||
return exGet
|
||||
}
|
||||
|
||||
response.ObjectExtended.PasswordHash = passwordHash
|
||||
|
||||
var _, exReplace = setAccount(s, ctx, config, pk, *response)
|
||||
if exReplace != nil {
|
||||
return exReplace
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) deactivateAccount(
|
||||
ctx context.Context, localpart string,
|
||||
) (err error) {
|
||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
||||
|
||||
// "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
|
||||
var response, exGet = getAccount(s, ctx, config, pk, docId)
|
||||
if exGet != nil {
|
||||
return exGet
|
||||
}
|
||||
|
||||
response.ObjectExtended.IsDeactivated = true
|
||||
|
||||
var _, exReplace = setAccount(s, ctx, config, pk, *response)
|
||||
if exReplace != nil {
|
||||
return exReplace
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) selectPasswordHash(
|
||||
ctx context.Context, localpart string,
|
||||
) (hash string, err error) {
|
||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
||||
return
|
||||
|
||||
// "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
response := []AccountCosmosData{}
|
||||
var selectPasswordHashCosmos = "select * from c where c._cn = @x1 and c._object.Localpart = @x2 and c._object_extended.is_deactivated = false"
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": localpart,
|
||||
}
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(selectPasswordHashCosmos, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
||||
if ex != nil {
|
||||
return "", ex
|
||||
}
|
||||
|
||||
if len(response) == 0 {
|
||||
return "", errors.New(fmt.Sprintf("Localpart %s not found", localpart))
|
||||
}
|
||||
|
||||
if len(response) != 1 {
|
||||
return "", errors.New(fmt.Sprintf("Localpart %s has multiple entries", localpart))
|
||||
}
|
||||
|
||||
return response[0].ObjectExtended.PasswordHash, nil
|
||||
}
|
||||
|
||||
func (s *accountsStatements) selectAccountByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (*api.Account, error) {
|
||||
var appserviceIDPtr sql.NullString
|
||||
var acc api.Account
|
||||
|
||||
stmt := s.selectAccountByLocalpartStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||
}
|
||||
return nil, err
|
||||
// "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
response := []AccountCosmosData{}
|
||||
var selectPasswordHashCosmos = "select * from c where c._cn = @x1 and c._object.Localpart = @x2"
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": localpart,
|
||||
}
|
||||
if appserviceIDPtr.Valid {
|
||||
acc.AppServiceID = appserviceIDPtr.String
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(selectPasswordHashCosmos, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
||||
if ex != nil {
|
||||
return nil, ex
|
||||
}
|
||||
|
||||
if len(response) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
acc = response[0].Object
|
||||
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
acc.ServerName = s.serverName
|
||||
|
||||
|
|
@ -176,12 +281,31 @@ func (s *accountsStatements) selectAccountByLocalpart(
|
|||
}
|
||||
|
||||
func (s *accountsStatements) selectNewNumericLocalpart(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
ctx context.Context,
|
||||
) (id int64, err error) {
|
||||
stmt := s.selectNewNumericLocalpartStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
|
||||
// "SELECT COUNT(localpart) FROM account_accounts"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
var response []AccountCosmosUserCount
|
||||
var selectCountCosmos = "select count(c._ts) as usercount from c where c._cn = @x1"
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
}
|
||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||
return
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(selectCountCosmos, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
||||
if ex != nil {
|
||||
return -1, ex
|
||||
}
|
||||
|
||||
return int64(response[0].UserCount), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,12 @@
|
|||
package cosmosdb
|
||||
|
||||
import (
|
||||
"time"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const openIDTokenSchema = `
|
||||
|
|
@ -21,32 +20,24 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
|
|||
token_expires_at_ms BIGINT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertTokenSQL = "" +
|
||||
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||
|
||||
const selectTokenSQL = "" +
|
||||
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
||||
|
||||
type tokenStatements struct {
|
||||
db *sql.DB
|
||||
insertTokenStmt *sql.Stmt
|
||||
selectTokenStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
type OpenIdTokenCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
Object *api.OpenIDToken `json:"_object"`
|
||||
}
|
||||
|
||||
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||
type tokenStatements struct {
|
||||
db *Database
|
||||
tableName string
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func (s *tokenStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
|
||||
s.db = db
|
||||
_, err = db.Exec(openIDTokenSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil {
|
||||
return
|
||||
}
|
||||
s.tableName = "open_id_tokens"
|
||||
s.serverName = server
|
||||
return
|
||||
}
|
||||
|
|
@ -55,12 +46,40 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam
|
|||
// Returns new token, otherwise returns error if the token already exists.
|
||||
func (s *tokenStatements) insertToken(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
token, localpart string,
|
||||
expiresAtMS int64,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
|
||||
|
||||
// "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||
var result = &api.OpenIDToken{
|
||||
UserID: localpart,
|
||||
Token: token,
|
||||
ExpiresAtMS: expiresAtMS,
|
||||
}
|
||||
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
|
||||
|
||||
var dbData = OpenIdTokenCosmosData{
|
||||
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Token),
|
||||
Cn: dbCollectionName,
|
||||
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
||||
Timestamp: time.Now().Unix(),
|
||||
Object: result,
|
||||
}
|
||||
|
||||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
dbData,
|
||||
options)
|
||||
|
||||
if ex != nil {
|
||||
return ex
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -71,16 +90,39 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes(
|
|||
token string,
|
||||
) (*api.OpenIDTokenAttributes, error) {
|
||||
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||
&openIDTokenAttrs.UserID,
|
||||
&openIDTokenAttrs.ExpiresAtMS,
|
||||
)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve token from the db")
|
||||
}
|
||||
return nil, err
|
||||
|
||||
// "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
response := []OpenIdTokenCosmosData{}
|
||||
var selectOpenIdTokenCosmos = "select * from c where c._cn = @x1 and c._object.Token = @x2"
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": token,
|
||||
}
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(selectOpenIdTokenCosmos, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
||||
if ex != nil {
|
||||
return nil, ex
|
||||
}
|
||||
|
||||
if(len(response) == 0) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var openIdToken = response[0].Object
|
||||
openIDTokenAttrs = api.OpenIDTokenAttributes{
|
||||
UserID: openIdToken.UserID,
|
||||
ExpiresAtMS: openIdToken.ExpiresAtMS,
|
||||
}
|
||||
return &openIDTokenAttrs, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,107 +16,186 @@ package cosmosdb
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
const profilesSchema = `
|
||||
-- Stores data about accounts profiles.
|
||||
CREATE TABLE IF NOT EXISTS account_profiles (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
-- The display name for this account
|
||||
display_name TEXT,
|
||||
-- The URL of the avatar for this account
|
||||
avatar_url TEXT
|
||||
);
|
||||
`
|
||||
// const profilesSchema = `
|
||||
// -- Stores data about accounts profiles.
|
||||
// CREATE TABLE IF NOT EXISTS account_profiles (
|
||||
// -- The Matrix user ID localpart for this account
|
||||
// localpart TEXT NOT NULL PRIMARY KEY,
|
||||
// -- The display name for this account
|
||||
// display_name TEXT,
|
||||
// -- The URL of the avatar for this account
|
||||
// avatar_url TEXT
|
||||
// );
|
||||
// `
|
||||
|
||||
const insertProfileSQL = "" +
|
||||
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
||||
|
||||
const selectProfileByLocalpartSQL = "" +
|
||||
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
|
||||
|
||||
const setAvatarURLSQL = "" +
|
||||
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
|
||||
|
||||
const setDisplayNameSQL = "" +
|
||||
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
|
||||
|
||||
const selectProfilesBySearchSQL = "" +
|
||||
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
|
||||
type profilesStatements struct {
|
||||
db *sql.DB
|
||||
insertProfileStmt *sql.Stmt
|
||||
selectProfileByLocalpartStmt *sql.Stmt
|
||||
setAvatarURLStmt *sql.Stmt
|
||||
setDisplayNameStmt *sql.Stmt
|
||||
selectProfilesBySearchStmt *sql.Stmt
|
||||
type ProfileCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
Object authtypes.Profile `json:"_object"`
|
||||
}
|
||||
|
||||
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
||||
type profilesStatements struct {
|
||||
db *Database
|
||||
tableName string
|
||||
}
|
||||
|
||||
func (s *profilesStatements) prepare(db *Database) (err error) {
|
||||
s.db = db
|
||||
_, err = db.Exec(profilesSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil {
|
||||
return
|
||||
}
|
||||
s.tableName = "account_profiles"
|
||||
return
|
||||
}
|
||||
|
||||
func getProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*ProfileCosmosData, error) {
|
||||
response := ProfileCosmosData{}
|
||||
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
docId,
|
||||
optionsGet,
|
||||
&response)
|
||||
return &response, ex
|
||||
}
|
||||
|
||||
func setProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) {
|
||||
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, profile.ETag)
|
||||
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
profile.Id,
|
||||
&profile,
|
||||
optionsReplace)
|
||||
return &profile, ex
|
||||
}
|
||||
|
||||
func (s *profilesStatements) insertProfile(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
ctx context.Context, localpart string,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
||||
|
||||
// "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
||||
var result = &authtypes.Profile{
|
||||
Localpart: localpart,
|
||||
}
|
||||
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||
|
||||
var dbData = ProfileCosmosData{
|
||||
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart),
|
||||
Cn: dbCollectionName,
|
||||
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
||||
Timestamp: time.Now().Unix(),
|
||||
Object: *result,
|
||||
}
|
||||
|
||||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
dbData,
|
||||
options)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *profilesStatements) selectProfileByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (*authtypes.Profile, error) {
|
||||
var profile authtypes.Profile
|
||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
|
||||
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
response := []ProfileCosmosData{}
|
||||
var selectProfileByLocalpartCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2"
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": localpart,
|
||||
}
|
||||
return &profile, nil
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(selectProfileByLocalpartCosmos, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
||||
if ex != nil {
|
||||
return nil, ex
|
||||
}
|
||||
|
||||
if len(response) == 0 {
|
||||
return nil, errors.New(fmt.Sprintf("Localpart %s not found", len(response)))
|
||||
}
|
||||
|
||||
if len(response) != 1 {
|
||||
return nil, errors.New(fmt.Sprintf("Localpart %s has multiple entries", len(response)))
|
||||
}
|
||||
|
||||
return &response[0].Object, nil
|
||||
}
|
||||
|
||||
func (s *profilesStatements) setAvatarURL(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
||||
ctx context.Context, localpart string, avatarURL string,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
||||
_, err = stmt.ExecContext(ctx, avatarURL, localpart)
|
||||
|
||||
// "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
|
||||
|
||||
var response, exGet = getProfile(s, ctx, config, pk, docId)
|
||||
if exGet != nil {
|
||||
return exGet
|
||||
}
|
||||
|
||||
response.Object.AvatarURL = avatarURL
|
||||
|
||||
var _, exReplace = setProfile(s, ctx, config, pk, *response)
|
||||
if exReplace != nil {
|
||||
return exReplace
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *profilesStatements) setDisplayName(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
||||
ctx context.Context, localpart string, displayName string,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
||||
_, err = stmt.ExecContext(ctx, displayName, localpart)
|
||||
|
||||
// "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
|
||||
var response, exGet = getProfile(s, ctx, config, pk, docId)
|
||||
if exGet != nil {
|
||||
return exGet
|
||||
}
|
||||
|
||||
response.Object.DisplayName = displayName
|
||||
|
||||
var _, exReplace = setProfile(s, ctx, config, pk, *response)
|
||||
if exReplace != nil {
|
||||
return exReplace
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -124,20 +203,36 @@ func (s *profilesStatements) selectProfilesBySearch(
|
|||
ctx context.Context, searchString string, limit int,
|
||||
) ([]authtypes.Profile, error) {
|
||||
var profiles []authtypes.Profile
|
||||
// The fmt.Sprintf directive below is building a parameter for the
|
||||
// "LIKE" condition in the SQL query. %% escapes the % char, so the
|
||||
// statement in the end will look like "LIKE %searchString%".
|
||||
rows, err := s.selectProfilesBySearchStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
response := []ProfileCosmosData{}
|
||||
var selectProfileByLocalpartCosmos = "select top @x3 * from c where c._cn = @x1 and contains(c._object.local_part, @x2)"
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": searchString,
|
||||
"@x3": limit,
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var profile authtypes.Profile
|
||||
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
profiles = append(profiles, profile)
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(selectProfileByLocalpartCosmos, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
||||
if ex != nil {
|
||||
return nil, ex
|
||||
}
|
||||
|
||||
for i := 0; i < len(response); i++ {
|
||||
var responseData = response[i]
|
||||
profiles = append(profiles, responseData.Object)
|
||||
}
|
||||
|
||||
return profiles, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,29 +15,27 @@
|
|||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
|
||||
// "sync"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Database represents an account database
|
||||
type Database struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
|
||||
sqlutil.PartitionOffsetStatements
|
||||
accounts accountsStatements
|
||||
profiles profilesStatements
|
||||
|
|
@ -48,55 +46,57 @@ type Database struct {
|
|||
bcryptCost int
|
||||
openIDTokenLifetimeMS int64
|
||||
|
||||
accountsMu sync.Mutex
|
||||
profilesMu sync.Mutex
|
||||
accountDatasMu sync.Mutex
|
||||
threepidsMu sync.Mutex
|
||||
databaseName string
|
||||
connection cosmosdbapi.CosmosConnection
|
||||
}
|
||||
|
||||
// NewDatabase creates a new accounts and profiles database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
|
||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connString := cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
connMap := cosmosdbutil.GetConnectionProperties(string(connString))
|
||||
accountEndpoint := connMap["AccountEndpoint"]
|
||||
accountKey := connMap["AccountKey"]
|
||||
conn := cosmosdbapi.GetCosmosConnection(accountEndpoint, accountKey)
|
||||
|
||||
d := &Database{
|
||||
serverName: serverName,
|
||||
db: db,
|
||||
writer: sqlutil.NewExclusiveWriter(),
|
||||
bcryptCost: bcryptCost,
|
||||
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
serverName: serverName,
|
||||
databaseName: "userapi",
|
||||
connection: conn,
|
||||
// db: db,
|
||||
// writer: sqlutil.NewExclusiveWriter(),
|
||||
// bcryptCost: bcryptCost,
|
||||
// openIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
}
|
||||
|
||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
||||
if err = d.accounts.execSchema(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := sqlutil.NewMigrations()
|
||||
deltas.LoadIsActive(m)
|
||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// if err = d.accounts.execSchema(db); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// m := sqlutil.NewMigrations()
|
||||
// deltas.LoadIsActive(m)
|
||||
// if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
partitions := sqlutil.PartitionOffsetStatements{}
|
||||
if err = partitions.Prepare(db, d.writer, "account"); err != nil {
|
||||
// partitions := sqlutil.PartitionOffsetStatements{}
|
||||
// if err = partitions.Prepare(db, d.writer, "account"); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
var err error
|
||||
if err = d.accounts.prepare(d, serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.accounts.prepare(db, serverName); err != nil {
|
||||
if err = d.profiles.prepare(d); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.profiles.prepare(db); err != nil {
|
||||
if err = d.accountDatas.prepare(d); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.accountDatas.prepare(db); err != nil {
|
||||
if err = d.threepids.prepare(d); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.threepids.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.openIDTokens.prepare(db, serverName); err != nil {
|
||||
if err = d.openIDTokens.prepare(d, serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
@ -131,11 +131,11 @@ func (d *Database) GetProfileByLocalpart(
|
|||
func (d *Database) SetAvatarURL(
|
||||
ctx context.Context, localpart string, avatarURL string,
|
||||
) error {
|
||||
d.profilesMu.Lock()
|
||||
defer d.profilesMu.Unlock()
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL)
|
||||
})
|
||||
// d.profilesMu.Lock()
|
||||
// defer d.profilesMu.Unlock()
|
||||
// return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
// })
|
||||
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
|
||||
}
|
||||
|
||||
// SetDisplayName updates the display name of the profile associated with the given
|
||||
|
|
@ -143,11 +143,12 @@ func (d *Database) SetAvatarURL(
|
|||
func (d *Database) SetDisplayName(
|
||||
ctx context.Context, localpart string, displayName string,
|
||||
) error {
|
||||
d.profilesMu.Lock()
|
||||
defer d.profilesMu.Unlock()
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.profiles.setDisplayName(ctx, txn, localpart, displayName)
|
||||
})
|
||||
// d.profilesMu.Lock()
|
||||
// defer d.profilesMu.Unlock()
|
||||
// return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
// return d.profiles.setDisplayName(ctx, txn, localpart, displayName)
|
||||
// })
|
||||
return d.profiles.setDisplayName(ctx, localpart, displayName)
|
||||
}
|
||||
|
||||
// SetPassword sets the account password to the given hash.
|
||||
|
|
@ -170,22 +171,23 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
|
|||
// when the first txn upgrades to a write txn. We also need to lock the account creation else we can
|
||||
// race with CreateAccount
|
||||
// We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed.
|
||||
d.profilesMu.Lock()
|
||||
d.accountDatasMu.Lock()
|
||||
d.accountsMu.Lock()
|
||||
defer d.profilesMu.Unlock()
|
||||
defer d.accountDatasMu.Unlock()
|
||||
defer d.accountsMu.Unlock()
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
var numLocalpart int64
|
||||
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
localpart := strconv.FormatInt(numLocalpart, 10)
|
||||
acc, err = d.createAccount(ctx, txn, localpart, "", "")
|
||||
return err
|
||||
})
|
||||
|
||||
// d.profilesMu.Lock()
|
||||
// d.accountDatasMu.Lock()
|
||||
// d.accountsMu.Lock()
|
||||
// defer d.profilesMu.Unlock()
|
||||
// defer d.accountDatasMu.Unlock()
|
||||
// defer d.accountsMu.Unlock()
|
||||
// err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
// })
|
||||
|
||||
var numLocalpart int64
|
||||
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
localpart := strconv.FormatInt(numLocalpart, 10)
|
||||
acc, err = d.createAccount(ctx, localpart, "", "")
|
||||
return acc, err
|
||||
}
|
||||
|
||||
|
|
@ -196,23 +198,25 @@ func (d *Database) CreateAccount(
|
|||
ctx context.Context, localpart, plaintextPassword, appserviceID string,
|
||||
) (acc *api.Account, err error) {
|
||||
// Create one account at a time else we can get 'database is locked'.
|
||||
d.profilesMu.Lock()
|
||||
d.accountDatasMu.Lock()
|
||||
d.accountsMu.Lock()
|
||||
defer d.profilesMu.Unlock()
|
||||
defer d.accountDatasMu.Unlock()
|
||||
defer d.accountsMu.Unlock()
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
|
||||
return err
|
||||
})
|
||||
return
|
||||
// d.profilesMu.Lock()
|
||||
// d.accountDatasMu.Lock()
|
||||
// d.accountsMu.Lock()
|
||||
// defer d.profilesMu.Unlock()
|
||||
// defer d.accountDatasMu.Unlock()
|
||||
// defer d.accountsMu.Unlock()
|
||||
// err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
// acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
|
||||
// return err
|
||||
// })
|
||||
|
||||
acc, err = d.createAccount(ctx, localpart, plaintextPassword, appserviceID)
|
||||
return acc, err
|
||||
}
|
||||
|
||||
// WARNING! This function assumes that the relevant mutexes have already
|
||||
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
||||
func (d *Database) createAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
|
||||
ctx context.Context, localpart, plaintextPassword, appserviceID string,
|
||||
) (*api.Account, error) {
|
||||
var err error
|
||||
var account *api.Account
|
||||
|
|
@ -224,13 +228,13 @@ func (d *Database) createAccount(
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil {
|
||||
if account, err = d.accounts.insertAccount(ctx, localpart, hash, appserviceID); err != nil {
|
||||
return nil, sqlutil.ErrUserExists
|
||||
}
|
||||
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
|
||||
if err = d.profiles.insertProfile(ctx, localpart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
|
||||
if err = d.accountDatas.insertAccountData(ctx, localpart, "", "m.push_rules", json.RawMessage(`{
|
||||
"global": {
|
||||
"content": [],
|
||||
"override": [],
|
||||
|
|
@ -252,11 +256,11 @@ func (d *Database) createAccount(
|
|||
func (d *Database) SaveAccountData(
|
||||
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
d.accountDatasMu.Lock()
|
||||
defer d.accountDatasMu.Unlock()
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
||||
})
|
||||
// d.accountDatasMu.Lock()
|
||||
// defer d.accountDatasMu.Unlock()
|
||||
// return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
// })
|
||||
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
|
||||
}
|
||||
|
||||
// GetAccountData returns account data related to a given localpart
|
||||
|
|
@ -286,7 +290,7 @@ func (d *Database) GetAccountDataByType(
|
|||
func (d *Database) GetNewNumericLocalpart(
|
||||
ctx context.Context,
|
||||
) (int64, error) {
|
||||
return d.accounts.selectNewNumericLocalpart(ctx, nil)
|
||||
return d.accounts.selectNewNumericLocalpart(ctx)
|
||||
}
|
||||
|
||||
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
|
||||
|
|
@ -305,22 +309,23 @@ var Err3PIDInUse = errors.New("This third-party identifier is already in use")
|
|||
func (d *Database) SaveThreePIDAssociation(
|
||||
ctx context.Context, threepid, localpart, medium string,
|
||||
) (err error) {
|
||||
d.threepidsMu.Lock()
|
||||
defer d.threepidsMu.Unlock()
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
user, err := d.threepids.selectLocalpartForThreePID(
|
||||
ctx, txn, threepid, medium,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// d.threepidsMu.Lock()
|
||||
// defer d.threepidsMu.Unlock()
|
||||
// return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
// })
|
||||
|
||||
if len(user) > 0 {
|
||||
return Err3PIDInUse
|
||||
}
|
||||
user, err := d.threepids.selectLocalpartForThreePID(
|
||||
ctx, threepid, medium,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
|
||||
})
|
||||
if len(user) > 0 {
|
||||
return Err3PIDInUse
|
||||
}
|
||||
|
||||
return d.threepids.insertThreePID(ctx, threepid, medium, localpart)
|
||||
}
|
||||
|
||||
// RemoveThreePIDAssociation removes the association involving a given third-party
|
||||
|
|
@ -330,11 +335,11 @@ func (d *Database) SaveThreePIDAssociation(
|
|||
func (d *Database) RemoveThreePIDAssociation(
|
||||
ctx context.Context, threepid string, medium string,
|
||||
) (err error) {
|
||||
d.threepidsMu.Lock()
|
||||
defer d.threepidsMu.Unlock()
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.threepids.deleteThreePID(ctx, txn, threepid, medium)
|
||||
})
|
||||
// d.threepidsMu.Lock()
|
||||
// defer d.threepidsMu.Unlock()
|
||||
// return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
// })
|
||||
return d.threepids.deleteThreePID(ctx, threepid, medium)
|
||||
}
|
||||
|
||||
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
|
||||
|
|
@ -345,7 +350,7 @@ func (d *Database) RemoveThreePIDAssociation(
|
|||
func (d *Database) GetLocalpartForThreePID(
|
||||
ctx context.Context, threepid string, medium string,
|
||||
) (localpart string, err error) {
|
||||
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
|
||||
return d.threepids.selectLocalpartForThreePID(ctx, threepid, medium)
|
||||
}
|
||||
|
||||
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
|
||||
|
|
@ -362,11 +367,11 @@ func (d *Database) GetThreePIDsForLocalpart(
|
|||
// in the database.
|
||||
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
||||
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||
if err == sql.ErrNoRows {
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
response, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||
// if err == sql.ErrNoRows {
|
||||
// return true, nil
|
||||
// }
|
||||
return response == nil, err
|
||||
}
|
||||
|
||||
// GetAccountByLocalpart returns the account associated with the given localpart.
|
||||
|
|
@ -395,9 +400,9 @@ func (d *Database) CreateOpenIDToken(
|
|||
token, localpart string,
|
||||
) (int64, error) {
|
||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
|
||||
err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
|
||||
})
|
||||
// err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
// })
|
||||
var err = d.openIDTokens.insertToken(ctx, token, localpart, expiresAtMS)
|
||||
return expiresAtMS, err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,118 +16,186 @@ package cosmosdb
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
)
|
||||
|
||||
const threepidSchema = `
|
||||
-- Stores data about third party identifiers
|
||||
CREATE TABLE IF NOT EXISTS account_threepid (
|
||||
-- The third party identifier
|
||||
threepid TEXT NOT NULL,
|
||||
-- The 3PID medium
|
||||
medium TEXT NOT NULL DEFAULT 'email',
|
||||
-- The localpart of the Matrix user ID associated to this 3PID
|
||||
localpart TEXT NOT NULL,
|
||||
// const threepidSchema = `
|
||||
// -- Stores data about third party identifiers
|
||||
// CREATE TABLE IF NOT EXISTS account_threepid (
|
||||
// -- The third party identifier
|
||||
// threepid TEXT NOT NULL,
|
||||
// -- The 3PID medium
|
||||
// medium TEXT NOT NULL DEFAULT 'email',
|
||||
// -- The localpart of the Matrix user ID associated to this 3PID
|
||||
// localpart TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY(threepid, medium)
|
||||
);
|
||||
// PRIMARY KEY(threepid, medium)
|
||||
// );
|
||||
|
||||
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart);
|
||||
`
|
||||
|
||||
const selectLocalpartForThreePIDSQL = "" +
|
||||
"SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||
|
||||
const selectThreePIDsForLocalpartSQL = "" +
|
||||
"SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
|
||||
|
||||
const insertThreePIDSQL = "" +
|
||||
"INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteThreePIDSQL = "" +
|
||||
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||
|
||||
type threepidStatements struct {
|
||||
db *sql.DB
|
||||
selectLocalpartForThreePIDStmt *sql.Stmt
|
||||
selectThreePIDsForLocalpartStmt *sql.Stmt
|
||||
insertThreePIDStmt *sql.Stmt
|
||||
deleteThreePIDStmt *sql.Stmt
|
||||
type ThreePIDObject struct {
|
||||
Localpart string `json:"local_part"`
|
||||
ThreePID string `json:"three_pid"`
|
||||
Medium string `json:"medium"`
|
||||
}
|
||||
|
||||
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
|
||||
s.db = db
|
||||
_, err = db.Exec(threepidSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
type ThreePIDCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
Object ThreePIDObject `json:"_object"`
|
||||
}
|
||||
|
||||
type threepidStatements struct {
|
||||
db *Database
|
||||
tableName string
|
||||
}
|
||||
|
||||
func (s *threepidStatements) prepare(db *Database) (err error) {
|
||||
s.db = db
|
||||
s.tableName = "account_threepid"
|
||||
return
|
||||
}
|
||||
|
||||
func (s *threepidStatements) selectLocalpartForThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
||||
ctx context.Context, threepid string, medium string,
|
||||
) (localpart string, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
||||
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
|
||||
if err == sql.ErrNoRows {
|
||||
|
||||
// "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
response := []ThreePIDCosmosData{}
|
||||
var selectLocalPartThreePIDCosmos = "select * from c where c._cn = @x1 and c._object.three_pid = @x2 and c._object.medium = @x3"
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": threepid,
|
||||
"@x3": medium,
|
||||
}
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(selectLocalPartThreePIDCosmos, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
||||
if ex != nil {
|
||||
return "", ex
|
||||
}
|
||||
|
||||
if len(response) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
return
|
||||
|
||||
return response[0].Object.Localpart, nil
|
||||
}
|
||||
|
||||
func (s *threepidStatements) selectThreePIDsForLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (threepids []authtypes.ThreePID, err error) {
|
||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
||||
if err != nil {
|
||||
return
|
||||
|
||||
// "SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
response := []ThreePIDCosmosData{}
|
||||
var selectThreePIDLocalPartCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2"
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": localpart,
|
||||
}
|
||||
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(selectThreePIDLocalPartCosmos, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
||||
if ex != nil {
|
||||
return threepids, ex
|
||||
}
|
||||
|
||||
if len(response) == 0 {
|
||||
return threepids, nil
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectThreePIDsForLocalpart: rows.close() failed")
|
||||
|
||||
threepids = []authtypes.ThreePID{}
|
||||
for rows.Next() {
|
||||
var threepid string
|
||||
var medium string
|
||||
if err = rows.Scan(&threepid, &medium); err != nil {
|
||||
return
|
||||
}
|
||||
for _, item := range response {
|
||||
threepids = append(threepids, authtypes.ThreePID{
|
||||
Address: threepid,
|
||||
Medium: medium,
|
||||
Address: item.Object.ThreePID,
|
||||
Medium: item.Object.Medium,
|
||||
})
|
||||
}
|
||||
return threepids, rows.Err()
|
||||
return threepids, nil
|
||||
}
|
||||
|
||||
func (s *threepidStatements) insertThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
||||
ctx context.Context, threepid, medium, localpart string,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
||||
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
|
||||
return err
|
||||
|
||||
// "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
||||
var result = ThreePIDObject{
|
||||
Localpart: localpart,
|
||||
Medium: medium,
|
||||
ThreePID: threepid,
|
||||
}
|
||||
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
|
||||
id := fmt.Sprintf("%s_%s", threepid, medium)
|
||||
var dbData = ThreePIDCosmosData{
|
||||
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id),
|
||||
Cn: dbCollectionName,
|
||||
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
||||
Timestamp: time.Now().Unix(),
|
||||
Object: result,
|
||||
}
|
||||
|
||||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
dbData,
|
||||
options)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *threepidStatements) deleteThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
|
||||
_, err = stmt.ExecContext(ctx, threepid, medium)
|
||||
return err
|
||||
ctx context.Context, threepid string, medium string) (err error) {
|
||||
|
||||
// "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
id := fmt.Sprintf("%s_%s", threepid, medium)
|
||||
pk := cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
|
||||
_, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
id,
|
||||
options)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue