- Add GO CosmosDB library github.com/vippsas/go-cosmosdb

- Update YAML file to use file: everywhere except for Accounts
- Use the CosmosDB conn string in the YAML
- Add cosmosdbapi package to wrap the external package
- Add Tenant.go to store the tenancy settings - to be removed when tenancy is implemented
- Update the 5 tables to use the internal CosmosDBAPI package instead of SQL
- Remove sql from storage.go and other files
This commit is contained in:
Alex Flatow 2021-05-11 09:11:33 +10:00
parent 234a89db5d
commit a5ddb710d8
17 changed files with 1053 additions and 498 deletions

15
.vscode/launch.json vendored
View file

@ -28,6 +28,19 @@
"${workspaceFolder}\\bin\\dendrite.yaml", "${workspaceFolder}\\bin\\dendrite.yaml",
"clientapi", "clientapi",
] ]
} },
{
"name": "Launch Package Monolith - CosmosDB",
"type": "go",
"request": "launch",
"mode": "debug",
"program": "${workspaceFolder}\\cmd\\dendrite-monolith-server",
"args": [
"-config",
"${workspaceFolder}\\dendrite-config-cosmosdb.yaml",
//Uncomment below to expose internal api's
// "--api",
// "true"
]}
] ]
} }

View file

@ -90,7 +90,7 @@ global:
# Naffka database options. Not required when using Kafka. # Naffka database options. Not required when using Kafka.
naffka_database: naffka_database:
connection_string: cosmosdb:naffka.db connection_string: file:naffka.db
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -122,7 +122,7 @@ app_service_api:
listen: http://localhost:7777 listen: http://localhost:7777
connect: http://localhost:7777 connect: http://localhost:7777
database: database:
connection_string: cosmosdb:appservice.db connection_string: file:appservice.db
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -202,7 +202,7 @@ federation_sender:
listen: http://localhost:7775 listen: http://localhost:7775
connect: http://localhost:7775 connect: http://localhost:7775
database: database:
connection_string: cosmosdb:federationsender.db connection_string: file:federationsender.db
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -228,7 +228,7 @@ key_server:
listen: http://localhost:7779 listen: http://localhost:7779
connect: http://localhost:7779 connect: http://localhost:7779
database: database:
connection_string: cosmosdb:keyserver.db connection_string: file:keyserver.db
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -241,7 +241,7 @@ media_api:
external_api: external_api:
listen: http://[::]:8074 listen: http://[::]:8074
database: database:
connection_string: cosmosdb:mediaapi.db connection_string: file:mediaapi.db
max_open_conns: 5 max_open_conns: 5
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -280,7 +280,7 @@ mscs:
# - msc2946 (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946) # - msc2946 (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946)
mscs: [] mscs: []
database: database:
connection_string: cosmosdb:mscs.db connection_string: file:mscs.db
max_open_conns: 5 max_open_conns: 5
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -291,7 +291,7 @@ room_server:
listen: http://localhost:7770 listen: http://localhost:7770
connect: http://localhost:7770 connect: http://localhost:7770
database: database:
connection_string: cosmosdb:roomserver.db connection_string: file:roomserver.db
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -302,7 +302,7 @@ signing_key_server:
listen: http://localhost:7780 listen: http://localhost:7780
connect: http://localhost:7780 connect: http://localhost:7780
database: database:
connection_string: cosmosdb:signingkeyserver.db connection_string: file:signingkeyserver.db
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -331,7 +331,7 @@ sync_api:
external_api: external_api:
listen: http://[::]:8073 listen: http://[::]:8073
database: database:
connection_string: cosmosdb:syncapi.db connection_string: file:syncapi.db
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
@ -354,12 +354,12 @@ user_api:
listen: http://localhost:7781 listen: http://localhost:7781
connect: http://localhost:7781 connect: http://localhost:7781
account_database: account_database:
connection_string: cosmosdb:userapi_accounts.db connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
device_database: device_database:
connection_string: cosmosdb:userapi_devices.db connection_string: file:userapi_devices.db
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1

1
go.mod
View file

@ -37,6 +37,7 @@ require (
github.com/tidwall/sjson v1.1.5 github.com/tidwall/sjson v1.1.5
github.com/uber/jaeger-client-go v2.25.0+incompatible github.com/uber/jaeger-client-go v2.25.0+incompatible
github.com/uber/jaeger-lib v2.4.0+incompatible github.com/uber/jaeger-lib v2.4.0+incompatible
github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d // indirect
github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20210218094457-e77ca8019daa github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20210218094457-e77ca8019daa
go.uber.org/atomic v1.7.0 go.uber.org/atomic v1.7.0
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83

4
go.sum
View file

@ -36,6 +36,7 @@ github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/
github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII=
github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c= github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c=
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
github.com/alecthomas/repr v0.0.0-20181024024818-d37bc2a10ba1/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
@ -176,6 +177,7 @@ github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/me
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
github.com/gofrs/uuid v3.1.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
@ -987,6 +989,8 @@ github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPU
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio= github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU=
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d h1:MZRYOouO0snrQyBAf4Wljc3qqaispjzMOhFRQgWfKMo=
github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d/go.mod h1:ldPlejlc7ZyiP0QQWGwL9CoZLvEjhD9yzpz0ct7+sXo=
github.com/vishvananda/netlink v1.0.0/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk= github.com/vishvananda/netlink v1.0.0/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
github.com/vishvananda/netns v0.0.0-20190625233234-7109fa855b0f/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI= github.com/vishvananda/netns v0.0.0-20190625233234-7109fa855b0f/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 h1:EKhdznlJHPMoKr0XTrX+IlJs1LH3lyx2nfr1dOlZ79k= github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 h1:EKhdznlJHPMoKr0XTrX+IlJs1LH3lyx2nfr1dOlZ79k=

View file

@ -0,0 +1,24 @@
package cosmosdbapi
import (
cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi"
)
type CosmosConnection struct {
Url string
Key string
}
func GetCosmosConnection(accountEndpoint string, accountKey string) CosmosConnection {
return CosmosConnection{
Url: accountEndpoint,
Key: accountKey,
}
}
func GetClient(conn CosmosConnection) *cosmosapi.Client {
cfg := cosmosapi.Config{
MasterKey: conn.Key,
}
return cosmosapi.New(conn.Url, cfg, nil, nil)
}

View file

@ -0,0 +1,10 @@
package cosmosdbapi
import (
"fmt"
)
func GetCollectionName(databaseName string, tableName string) string {
return fmt.Sprintf("matrix_%s_%s", databaseName, tableName)
}

View file

@ -0,0 +1,14 @@
package cosmosdbapi
import (
"fmt"
)
func GetDocumentId(tenantName string, collectionName string, id string) string {
return fmt.Sprintf("%s,%s,%s", collectionName, tenantName, id)
}
func GetPartitionKey(tenantName string, collectionName string) string {
return fmt.Sprintf("%s,%s", collectionName, tenantName)
}

View file

@ -0,0 +1,46 @@
package cosmosdbapi
import (
cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi"
)
func GetCreateDocumentOptions(pk string) cosmosapi.CreateDocumentOptions {
return cosmosapi.CreateDocumentOptions{
IsUpsert: false,
PartitionKeyValue: pk,
}
}
func GetUpsertDocumentOptions(pk string) cosmosapi.CreateDocumentOptions {
return cosmosapi.CreateDocumentOptions{
IsUpsert: true,
PartitionKeyValue: pk,
}
}
func GetQueryDocumentsOptions(pk string) cosmosapi.QueryDocumentsOptions {
return cosmosapi.QueryDocumentsOptions{
PartitionKeyValue: pk,
IsQuery: true,
ContentType: cosmosapi.QUERY_CONTENT_TYPE,
}
}
func GetGetDocumentOptions(pk string) cosmosapi.GetDocumentOptions {
return cosmosapi.GetDocumentOptions{
PartitionKeyValue: pk,
}
}
func GetReplaceDocumentOptions(pk string, etag string) cosmosapi.ReplaceDocumentOptions {
return cosmosapi.ReplaceDocumentOptions{
PartitionKeyValue: pk,
IfMatch: etag,
}
}
func GetDeleteDocumentOptions(pk string) cosmosapi.DeleteDocumentOptions {
return cosmosapi.DeleteDocumentOptions{
PartitionKeyValue: pk,
}
}

View file

@ -0,0 +1,20 @@
package cosmosdbapi
import (
cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi"
)
func GetQuery(qry string, params map[string]interface{}) cosmosapi.Query {
qryParams := []cosmosapi.QueryParam{}
for key, value := range params {
qryParam := cosmosapi.QueryParam {
Name: key,
Value: value,
}
qryParams = append(qryParams, qryParam)
}
return cosmosapi.Query {
Query: qry,
Params: qryParams,
}
}

View file

@ -0,0 +1,14 @@
package cosmosdbapi
type Tenant struct {
DatabaseName string
TenantName string
}
//TODO: Move into Config or the JWT
func DefaultConfig() Tenant {
return Tenant{
DatabaseName: "safezone_local",
TenantName: "criticalarc.com",
}
}

View file

@ -8,5 +8,15 @@ import (
func GetConnectionString(d *config.DataSource) config.DataSource { func GetConnectionString(d *config.DataSource) config.DataSource {
var connString string var connString string
connString = string(*d) connString = string(*d)
return config.DataSource(strings.Replace(connString, "cosmosdb:", "file:", 1)) return config.DataSource(strings.Replace(connString, "cosmosdb:", "", 1))
}
func GetConnectionProperties(connectionString string) map[string]string {
connectionItemsRaw := strings.Split(connectionString, ";")
connectionItems := map[string]string{}
for _, item := range connectionItemsRaw {
itemSplit := strings.SplitN(item, "=", 2)
connectionItems[itemSplit[0]] = itemSplit[1]
}
return connectionItems
} }

View file

@ -16,68 +16,94 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/cosmosdbapi"
) )
const accountDataSchema = ` // const accountDataSchema = `
-- Stores data about accounts data. // -- Stores data about accounts data.
CREATE TABLE IF NOT EXISTS account_data ( // CREATE TABLE IF NOT EXISTS account_data (
-- The Matrix user ID localpart for this account // -- The Matrix user ID localpart for this account
localpart TEXT NOT NULL, // localpart TEXT NOT NULL,
-- The room ID for this data (empty string if not specific to a room) // -- The room ID for this data (empty string if not specific to a room)
room_id TEXT, // room_id TEXT,
-- The account data type // -- The account data type
type TEXT NOT NULL, // type TEXT NOT NULL,
-- The account data content // -- The account data content
content TEXT NOT NULL, // content TEXT NOT NULL,
PRIMARY KEY(localpart, room_id, type) // PRIMARY KEY(localpart, room_id, type)
); // );
` // `
const insertAccountDataSQL = ` type AccountCosmosAccountData struct {
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) Id string `json:"id"`
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 Pk string `json:"_pk"`
` Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Object AccountData `json:"_object"`
}
const selectAccountDataSQL = "" + type AccountData struct {
"SELECT room_id, type, content FROM account_data WHERE localpart = $1" LocalPart string `json:"local_part"`
RoomId string `json:"room_id"`
const selectAccountDataByTypeSQL = "" + Type string `json:"type"`
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" Content []byte `json:"content"`
}
type accountDataStatements struct { type accountDataStatements struct {
db *sql.DB db *Database
insertAccountDataStmt *sql.Stmt tableName string
selectAccountDataStmt *sql.Stmt
selectAccountDataByTypeStmt *sql.Stmt
} }
func (s *accountDataStatements) prepare(db *sql.DB) (err error) { func (s *accountDataStatements) prepare(db *Database) (err error) {
s.db = db s.db = db
_, err = db.Exec(accountDataSchema) s.tableName = "account_data"
if err != nil {
return
}
if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil {
return
}
if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil {
return
}
if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil {
return
}
return return
} }
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
// INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
// ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
var result = AccountData{
LocalPart: localpart,
RoomId: roomID,
Type: dataType,
Content: content,
}
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
id := ""
if roomID == "" {
id = fmt.Sprintf("%s_%s", result.LocalPart, result.Type)
} else {
id = fmt.Sprintf("%s_%s_%s", result.LocalPart, result.RoomId, result.Type)
}
var dbData = AccountCosmosAccountData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id),
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
Timestamp: time.Now().Unix(),
Object: result,
}
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
config.DatabaseName,
config.TenantName,
dbData,
options)
return err return err
} }
@ -88,30 +114,43 @@ func (s *accountDataStatements) selectAccountData(
/* rooms */ map[string]map[string]json.RawMessage, /* rooms */ map[string]map[string]json.RawMessage,
error, error,
) { ) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) // "SELECT room_id, type, content FROM account_data WHERE localpart = $1"
if err != nil { var config = cosmosdbapi.DefaultConfig()
return nil, nil, err var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
response := []AccountCosmosAccountData{}
var selectAccountDataCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2"
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectAccountDataCosmos, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return nil, nil, ex
} }
global := map[string]json.RawMessage{} global := map[string]json.RawMessage{}
rooms := map[string]map[string]json.RawMessage{} rooms := map[string]map[string]json.RawMessage{}
for rows.Next() { for i := 0; i < len(response); i++ {
var roomID string var row = response[i]
var dataType string var roomID = row.Object.RoomId
var content []byte
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
return nil, nil, err
}
if roomID != "" { if roomID != "" {
if _, ok := rooms[roomID]; !ok { if _, ok := rooms[row.Object.RoomId]; !ok {
rooms[roomID] = map[string]json.RawMessage{} rooms[roomID] = map[string]json.RawMessage{}
} }
rooms[roomID][dataType] = content rooms[roomID][row.Object.Type] = row.Object.Content
} else { } else {
global[dataType] = content global[row.Object.Type] = row.Object.Content
} }
} }
@ -122,13 +161,39 @@ func (s *accountDataStatements) selectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) { ) (data json.RawMessage, err error) {
var bytes []byte var bytes []byte
stmt := s.selectAccountDataByTypeStmt
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { // "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
if err == sql.ErrNoRows { var config = cosmosdbapi.DefaultConfig()
return nil, nil var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
response := []AccountCosmosAccountData{}
var selectAccountDataCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2 and c._object.room_id = @x3 and c._object.type = @x4"
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
"@x3": roomID,
"@x4": dataType,
} }
return var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectAccountDataCosmos, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return nil, ex
} }
if len(response) == 0 {
return data, nil
}
bytes = response[0].Object.Content
data = json.RawMessage(bytes) data = json.RawMessage(bytes)
return return
} }

View file

@ -16,159 +16,264 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "errors"
"fmt"
"time" "time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
) )
const accountsSchema = ` // const accountsSchema = `
-- Stores data about accounts. // -- Stores data about accounts.
CREATE TABLE IF NOT EXISTS account_accounts ( // CREATE TABLE IF NOT EXISTS account_accounts (
-- The Matrix user ID localpart for this account // -- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY, // localpart TEXT NOT NULL PRIMARY KEY,
-- When this account was first created, as a unix timestamp (ms resolution). // -- When this account was first created, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL, // created_ts BIGINT NOT NULL,
-- The password hash for this account. Can be NULL if this is a passwordless account. // -- The password hash for this account. Can be NULL if this is a passwordless account.
password_hash TEXT, // password_hash TEXT,
-- Identifies which application service this account belongs to, if any. // -- Identifies which application service this account belongs to, if any.
appservice_id TEXT, // appservice_id TEXT,
-- If the account is currently active // -- If the account is currently active
is_deactivated BOOLEAN DEFAULT 0 // is_deactivated BOOLEAN DEFAULT 0
-- TODO: // -- TODO:
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? // -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
); // );
` // `
const insertAccountSQL = "" + type AccountExtended struct {
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" IsDeactivated bool `json:"is_deactivated"`
PasswordHash string `json:"password_hash"`
Created int64 `json:"created_ts"`
}
const updatePasswordSQL = "" + type AccountCosmosData struct {
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Object api.Account `json:"_object"`
ObjectExtended AccountExtended `json:"_object_extended"`
}
const deactivateAccountSQL = "" + type AccountCosmosUserCount struct {
"UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" UserCount int64 `json:"usercount"`
}
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
const selectNewNumericLocalpartSQL = "" +
"SELECT COUNT(localpart) FROM account_accounts"
type accountsStatements struct { type accountsStatements struct {
db *sql.DB db *Database
insertAccountStmt *sql.Stmt tableName string
updatePasswordStmt *sql.Stmt
deactivateAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *accountsStatements) execSchema(db *sql.DB) error { func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
_, err := db.Exec(accountsSchema)
return err
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db s.db = db
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { s.tableName = "account_accounts"
return
}
if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
return
}
if s.deactivateAccountStmt, err = db.Prepare(deactivateAccountSQL); err != nil {
return
}
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return
}
if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
return
}
if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil {
return
}
s.serverName = server s.serverName = server
return return
} }
func getAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*AccountCosmosData, error) {
response := AccountCosmosData{}
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
ctx,
config.DatabaseName,
config.TenantName,
docId,
optionsGet,
&response)
return &response, ex
}
func setAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, account AccountCosmosData) (*AccountCosmosData, error) {
response := AccountCosmosData{}
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, account.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
config.DatabaseName,
config.TenantName,
account.Id,
&account,
optionsReplace)
return &response, ex
}
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount( func (s *accountsStatements) insertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, ctx context.Context, localpart, hash, appserviceID string,
) (*api.Account, error) { ) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt // stmt := s.insertAccountStmt
var err error var result = api.Account{
if appserviceID == "" {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
} else {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
}
if err != nil {
return nil, err
}
return &api.Account{
Localpart: localpart, Localpart: localpart,
UserID: userutil.MakeUserID(localpart, s.serverName), UserID: userutil.MakeUserID(localpart, s.serverName),
ServerName: s.serverName, ServerName: s.serverName,
AppServiceID: appserviceID, AppServiceID: appserviceID,
}, nil }
var extended = AccountExtended{
IsDeactivated: false,
PasswordHash: hash,
Created: createdTimeMS,
}
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var dbData = AccountCosmosData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart),
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
Timestamp: time.Now().Unix(),
Object: result,
ObjectExtended: extended,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
config.DatabaseName,
config.TenantName,
dbData,
options)
if err != nil {
return nil, err
}
return &result, nil
} }
func (s *accountsStatements) updatePassword( func (s *accountsStatements) updatePassword(
ctx context.Context, localpart, passwordHash string, ctx context.Context, localpart, passwordHash string,
) (err error) { ) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
// "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var response, exGet = getAccount(s, ctx, config, pk, docId)
if exGet != nil {
return exGet
}
response.ObjectExtended.PasswordHash = passwordHash
var _, exReplace = setAccount(s, ctx, config, pk, *response)
if exReplace != nil {
return exReplace
}
return return
} }
func (s *accountsStatements) deactivateAccount( func (s *accountsStatements) deactivateAccount(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (err error) { ) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
// "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var response, exGet = getAccount(s, ctx, config, pk, docId)
if exGet != nil {
return exGet
}
response.ObjectExtended.IsDeactivated = true
var _, exReplace = setAccount(s, ctx, config, pk, *response)
if exReplace != nil {
return exReplace
}
return return
} }
func (s *accountsStatements) selectPasswordHash( func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (hash string, err error) { ) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
return // "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
response := []AccountCosmosData{}
var selectPasswordHashCosmos = "select * from c where c._cn = @x1 and c._object.Localpart = @x2 and c._object_extended.is_deactivated = false"
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectPasswordHashCosmos, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return "", ex
}
if len(response) == 0 {
return "", errors.New(fmt.Sprintf("Localpart %s not found", localpart))
}
if len(response) != 1 {
return "", errors.New(fmt.Sprintf("Localpart %s has multiple entries", localpart))
}
return response[0].ObjectExtended.PasswordHash, nil
} }
func (s *accountsStatements) selectAccountByLocalpart( func (s *accountsStatements) selectAccountByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (*api.Account, error) { ) (*api.Account, error) {
var appserviceIDPtr sql.NullString
var acc api.Account var acc api.Account
stmt := s.selectAccountByLocalpartStmt // "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) var config = cosmosdbapi.DefaultConfig()
if err != nil { var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
if err != sql.ErrNoRows { var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
log.WithError(err).Error("Unable to retrieve user from the db") response := []AccountCosmosData{}
var selectPasswordHashCosmos = "select * from c where c._cn = @x1 and c._object.Localpart = @x2"
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
} }
return nil, err var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
} var query = cosmosdbapi.GetQuery(selectPasswordHashCosmos, params)
if appserviceIDPtr.Valid { var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
acc.AppServiceID = appserviceIDPtr.String ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return nil, ex
} }
if len(response) == 0 {
return nil, nil
}
acc = response[0].Object
acc.UserID = userutil.MakeUserID(localpart, s.serverName) acc.UserID = userutil.MakeUserID(localpart, s.serverName)
acc.ServerName = s.serverName acc.ServerName = s.serverName
@ -176,12 +281,31 @@ func (s *accountsStatements) selectAccountByLocalpart(
} }
func (s *accountsStatements) selectNewNumericLocalpart( func (s *accountsStatements) selectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx, ctx context.Context,
) (id int64, err error) { ) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt
if txn != nil { // "SELECT COUNT(localpart) FROM account_accounts"
stmt = sqlutil.TxStmt(txn, stmt) var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var response []AccountCosmosUserCount
var selectCountCosmos = "select count(c._ts) as usercount from c where c._cn = @x1"
params := map[string]interface{}{
"@x1": dbCollectionName,
} }
err = stmt.QueryRowContext(ctx).Scan(&id) var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
return var query = cosmosdbapi.GetQuery(selectCountCosmos, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return -1, ex
}
return int64(response[0].UserCount), nil
} }

View file

@ -1,13 +1,12 @@
package cosmosdb package cosmosdb
import ( import (
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"context" "context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
) )
const openIDTokenSchema = ` const openIDTokenSchema = `
@ -21,32 +20,24 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
token_expires_at_ms BIGINT NOT NULL token_expires_at_ms BIGINT NOT NULL
); );
` `
type OpenIdTokenCosmosData struct {
const insertTokenSQL = "" + Id string `json:"id"`
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" Pk string `json:"_pk"`
Cn string `json:"_cn"`
const selectTokenSQL = "" + ETag string `json:"_etag"`
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" Timestamp int64 `json:"_ts"`
Object *api.OpenIDToken `json:"_object"`
}
type tokenStatements struct { type tokenStatements struct {
db *sql.DB db *Database
insertTokenStmt *sql.Stmt tableName string
selectTokenStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { func (s *tokenStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
s.db = db s.db = db
_, err = db.Exec(openIDTokenSchema) s.tableName = "open_id_tokens"
if err != nil {
return err
}
if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil {
return
}
if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil {
return
}
s.serverName = server s.serverName = server
return return
} }
@ -55,12 +46,40 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam
// Returns new token, otherwise returns error if the token already exists. // Returns new token, otherwise returns error if the token already exists.
func (s *tokenStatements) insertToken( func (s *tokenStatements) insertToken(
ctx context.Context, ctx context.Context,
txn *sql.Tx,
token, localpart string, token, localpart string,
expiresAtMS int64, expiresAtMS int64,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) // "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
var result = &api.OpenIDToken{
UserID: localpart,
Token: token,
ExpiresAtMS: expiresAtMS,
}
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
var dbData = OpenIdTokenCosmosData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Token),
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
Timestamp: time.Now().Unix(),
Object: result,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
config.DatabaseName,
config.TenantName,
dbData,
options)
if ex != nil {
return ex
}
return return
} }
@ -71,16 +90,39 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes(
token string, token string,
) (*api.OpenIDTokenAttributes, error) { ) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes var openIDTokenAttrs api.OpenIDTokenAttributes
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDTokenAttrs.UserID, // "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
&openIDTokenAttrs.ExpiresAtMS, var config = cosmosdbapi.DefaultConfig()
) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
if err != nil { var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
if err != sql.ErrNoRows { response := []OpenIdTokenCosmosData{}
log.WithError(err).Error("Unable to retrieve token from the db") var selectOpenIdTokenCosmos = "select * from c where c._cn = @x1 and c._object.Token = @x2"
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": token,
} }
return nil, err var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectOpenIdTokenCosmos, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return nil, ex
} }
if(len(response) == 0) {
return nil, nil
}
var openIdToken = response[0].Object
openIDTokenAttrs = api.OpenIDTokenAttributes{
UserID: openIdToken.UserID,
ExpiresAtMS: openIdToken.ExpiresAtMS,
}
return &openIDTokenAttrs, nil return &openIDTokenAttrs, nil
} }

View file

@ -16,107 +16,186 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "errors"
"fmt" "fmt"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
const profilesSchema = ` // const profilesSchema = `
-- Stores data about accounts profiles. // -- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS account_profiles ( // CREATE TABLE IF NOT EXISTS account_profiles (
-- The Matrix user ID localpart for this account // -- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY, // localpart TEXT NOT NULL PRIMARY KEY,
-- The display name for this account // -- The display name for this account
display_name TEXT, // display_name TEXT,
-- The URL of the avatar for this account // -- The URL of the avatar for this account
avatar_url TEXT // avatar_url TEXT
); // );
` // `
const insertProfileSQL = "" + type ProfileCosmosData struct {
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" Id string `json:"id"`
Pk string `json:"_pk"`
const selectProfileByLocalpartSQL = "" + Cn string `json:"_cn"`
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
const setAvatarURLSQL = "" + Object authtypes.Profile `json:"_object"`
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" }
const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct { type profilesStatements struct {
db *sql.DB db *Database
insertProfileStmt *sql.Stmt tableName string
selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt
setDisplayNameStmt *sql.Stmt
selectProfilesBySearchStmt *sql.Stmt
} }
func (s *profilesStatements) prepare(db *sql.DB) (err error) { func (s *profilesStatements) prepare(db *Database) (err error) {
s.db = db s.db = db
_, err = db.Exec(profilesSchema) s.tableName = "account_profiles"
if err != nil {
return return
} }
if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil {
return func getProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*ProfileCosmosData, error) {
response := ProfileCosmosData{}
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
ctx,
config.DatabaseName,
config.TenantName,
docId,
optionsGet,
&response)
return &response, ex
} }
if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil {
return func setProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) {
} var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, profile.ETag)
if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil { var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
return ctx,
} config.DatabaseName,
if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { config.TenantName,
return profile.Id,
} &profile,
if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil { optionsReplace)
return return &profile, ex
}
return
} }
func (s *profilesStatements) insertProfile( func (s *profilesStatements) insertProfile(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, localpart string,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
// "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
var result = &authtypes.Profile{
Localpart: localpart,
}
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var dbData = ProfileCosmosData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart),
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
Timestamp: time.Now().Unix(),
Object: *result,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
config.DatabaseName,
config.TenantName,
dbData,
options)
return err return err
} }
func (s *profilesStatements) selectProfileByLocalpart( func (s *profilesStatements) selectProfileByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
var profile authtypes.Profile
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( // "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL, var config = cosmosdbapi.DefaultConfig()
) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
if err != nil { var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
return nil, err response := []ProfileCosmosData{}
var selectProfileByLocalpartCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2"
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
} }
return &profile, nil var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectProfileByLocalpartCosmos, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return nil, ex
}
if len(response) == 0 {
return nil, errors.New(fmt.Sprintf("Localpart %s not found", len(response)))
}
if len(response) != 1 {
return nil, errors.New(fmt.Sprintf("Localpart %s has multiple entries", len(response)))
}
return &response[0].Object, nil
} }
func (s *profilesStatements) setAvatarURL( func (s *profilesStatements) setAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ctx context.Context, localpart string, avatarURL string,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
_, err = stmt.ExecContext(ctx, avatarURL, localpart) // "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
var response, exGet = getProfile(s, ctx, config, pk, docId)
if exGet != nil {
return exGet
}
response.Object.AvatarURL = avatarURL
var _, exReplace = setProfile(s, ctx, config, pk, *response)
if exReplace != nil {
return exReplace
}
return return
} }
func (s *profilesStatements) setDisplayName( func (s *profilesStatements) setDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string, ctx context.Context, localpart string, displayName string,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
_, err = stmt.ExecContext(ctx, displayName, localpart) // "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
var response, exGet = getProfile(s, ctx, config, pk, docId)
if exGet != nil {
return exGet
}
response.Object.DisplayName = displayName
var _, exReplace = setProfile(s, ctx, config, pk, *response)
if exReplace != nil {
return exReplace
}
return return
} }
@ -124,20 +203,36 @@ func (s *profilesStatements) selectProfilesBySearch(
ctx context.Context, searchString string, limit int, ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) { ) ([]authtypes.Profile, error) {
var profiles []authtypes.Profile var profiles []authtypes.Profile
// The fmt.Sprintf directive below is building a parameter for the
// "LIKE" condition in the SQL query. %% escapes the % char, so the // "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
// statement in the end will look like "LIKE %searchString%". var config = cosmosdbapi.DefaultConfig()
rows, err := s.selectProfilesBySearchStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
if err != nil { var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
return nil, err response := []ProfileCosmosData{}
var selectProfileByLocalpartCosmos = "select top @x3 * from c where c._cn = @x1 and contains(c._object.local_part, @x2)"
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": searchString,
"@x3": limit,
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
for rows.Next() { var query = cosmosdbapi.GetQuery(selectProfileByLocalpartCosmos, params)
var profile authtypes.Profile var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { ctx,
return nil, err config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return nil, ex
} }
profiles = append(profiles, profile)
for i := 0; i < len(response); i++ {
var responseData = response[i]
profiles = append(profiles, responseData.Object)
} }
return profiles, nil return profiles, nil
} }

View file

@ -15,29 +15,27 @@
package cosmosdb package cosmosdb
import ( import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"strconv" "strconv"
"sync"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
// "sync"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
// Database represents an account database // Database represents an account database
type Database struct { type Database struct {
db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements sqlutil.PartitionOffsetStatements
accounts accountsStatements accounts accountsStatements
profiles profilesStatements profiles profilesStatements
@ -48,55 +46,57 @@ type Database struct {
bcryptCost int bcryptCost int
openIDTokenLifetimeMS int64 openIDTokenLifetimeMS int64
accountsMu sync.Mutex databaseName string
profilesMu sync.Mutex connection cosmosdbapi.CosmosConnection
accountDatasMu sync.Mutex
threepidsMu sync.Mutex
} }
// NewDatabase creates a new accounts and profiles database // NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) connString := cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
db, err := sqlutil.Open(dbProperties) connMap := cosmosdbutil.GetConnectionProperties(string(connString))
if err != nil { accountEndpoint := connMap["AccountEndpoint"]
return nil, err accountKey := connMap["AccountKey"]
} conn := cosmosdbapi.GetCosmosConnection(accountEndpoint, accountKey)
d := &Database{ d := &Database{
serverName: serverName, serverName: serverName,
db: db, databaseName: "userapi",
writer: sqlutil.NewExclusiveWriter(), connection: conn,
bcryptCost: bcryptCost, // db: db,
openIDTokenLifetimeMS: openIDTokenLifetimeMS, // writer: sqlutil.NewExclusiveWriter(),
// bcryptCost: bcryptCost,
// openIDTokenLifetimeMS: openIDTokenLifetimeMS,
} }
// Create tables before executing migrations so we don't fail if the table is missing, // Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns // and THEN prepare statements so we don't fail due to referencing new columns
if err = d.accounts.execSchema(db); err != nil { // if err = d.accounts.execSchema(db); err != nil {
return nil, err // return nil, err
} // }
m := sqlutil.NewMigrations() // m := sqlutil.NewMigrations()
deltas.LoadIsActive(m) // deltas.LoadIsActive(m)
if err = m.RunDeltas(db, dbProperties); err != nil { // if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err // return nil, err
} // }
partitions := sqlutil.PartitionOffsetStatements{} // partitions := sqlutil.PartitionOffsetStatements{}
if err = partitions.Prepare(db, d.writer, "account"); err != nil { // if err = partitions.Prepare(db, d.writer, "account"); err != nil {
// return nil, err
// }
var err error
if err = d.accounts.prepare(d, serverName); err != nil {
return nil, err return nil, err
} }
if err = d.accounts.prepare(db, serverName); err != nil { if err = d.profiles.prepare(d); err != nil {
return nil, err return nil, err
} }
if err = d.profiles.prepare(db); err != nil { if err = d.accountDatas.prepare(d); err != nil {
return nil, err return nil, err
} }
if err = d.accountDatas.prepare(db); err != nil { if err = d.threepids.prepare(d); err != nil {
return nil, err return nil, err
} }
if err = d.threepids.prepare(db); err != nil { if err = d.openIDTokens.prepare(d, serverName); err != nil {
return nil, err
}
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err return nil, err
} }
@ -131,11 +131,11 @@ func (d *Database) GetProfileByLocalpart(
func (d *Database) SetAvatarURL( func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string, ctx context.Context, localpart string, avatarURL string,
) error { ) error {
d.profilesMu.Lock() // d.profilesMu.Lock()
defer d.profilesMu.Unlock() // defer d.profilesMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { // return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL) // })
}) return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
} }
// SetDisplayName updates the display name of the profile associated with the given // SetDisplayName updates the display name of the profile associated with the given
@ -143,11 +143,12 @@ func (d *Database) SetAvatarURL(
func (d *Database) SetDisplayName( func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string, ctx context.Context, localpart string, displayName string,
) error { ) error {
d.profilesMu.Lock() // d.profilesMu.Lock()
defer d.profilesMu.Unlock() // defer d.profilesMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { // return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.profiles.setDisplayName(ctx, txn, localpart, displayName) // return d.profiles.setDisplayName(ctx, txn, localpart, displayName)
}) // })
return d.profiles.setDisplayName(ctx, localpart, displayName)
} }
// SetPassword sets the account password to the given hash. // SetPassword sets the account password to the given hash.
@ -170,22 +171,23 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
// when the first txn upgrades to a write txn. We also need to lock the account creation else we can // when the first txn upgrades to a write txn. We also need to lock the account creation else we can
// race with CreateAccount // race with CreateAccount
// We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed. // We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed.
d.profilesMu.Lock()
d.accountDatasMu.Lock() // d.profilesMu.Lock()
d.accountsMu.Lock() // d.accountDatasMu.Lock()
defer d.profilesMu.Unlock() // d.accountsMu.Lock()
defer d.accountDatasMu.Unlock() // defer d.profilesMu.Unlock()
defer d.accountsMu.Unlock() // defer d.accountDatasMu.Unlock()
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { // defer d.accountsMu.Unlock()
// err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
// })
var numLocalpart int64 var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx)
if err != nil { if err != nil {
return err return nil, err
} }
localpart := strconv.FormatInt(numLocalpart, 10) localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "") acc, err = d.createAccount(ctx, localpart, "", "")
return err
})
return acc, err return acc, err
} }
@ -196,23 +198,25 @@ func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (acc *api.Account, err error) { ) (acc *api.Account, err error) {
// Create one account at a time else we can get 'database is locked'. // Create one account at a time else we can get 'database is locked'.
d.profilesMu.Lock() // d.profilesMu.Lock()
d.accountDatasMu.Lock() // d.accountDatasMu.Lock()
d.accountsMu.Lock() // d.accountsMu.Lock()
defer d.profilesMu.Unlock() // defer d.profilesMu.Unlock()
defer d.accountDatasMu.Unlock() // defer d.accountDatasMu.Unlock()
defer d.accountsMu.Unlock() // defer d.accountsMu.Unlock()
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { // err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) // acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
return err // return err
}) // })
return
acc, err = d.createAccount(ctx, localpart, plaintextPassword, appserviceID)
return acc, err
} }
// WARNING! This function assumes that the relevant mutexes have already // WARNING! This function assumes that the relevant mutexes have already
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). // been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
func (d *Database) createAccount( func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (*api.Account, error) { ) (*api.Account, error) {
var err error var err error
var account *api.Account var account *api.Account
@ -224,13 +228,13 @@ func (d *Database) createAccount(
return nil, err return nil, err
} }
} }
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil { if account, err = d.accounts.insertAccount(ctx, localpart, hash, appserviceID); err != nil {
return nil, sqlutil.ErrUserExists return nil, sqlutil.ErrUserExists
} }
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { if err = d.profiles.insertProfile(ctx, localpart); err != nil {
return nil, err return nil, err
} }
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ if err = d.accountDatas.insertAccountData(ctx, localpart, "", "m.push_rules", json.RawMessage(`{
"global": { "global": {
"content": [], "content": [],
"override": [], "override": [],
@ -252,11 +256,11 @@ func (d *Database) createAccount(
func (d *Database) SaveAccountData( func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error { ) error {
d.accountDatasMu.Lock() // d.accountDatasMu.Lock()
defer d.accountDatasMu.Unlock() // defer d.accountDatasMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { // return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) // })
}) return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
} }
// GetAccountData returns account data related to a given localpart // GetAccountData returns account data related to a given localpart
@ -286,7 +290,7 @@ func (d *Database) GetAccountDataByType(
func (d *Database) GetNewNumericLocalpart( func (d *Database) GetNewNumericLocalpart(
ctx context.Context, ctx context.Context,
) (int64, error) { ) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx, nil) return d.accounts.selectNewNumericLocalpart(ctx)
} }
func (d *Database) hashPassword(plaintext string) (hash string, err error) { func (d *Database) hashPassword(plaintext string) (hash string, err error) {
@ -305,11 +309,13 @@ var Err3PIDInUse = errors.New("This third-party identifier is already in use")
func (d *Database) SaveThreePIDAssociation( func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string, ctx context.Context, threepid, localpart, medium string,
) (err error) { ) (err error) {
d.threepidsMu.Lock() // d.threepidsMu.Lock()
defer d.threepidsMu.Unlock() // defer d.threepidsMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { // return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
// })
user, err := d.threepids.selectLocalpartForThreePID( user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium, ctx, threepid, medium,
) )
if err != nil { if err != nil {
return err return err
@ -319,8 +325,7 @@ func (d *Database) SaveThreePIDAssociation(
return Err3PIDInUse return Err3PIDInUse
} }
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart) return d.threepids.insertThreePID(ctx, threepid, medium, localpart)
})
} }
// RemoveThreePIDAssociation removes the association involving a given third-party // RemoveThreePIDAssociation removes the association involving a given third-party
@ -330,11 +335,11 @@ func (d *Database) SaveThreePIDAssociation(
func (d *Database) RemoveThreePIDAssociation( func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string, ctx context.Context, threepid string, medium string,
) (err error) { ) (err error) {
d.threepidsMu.Lock() // d.threepidsMu.Lock()
defer d.threepidsMu.Unlock() // defer d.threepidsMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { // return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.threepids.deleteThreePID(ctx, txn, threepid, medium) // })
}) return d.threepids.deleteThreePID(ctx, threepid, medium)
} }
// GetLocalpartForThreePID looks up the localpart associated with a given third-party // GetLocalpartForThreePID looks up the localpart associated with a given third-party
@ -345,7 +350,7 @@ func (d *Database) RemoveThreePIDAssociation(
func (d *Database) GetLocalpartForThreePID( func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string, ctx context.Context, threepid string, medium string,
) (localpart string, err error) { ) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) return d.threepids.selectLocalpartForThreePID(ctx, threepid, medium)
} }
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with // GetThreePIDsForLocalpart looks up the third-party identifiers associated with
@ -362,11 +367,11 @@ func (d *Database) GetThreePIDsForLocalpart(
// in the database. // in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken. // If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart) response, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows { // if err == sql.ErrNoRows {
return true, nil // return true, nil
} // }
return false, err return response == nil, err
} }
// GetAccountByLocalpart returns the account associated with the given localpart. // GetAccountByLocalpart returns the account associated with the given localpart.
@ -395,9 +400,9 @@ func (d *Database) CreateOpenIDToken(
token, localpart string, token, localpart string,
) (int64, error) { ) (int64, error) {
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error { // err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS) // })
}) var err = d.openIDTokens.insertToken(ctx, token, localpart, expiresAtMS)
return expiresAtMS, err return expiresAtMS, err
} }

View file

@ -16,118 +16,186 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "fmt"
"time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
) )
const threepidSchema = ` // const threepidSchema = `
-- Stores data about third party identifiers // -- Stores data about third party identifiers
CREATE TABLE IF NOT EXISTS account_threepid ( // CREATE TABLE IF NOT EXISTS account_threepid (
-- The third party identifier // -- The third party identifier
threepid TEXT NOT NULL, // threepid TEXT NOT NULL,
-- The 3PID medium // -- The 3PID medium
medium TEXT NOT NULL DEFAULT 'email', // medium TEXT NOT NULL DEFAULT 'email',
-- The localpart of the Matrix user ID associated to this 3PID // -- The localpart of the Matrix user ID associated to this 3PID
localpart TEXT NOT NULL, // localpart TEXT NOT NULL,
PRIMARY KEY(threepid, medium) // PRIMARY KEY(threepid, medium)
); // );
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart); type ThreePIDObject struct {
` Localpart string `json:"local_part"`
ThreePID string `json:"three_pid"`
Medium string `json:"medium"`
}
const selectLocalpartForThreePIDSQL = "" + type ThreePIDCosmosData struct {
"SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" Id string `json:"id"`
Pk string `json:"_pk"`
const selectThreePIDsForLocalpartSQL = "" + Cn string `json:"_cn"`
"SELECT threepid, medium FROM account_threepid WHERE localpart = $1" ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
const insertThreePIDSQL = "" + Object ThreePIDObject `json:"_object"`
"INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)" }
const deleteThreePIDSQL = "" +
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
type threepidStatements struct { type threepidStatements struct {
db *sql.DB db *Database
selectLocalpartForThreePIDStmt *sql.Stmt tableName string
selectThreePIDsForLocalpartStmt *sql.Stmt
insertThreePIDStmt *sql.Stmt
deleteThreePIDStmt *sql.Stmt
} }
func (s *threepidStatements) prepare(db *sql.DB) (err error) { func (s *threepidStatements) prepare(db *Database) (err error) {
s.db = db s.db = db
_, err = db.Exec(threepidSchema) s.tableName = "account_threepid"
if err != nil {
return
}
if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil {
return
}
if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil {
return
}
if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil {
return
}
if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil {
return
}
return return
} }
func (s *threepidStatements) selectLocalpartForThreePID( func (s *threepidStatements) selectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string, ctx context.Context, threepid string, medium string,
) (localpart string, err error) { ) (localpart string, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) // "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
if err == sql.ErrNoRows { var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
response := []ThreePIDCosmosData{}
var selectLocalPartThreePIDCosmos = "select * from c where c._cn = @x1 and c._object.three_pid = @x2 and c._object.medium = @x3"
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": threepid,
"@x3": medium,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectLocalPartThreePIDCosmos, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return "", ex
}
if len(response) == 0 {
return "", nil return "", nil
} }
return
return response[0].Object.Localpart, nil
} }
func (s *threepidStatements) selectThreePIDsForLocalpart( func (s *threepidStatements) selectThreePIDsForLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) { ) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
if err != nil { // "SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
return var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
response := []ThreePIDCosmosData{}
var selectThreePIDLocalPartCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2"
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectThreePIDLocalPartCosmos, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
query,
&response,
options)
if ex != nil {
return threepids, ex
}
if len(response) == 0 {
return threepids, nil
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectThreePIDsForLocalpart: rows.close() failed")
threepids = []authtypes.ThreePID{} threepids = []authtypes.ThreePID{}
for rows.Next() { for _, item := range response {
var threepid string
var medium string
if err = rows.Scan(&threepid, &medium); err != nil {
return
}
threepids = append(threepids, authtypes.ThreePID{ threepids = append(threepids, authtypes.ThreePID{
Address: threepid, Address: item.Object.ThreePID,
Medium: medium, Medium: item.Object.Medium,
}) })
} }
return threepids, rows.Err() return threepids, nil
} }
func (s *threepidStatements) insertThreePID( func (s *threepidStatements) insertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ctx context.Context, threepid, medium, localpart string,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart) // "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
var result = ThreePIDObject{
Localpart: localpart,
Medium: medium,
ThreePID: threepid,
}
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
id := fmt.Sprintf("%s_%s", threepid, medium)
var dbData = ThreePIDCosmosData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id),
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
Timestamp: time.Now().Unix(),
Object: result,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
config.DatabaseName,
config.TenantName,
dbData,
options)
if err != nil {
return err return err
} }
return
}
func (s *threepidStatements) deleteThreePID( func (s *threepidStatements) deleteThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) { ctx context.Context, threepid string, medium string) (err error) {
stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium) // "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
id := fmt.Sprintf("%s_%s", threepid, medium)
pk := cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
_, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
config.DatabaseName,
config.TenantName,
id,
options)
if err != nil {
return err return err
} }
return
}