mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-27 16:53:10 -06:00
Merge pull request #3 from criticalarc/af/CA-5532-AddCosmosAccounts
Add CosmosDB backend for the Accounts DB in Dendrite
This commit is contained in:
commit
dfd5d445ac
15
.vscode/launch.json
vendored
15
.vscode/launch.json
vendored
|
|
@ -28,6 +28,19 @@
|
||||||
"${workspaceFolder}\\bin\\dendrite.yaml",
|
"${workspaceFolder}\\bin\\dendrite.yaml",
|
||||||
"clientapi",
|
"clientapi",
|
||||||
]
|
]
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
"name": "Launch Package Monolith - CosmosDB",
|
||||||
|
"type": "go",
|
||||||
|
"request": "launch",
|
||||||
|
"mode": "debug",
|
||||||
|
"program": "${workspaceFolder}\\cmd\\dendrite-monolith-server",
|
||||||
|
"args": [
|
||||||
|
"-config",
|
||||||
|
"${workspaceFolder}\\dendrite-config-cosmosdb.yaml",
|
||||||
|
//Uncomment below to expose internal api's
|
||||||
|
// "--api",
|
||||||
|
// "true"
|
||||||
|
]}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
@ -90,7 +90,7 @@ global:
|
||||||
|
|
||||||
# Naffka database options. Not required when using Kafka.
|
# Naffka database options. Not required when using Kafka.
|
||||||
naffka_database:
|
naffka_database:
|
||||||
connection_string: cosmosdb:naffka.db
|
connection_string: file:naffka.db
|
||||||
max_open_conns: 10
|
max_open_conns: 10
|
||||||
max_idle_conns: 2
|
max_idle_conns: 2
|
||||||
conn_max_lifetime: -1
|
conn_max_lifetime: -1
|
||||||
|
|
@ -122,7 +122,7 @@ app_service_api:
|
||||||
listen: http://localhost:7777
|
listen: http://localhost:7777
|
||||||
connect: http://localhost:7777
|
connect: http://localhost:7777
|
||||||
database:
|
database:
|
||||||
connection_string: cosmosdb:appservice.db
|
connection_string: file:appservice.db
|
||||||
max_open_conns: 10
|
max_open_conns: 10
|
||||||
max_idle_conns: 2
|
max_idle_conns: 2
|
||||||
conn_max_lifetime: -1
|
conn_max_lifetime: -1
|
||||||
|
|
@ -202,7 +202,7 @@ federation_sender:
|
||||||
listen: http://localhost:7775
|
listen: http://localhost:7775
|
||||||
connect: http://localhost:7775
|
connect: http://localhost:7775
|
||||||
database:
|
database:
|
||||||
connection_string: cosmosdb:federationsender.db
|
connection_string: file:federationsender.db
|
||||||
max_open_conns: 10
|
max_open_conns: 10
|
||||||
max_idle_conns: 2
|
max_idle_conns: 2
|
||||||
conn_max_lifetime: -1
|
conn_max_lifetime: -1
|
||||||
|
|
@ -228,7 +228,7 @@ key_server:
|
||||||
listen: http://localhost:7779
|
listen: http://localhost:7779
|
||||||
connect: http://localhost:7779
|
connect: http://localhost:7779
|
||||||
database:
|
database:
|
||||||
connection_string: cosmosdb:keyserver.db
|
connection_string: file:keyserver.db
|
||||||
max_open_conns: 10
|
max_open_conns: 10
|
||||||
max_idle_conns: 2
|
max_idle_conns: 2
|
||||||
conn_max_lifetime: -1
|
conn_max_lifetime: -1
|
||||||
|
|
@ -241,7 +241,7 @@ media_api:
|
||||||
external_api:
|
external_api:
|
||||||
listen: http://[::]:8074
|
listen: http://[::]:8074
|
||||||
database:
|
database:
|
||||||
connection_string: cosmosdb:mediaapi.db
|
connection_string: file:mediaapi.db
|
||||||
max_open_conns: 5
|
max_open_conns: 5
|
||||||
max_idle_conns: 2
|
max_idle_conns: 2
|
||||||
conn_max_lifetime: -1
|
conn_max_lifetime: -1
|
||||||
|
|
@ -280,7 +280,7 @@ mscs:
|
||||||
# - msc2946 (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946)
|
# - msc2946 (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946)
|
||||||
mscs: []
|
mscs: []
|
||||||
database:
|
database:
|
||||||
connection_string: cosmosdb:mscs.db
|
connection_string: file:mscs.db
|
||||||
max_open_conns: 5
|
max_open_conns: 5
|
||||||
max_idle_conns: 2
|
max_idle_conns: 2
|
||||||
conn_max_lifetime: -1
|
conn_max_lifetime: -1
|
||||||
|
|
@ -291,7 +291,7 @@ room_server:
|
||||||
listen: http://localhost:7770
|
listen: http://localhost:7770
|
||||||
connect: http://localhost:7770
|
connect: http://localhost:7770
|
||||||
database:
|
database:
|
||||||
connection_string: cosmosdb:roomserver.db
|
connection_string: file:roomserver.db
|
||||||
max_open_conns: 10
|
max_open_conns: 10
|
||||||
max_idle_conns: 2
|
max_idle_conns: 2
|
||||||
conn_max_lifetime: -1
|
conn_max_lifetime: -1
|
||||||
|
|
@ -302,7 +302,7 @@ signing_key_server:
|
||||||
listen: http://localhost:7780
|
listen: http://localhost:7780
|
||||||
connect: http://localhost:7780
|
connect: http://localhost:7780
|
||||||
database:
|
database:
|
||||||
connection_string: cosmosdb:signingkeyserver.db
|
connection_string: file:signingkeyserver.db
|
||||||
max_open_conns: 10
|
max_open_conns: 10
|
||||||
max_idle_conns: 2
|
max_idle_conns: 2
|
||||||
conn_max_lifetime: -1
|
conn_max_lifetime: -1
|
||||||
|
|
@ -331,7 +331,7 @@ sync_api:
|
||||||
external_api:
|
external_api:
|
||||||
listen: http://[::]:8073
|
listen: http://[::]:8073
|
||||||
database:
|
database:
|
||||||
connection_string: cosmosdb:syncapi.db
|
connection_string: file:syncapi.db
|
||||||
max_open_conns: 10
|
max_open_conns: 10
|
||||||
max_idle_conns: 2
|
max_idle_conns: 2
|
||||||
conn_max_lifetime: -1
|
conn_max_lifetime: -1
|
||||||
|
|
@ -354,12 +354,12 @@ user_api:
|
||||||
listen: http://localhost:7781
|
listen: http://localhost:7781
|
||||||
connect: http://localhost:7781
|
connect: http://localhost:7781
|
||||||
account_database:
|
account_database:
|
||||||
connection_string: cosmosdb:userapi_accounts.db
|
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
|
||||||
max_open_conns: 10
|
max_open_conns: 10
|
||||||
max_idle_conns: 2
|
max_idle_conns: 2
|
||||||
conn_max_lifetime: -1
|
conn_max_lifetime: -1
|
||||||
device_database:
|
device_database:
|
||||||
connection_string: cosmosdb:userapi_devices.db
|
connection_string: file:userapi_devices.db
|
||||||
max_open_conns: 10
|
max_open_conns: 10
|
||||||
max_idle_conns: 2
|
max_idle_conns: 2
|
||||||
conn_max_lifetime: -1
|
conn_max_lifetime: -1
|
||||||
|
|
|
||||||
1
go.mod
1
go.mod
|
|
@ -37,6 +37,7 @@ require (
|
||||||
github.com/tidwall/sjson v1.1.5
|
github.com/tidwall/sjson v1.1.5
|
||||||
github.com/uber/jaeger-client-go v2.25.0+incompatible
|
github.com/uber/jaeger-client-go v2.25.0+incompatible
|
||||||
github.com/uber/jaeger-lib v2.4.0+incompatible
|
github.com/uber/jaeger-lib v2.4.0+incompatible
|
||||||
|
github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d // indirect
|
||||||
github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20210218094457-e77ca8019daa
|
github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20210218094457-e77ca8019daa
|
||||||
go.uber.org/atomic v1.7.0
|
go.uber.org/atomic v1.7.0
|
||||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83
|
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83
|
||||||
|
|
|
||||||
4
go.sum
4
go.sum
|
|
@ -36,6 +36,7 @@ github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/
|
||||||
github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII=
|
github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII=
|
||||||
github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c=
|
github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c=
|
||||||
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
|
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
|
||||||
|
github.com/alecthomas/repr v0.0.0-20181024024818-d37bc2a10ba1/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ=
|
||||||
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
||||||
|
|
@ -176,6 +177,7 @@ github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/me
|
||||||
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
|
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
|
||||||
github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
||||||
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
|
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
|
||||||
|
github.com/gofrs/uuid v3.1.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
|
||||||
github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s=
|
github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s=
|
||||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||||
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||||
|
|
@ -987,6 +989,8 @@ github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPU
|
||||||
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
|
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
|
||||||
github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU=
|
github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU=
|
||||||
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
|
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
|
||||||
|
github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d h1:MZRYOouO0snrQyBAf4Wljc3qqaispjzMOhFRQgWfKMo=
|
||||||
|
github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d/go.mod h1:ldPlejlc7ZyiP0QQWGwL9CoZLvEjhD9yzpz0ct7+sXo=
|
||||||
github.com/vishvananda/netlink v1.0.0/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
|
github.com/vishvananda/netlink v1.0.0/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
|
||||||
github.com/vishvananda/netns v0.0.0-20190625233234-7109fa855b0f/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
|
github.com/vishvananda/netns v0.0.0-20190625233234-7109fa855b0f/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
|
||||||
github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 h1:EKhdznlJHPMoKr0XTrX+IlJs1LH3lyx2nfr1dOlZ79k=
|
github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 h1:EKhdznlJHPMoKr0XTrX+IlJs1LH3lyx2nfr1dOlZ79k=
|
||||||
|
|
|
||||||
24
internal/cosmosdbapi/client.go
Normal file
24
internal/cosmosdbapi/client.go
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
package cosmosdbapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CosmosConnection struct {
|
||||||
|
Url string
|
||||||
|
Key string
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetCosmosConnection(accountEndpoint string, accountKey string) CosmosConnection {
|
||||||
|
return CosmosConnection{
|
||||||
|
Url: accountEndpoint,
|
||||||
|
Key: accountKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetClient(conn CosmosConnection) *cosmosapi.Client {
|
||||||
|
cfg := cosmosapi.Config{
|
||||||
|
MasterKey: conn.Key,
|
||||||
|
}
|
||||||
|
return cosmosapi.New(conn.Url, cfg, nil, nil)
|
||||||
|
}
|
||||||
10
internal/cosmosdbapi/collection.go
Normal file
10
internal/cosmosdbapi/collection.go
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
package cosmosdbapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetCollectionName(databaseName string, tableName string) string {
|
||||||
|
return fmt.Sprintf("matrix_%s_%s", databaseName, tableName)
|
||||||
|
}
|
||||||
14
internal/cosmosdbapi/document.go
Normal file
14
internal/cosmosdbapi/document.go
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
package cosmosdbapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetDocumentId(tenantName string, collectionName string, id string) string {
|
||||||
|
return fmt.Sprintf("%s,%s,%s", collectionName, tenantName, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPartitionKey(tenantName string, collectionName string) string {
|
||||||
|
return fmt.Sprintf("%s,%s", collectionName, tenantName)
|
||||||
|
}
|
||||||
46
internal/cosmosdbapi/documentoperations.go
Normal file
46
internal/cosmosdbapi/documentoperations.go
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
package cosmosdbapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetCreateDocumentOptions(pk string) cosmosapi.CreateDocumentOptions {
|
||||||
|
return cosmosapi.CreateDocumentOptions{
|
||||||
|
IsUpsert: false,
|
||||||
|
PartitionKeyValue: pk,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUpsertDocumentOptions(pk string) cosmosapi.CreateDocumentOptions {
|
||||||
|
return cosmosapi.CreateDocumentOptions{
|
||||||
|
IsUpsert: true,
|
||||||
|
PartitionKeyValue: pk,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetQueryDocumentsOptions(pk string) cosmosapi.QueryDocumentsOptions {
|
||||||
|
return cosmosapi.QueryDocumentsOptions{
|
||||||
|
PartitionKeyValue: pk,
|
||||||
|
IsQuery: true,
|
||||||
|
ContentType: cosmosapi.QUERY_CONTENT_TYPE,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetGetDocumentOptions(pk string) cosmosapi.GetDocumentOptions {
|
||||||
|
return cosmosapi.GetDocumentOptions{
|
||||||
|
PartitionKeyValue: pk,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetReplaceDocumentOptions(pk string, etag string) cosmosapi.ReplaceDocumentOptions {
|
||||||
|
return cosmosapi.ReplaceDocumentOptions{
|
||||||
|
PartitionKeyValue: pk,
|
||||||
|
IfMatch: etag,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetDeleteDocumentOptions(pk string) cosmosapi.DeleteDocumentOptions {
|
||||||
|
return cosmosapi.DeleteDocumentOptions{
|
||||||
|
PartitionKeyValue: pk,
|
||||||
|
}
|
||||||
|
}
|
||||||
20
internal/cosmosdbapi/query.go
Normal file
20
internal/cosmosdbapi/query.go
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
package cosmosdbapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetQuery(qry string, params map[string]interface{}) cosmosapi.Query {
|
||||||
|
qryParams := []cosmosapi.QueryParam{}
|
||||||
|
for key, value := range params {
|
||||||
|
qryParam := cosmosapi.QueryParam {
|
||||||
|
Name: key,
|
||||||
|
Value: value,
|
||||||
|
}
|
||||||
|
qryParams = append(qryParams, qryParam)
|
||||||
|
}
|
||||||
|
return cosmosapi.Query {
|
||||||
|
Query: qry,
|
||||||
|
Params: qryParams,
|
||||||
|
}
|
||||||
|
}
|
||||||
14
internal/cosmosdbapi/tenant.go
Normal file
14
internal/cosmosdbapi/tenant.go
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
package cosmosdbapi
|
||||||
|
|
||||||
|
type Tenant struct {
|
||||||
|
DatabaseName string
|
||||||
|
TenantName string
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO: Move into Config or the JWT
|
||||||
|
func DefaultConfig() Tenant {
|
||||||
|
return Tenant{
|
||||||
|
DatabaseName: "safezone_local",
|
||||||
|
TenantName: "criticalarc.com",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -8,5 +8,15 @@ import (
|
||||||
func GetConnectionString(d *config.DataSource) config.DataSource {
|
func GetConnectionString(d *config.DataSource) config.DataSource {
|
||||||
var connString string
|
var connString string
|
||||||
connString = string(*d)
|
connString = string(*d)
|
||||||
return config.DataSource(strings.Replace(connString, "cosmosdb:", "file:", 1))
|
return config.DataSource(strings.Replace(connString, "cosmosdb:", "", 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetConnectionProperties(connectionString string) map[string]string {
|
||||||
|
connectionItemsRaw := strings.Split(connectionString, ";")
|
||||||
|
connectionItems := map[string]string{}
|
||||||
|
for _, item := range connectionItemsRaw {
|
||||||
|
itemSplit := strings.SplitN(item, "=", 2)
|
||||||
|
connectionItems[itemSplit[0]] = itemSplit[1]
|
||||||
|
}
|
||||||
|
return connectionItems
|
||||||
}
|
}
|
||||||
|
|
@ -16,68 +16,99 @@ package cosmosdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
)
|
)
|
||||||
|
|
||||||
const accountDataSchema = `
|
// const accountDataSchema = `
|
||||||
-- Stores data about accounts data.
|
// -- Stores data about accounts data.
|
||||||
CREATE TABLE IF NOT EXISTS account_data (
|
// CREATE TABLE IF NOT EXISTS account_data (
|
||||||
-- The Matrix user ID localpart for this account
|
// -- The Matrix user ID localpart for this account
|
||||||
localpart TEXT NOT NULL,
|
// localpart TEXT NOT NULL,
|
||||||
-- The room ID for this data (empty string if not specific to a room)
|
// -- The room ID for this data (empty string if not specific to a room)
|
||||||
room_id TEXT,
|
// room_id TEXT,
|
||||||
-- The account data type
|
// -- The account data type
|
||||||
type TEXT NOT NULL,
|
// type TEXT NOT NULL,
|
||||||
-- The account data content
|
// -- The account data content
|
||||||
content TEXT NOT NULL,
|
// content TEXT NOT NULL,
|
||||||
|
|
||||||
PRIMARY KEY(localpart, room_id, type)
|
// PRIMARY KEY(localpart, room_id, type)
|
||||||
);
|
// );
|
||||||
`
|
// `
|
||||||
|
|
||||||
const insertAccountDataSQL = `
|
type AccountDataCosmosData struct {
|
||||||
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
Id string `json:"id"`
|
||||||
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
|
Pk string `json:"_pk"`
|
||||||
`
|
Cn string `json:"_cn"`
|
||||||
|
ETag string `json:"_etag"`
|
||||||
const selectAccountDataSQL = "" +
|
Timestamp int64 `json:"_ts"`
|
||||||
"SELECT room_id, type, content FROM account_data WHERE localpart = $1"
|
AccountData AccountDataCosmos `json:"mx_userapi_accountdata"`
|
||||||
|
|
||||||
const selectAccountDataByTypeSQL = "" +
|
|
||||||
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
|
||||||
|
|
||||||
type accountDataStatements struct {
|
|
||||||
db *sql.DB
|
|
||||||
insertAccountDataStmt *sql.Stmt
|
|
||||||
selectAccountDataStmt *sql.Stmt
|
|
||||||
selectAccountDataByTypeStmt *sql.Stmt
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
type AccountDataCosmos struct {
|
||||||
|
LocalPart string `json:"local_part"`
|
||||||
|
RoomId string `json:"room_id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Content []byte `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type accountDataStatements struct {
|
||||||
|
db *Database
|
||||||
|
// insertAccountDataStmt *sql.Stmt
|
||||||
|
selectAccountDataStmt string
|
||||||
|
selectAccountDataByTypeStmt string
|
||||||
|
tableName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accountDataStatements) prepare(db *Database) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
_, err = db.Exec(accountDataSchema)
|
s.selectAccountDataStmt = "select * from c where c._cn = @x1 and c.mx_userapi_accountdata.local_part = @x2"
|
||||||
if err != nil {
|
s.selectAccountDataByTypeStmt = "select * from c where c._cn = @x1 and c.mx_userapi_accountdata.local_part = @x2 and c.mx_userapi_accountdata.room_id = @x3 and c.mx_userapi_accountdata.type = @x4"
|
||||||
return
|
s.tableName = "account_data"
|
||||||
}
|
|
||||||
if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) insertAccountData(
|
func (s *accountDataStatements) insertAccountData(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
||||||
) error {
|
) error {
|
||||||
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
|
||||||
|
// INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
||||||
|
// ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
|
||||||
|
var result = AccountDataCosmos{
|
||||||
|
LocalPart: localpart,
|
||||||
|
RoomId: roomID,
|
||||||
|
Type: dataType,
|
||||||
|
Content: content,
|
||||||
|
}
|
||||||
|
|
||||||
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
|
||||||
|
id := ""
|
||||||
|
if roomID == "" {
|
||||||
|
id = fmt.Sprintf("%s_%s", result.LocalPart, result.Type)
|
||||||
|
} else {
|
||||||
|
id = fmt.Sprintf("%s_%s_%s", result.LocalPart, result.RoomId, result.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
var dbData = AccountDataCosmosData{
|
||||||
|
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id),
|
||||||
|
Cn: dbCollectionName,
|
||||||
|
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
AccountData: result,
|
||||||
|
}
|
||||||
|
|
||||||
|
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||||
|
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
dbData,
|
||||||
|
options)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -88,30 +119,42 @@ func (s *accountDataStatements) selectAccountData(
|
||||||
/* rooms */ map[string]map[string]json.RawMessage,
|
/* rooms */ map[string]map[string]json.RawMessage,
|
||||||
error,
|
error,
|
||||||
) {
|
) {
|
||||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
// "SELECT room_id, type, content FROM account_data WHERE localpart = $1"
|
||||||
if err != nil {
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
return nil, nil, err
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||||
|
response := []AccountDataCosmosData{}
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": localpart,
|
||||||
|
}
|
||||||
|
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(s.selectAccountDataStmt, params)
|
||||||
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if ex != nil {
|
||||||
|
return nil, nil, ex
|
||||||
}
|
}
|
||||||
|
|
||||||
global := map[string]json.RawMessage{}
|
global := map[string]json.RawMessage{}
|
||||||
rooms := map[string]map[string]json.RawMessage{}
|
rooms := map[string]map[string]json.RawMessage{}
|
||||||
|
|
||||||
for rows.Next() {
|
for i := 0; i < len(response); i++ {
|
||||||
var roomID string
|
var row = response[i]
|
||||||
var dataType string
|
var roomID = row.AccountData.RoomId
|
||||||
var content []byte
|
|
||||||
|
|
||||||
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if roomID != "" {
|
if roomID != "" {
|
||||||
if _, ok := rooms[roomID]; !ok {
|
if _, ok := rooms[row.AccountData.RoomId]; !ok {
|
||||||
rooms[roomID] = map[string]json.RawMessage{}
|
rooms[roomID] = map[string]json.RawMessage{}
|
||||||
}
|
}
|
||||||
rooms[roomID][dataType] = content
|
rooms[roomID][row.AccountData.Type] = row.AccountData.Content
|
||||||
} else {
|
} else {
|
||||||
global[dataType] = content
|
global[row.AccountData.Type] = row.AccountData.Content
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -122,13 +165,38 @@ func (s *accountDataStatements) selectAccountDataByType(
|
||||||
ctx context.Context, localpart, roomID, dataType string,
|
ctx context.Context, localpart, roomID, dataType string,
|
||||||
) (data json.RawMessage, err error) {
|
) (data json.RawMessage, err error) {
|
||||||
var bytes []byte
|
var bytes []byte
|
||||||
stmt := s.selectAccountDataByTypeStmt
|
|
||||||
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
|
// "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||||
if err == sql.ErrNoRows {
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
return nil, nil
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
|
||||||
}
|
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||||
return
|
response := []AccountDataCosmosData{}
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": localpart,
|
||||||
|
"@x3": roomID,
|
||||||
|
"@x4": dataType,
|
||||||
}
|
}
|
||||||
|
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(s.selectAccountDataByTypeStmt, params)
|
||||||
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if ex != nil {
|
||||||
|
return nil, ex
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(response) == 0 {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes = response[0].AccountData.Content
|
||||||
|
|
||||||
data = json.RawMessage(bytes)
|
data = json.RawMessage(bytes)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,159 +16,288 @@ package cosmosdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const accountsSchema = `
|
// const accountsSchema = `
|
||||||
-- Stores data about accounts.
|
// -- Stores data about accounts.
|
||||||
CREATE TABLE IF NOT EXISTS account_accounts (
|
// CREATE TABLE IF NOT EXISTS account_accounts (
|
||||||
-- The Matrix user ID localpart for this account
|
// -- The Matrix user ID localpart for this account
|
||||||
localpart TEXT NOT NULL PRIMARY KEY,
|
// localpart TEXT NOT NULL PRIMARY KEY,
|
||||||
-- When this account was first created, as a unix timestamp (ms resolution).
|
// -- When this account was first created, as a unix timestamp (ms resolution).
|
||||||
created_ts BIGINT NOT NULL,
|
// created_ts BIGINT NOT NULL,
|
||||||
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
// -- The password hash for this account. Can be NULL if this is a passwordless account.
|
||||||
password_hash TEXT,
|
// password_hash TEXT,
|
||||||
-- Identifies which application service this account belongs to, if any.
|
// -- Identifies which application service this account belongs to, if any.
|
||||||
appservice_id TEXT,
|
// appservice_id TEXT,
|
||||||
-- If the account is currently active
|
// -- If the account is currently active
|
||||||
is_deactivated BOOLEAN DEFAULT 0
|
// is_deactivated BOOLEAN DEFAULT 0
|
||||||
-- TODO:
|
// -- TODO:
|
||||||
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
|
// -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
|
||||||
);
|
// );
|
||||||
`
|
// `
|
||||||
|
|
||||||
const insertAccountSQL = "" +
|
type AccountCosmos struct {
|
||||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
|
UserID string `json:"user_id"`
|
||||||
|
Localpart string `json:"local_part"`
|
||||||
|
ServerName gomatrixserverlib.ServerName `json:"server_name"`
|
||||||
|
AppServiceID string `json:"app_service_id"`
|
||||||
|
IsDeactivated bool `json:"is_deactivated"`
|
||||||
|
PasswordHash string `json:"password_hash"`
|
||||||
|
Created int64 `json:"created_ts"`
|
||||||
|
}
|
||||||
|
|
||||||
const updatePasswordSQL = "" +
|
type AccountCosmosData struct {
|
||||||
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
Id string `json:"id"`
|
||||||
|
Pk string `json:"_pk"`
|
||||||
|
Cn string `json:"_cn"`
|
||||||
|
ETag string `json:"_etag"`
|
||||||
|
Timestamp int64 `json:"_ts"`
|
||||||
|
Account AccountCosmos `json:"mx_userapi_account"`
|
||||||
|
}
|
||||||
|
|
||||||
const deactivateAccountSQL = "" +
|
type AccountCosmosUserCount struct {
|
||||||
"UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
UserCount int64 `json:"usercount"`
|
||||||
|
}
|
||||||
const selectAccountByLocalpartSQL = "" +
|
|
||||||
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
|
|
||||||
|
|
||||||
const selectPasswordHashSQL = "" +
|
|
||||||
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
|
||||||
|
|
||||||
const selectNewNumericLocalpartSQL = "" +
|
|
||||||
"SELECT COUNT(localpart) FROM account_accounts"
|
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
db *sql.DB
|
db *Database
|
||||||
insertAccountStmt *sql.Stmt
|
selectAccountByLocalpartStmt string
|
||||||
updatePasswordStmt *sql.Stmt
|
selectPasswordHashStmt string
|
||||||
deactivateAccountStmt *sql.Stmt
|
selectNewNumericLocalpartStmt string
|
||||||
selectAccountByLocalpartStmt *sql.Stmt
|
tableName string
|
||||||
selectPasswordHashStmt *sql.Stmt
|
|
||||||
selectNewNumericLocalpartStmt *sql.Stmt
|
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) execSchema(db *sql.DB) error {
|
func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
|
||||||
_, err := db.Exec(accountsSchema)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
|
||||||
s.db = db
|
s.db = db
|
||||||
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
|
s.selectPasswordHashStmt = "select * from c where c._cn = @x1 and c.mx_userapi_account.local_part = @x2 and c.mx_userapi_account.is_deactivated = false"
|
||||||
return
|
s.selectAccountByLocalpartStmt = "select * from c where c._cn = @x1 and c.mx_userapi_account.local_part = @x2"
|
||||||
}
|
s.selectNewNumericLocalpartStmt = "select count(c._ts) as usercount from c where c._cn = @x1"
|
||||||
if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
|
s.tableName = "account_accounts"
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deactivateAccountStmt, err = db.Prepare(deactivateAccountSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.serverName = server
|
s.serverName = server
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*AccountCosmosData, error) {
|
||||||
|
response := AccountCosmosData{}
|
||||||
|
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
|
||||||
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
docId,
|
||||||
|
optionsGet,
|
||||||
|
&response)
|
||||||
|
return &response, ex
|
||||||
|
}
|
||||||
|
|
||||||
|
func setAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, account AccountCosmosData) (*AccountCosmosData, error) {
|
||||||
|
response := AccountCosmosData{}
|
||||||
|
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, account.ETag)
|
||||||
|
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
account.Id,
|
||||||
|
&account,
|
||||||
|
optionsReplace)
|
||||||
|
return &response, ex
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapFromAccount(db AccountCosmos) api.Account {
|
||||||
|
return api.Account{
|
||||||
|
AppServiceID: db.AppServiceID,
|
||||||
|
Localpart: db.Localpart,
|
||||||
|
ServerName: db.ServerName,
|
||||||
|
UserID: db.UserID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapToAccount(api api.Account) AccountCosmos {
|
||||||
|
return AccountCosmos{
|
||||||
|
AppServiceID: api.AppServiceID,
|
||||||
|
Localpart: api.Localpart,
|
||||||
|
ServerName: api.ServerName,
|
||||||
|
UserID: api.UserID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
||||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||||
// on success.
|
// on success.
|
||||||
func (s *accountsStatements) insertAccount(
|
func (s *accountsStatements) insertAccount(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
|
ctx context.Context, localpart, hash, appserviceID string,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
stmt := s.insertAccountStmt
|
// stmt := s.insertAccountStmt
|
||||||
|
|
||||||
var err error
|
var result = api.Account{
|
||||||
if appserviceID == "" {
|
|
||||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
|
||||||
} else {
|
|
||||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &api.Account{
|
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||||
ServerName: s.serverName,
|
ServerName: s.serverName,
|
||||||
AppServiceID: appserviceID,
|
AppServiceID: appserviceID,
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
//Add the extra properties not on the API
|
||||||
|
var data = mapToAccount(result)
|
||||||
|
data.Created = createdTimeMS
|
||||||
|
data.PasswordHash = hash
|
||||||
|
data.IsDeactivated = false
|
||||||
|
|
||||||
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||||
|
|
||||||
|
var dbData = AccountCosmosData{
|
||||||
|
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart),
|
||||||
|
Cn: dbCollectionName,
|
||||||
|
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
Account: data,
|
||||||
|
}
|
||||||
|
|
||||||
|
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||||
|
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
dbData,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) updatePassword(
|
func (s *accountsStatements) updatePassword(
|
||||||
ctx context.Context, localpart, passwordHash string,
|
ctx context.Context, localpart, passwordHash string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
|
||||||
|
// "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||||
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||||
|
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||||
|
|
||||||
|
var response, exGet = getAccount(s, ctx, config, pk, docId)
|
||||||
|
if exGet != nil {
|
||||||
|
return exGet
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Account.PasswordHash = passwordHash
|
||||||
|
|
||||||
|
var _, exReplace = setAccount(s, ctx, config, pk, *response)
|
||||||
|
if exReplace != nil {
|
||||||
|
return exReplace
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) deactivateAccount(
|
func (s *accountsStatements) deactivateAccount(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
|
||||||
|
// "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
||||||
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||||
|
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||||
|
|
||||||
|
var response, exGet = getAccount(s, ctx, config, pk, docId)
|
||||||
|
if exGet != nil {
|
||||||
|
return exGet
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Account.IsDeactivated = true
|
||||||
|
|
||||||
|
var _, exReplace = setAccount(s, ctx, config, pk, *response)
|
||||||
|
if exReplace != nil {
|
||||||
|
return exReplace
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) selectPasswordHash(
|
func (s *accountsStatements) selectPasswordHash(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (hash string, err error) {
|
) (hash string, err error) {
|
||||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
|
||||||
return
|
// "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||||
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||||
|
response := []AccountCosmosData{}
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": localpart,
|
||||||
|
}
|
||||||
|
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(s.selectPasswordHashStmt, params)
|
||||||
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if ex != nil {
|
||||||
|
return "", ex
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(response) == 0 {
|
||||||
|
return "", errors.New(fmt.Sprintf("Localpart %s not found", localpart))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(response) != 1 {
|
||||||
|
return "", errors.New(fmt.Sprintf("Localpart %s has multiple entries", localpart))
|
||||||
|
}
|
||||||
|
|
||||||
|
return response[0].Account.PasswordHash, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) selectAccountByLocalpart(
|
func (s *accountsStatements) selectAccountByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
var appserviceIDPtr sql.NullString
|
|
||||||
var acc api.Account
|
var acc api.Account
|
||||||
|
|
||||||
stmt := s.selectAccountByLocalpartStmt
|
// "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
|
||||||
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
if err != nil {
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||||
if err != sql.ErrNoRows {
|
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
response := []AccountCosmosData{}
|
||||||
}
|
params := map[string]interface{}{
|
||||||
return nil, err
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": localpart,
|
||||||
}
|
}
|
||||||
if appserviceIDPtr.Valid {
|
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
acc.AppServiceID = appserviceIDPtr.String
|
var query = cosmosdbapi.GetQuery(s.selectAccountByLocalpartStmt, params)
|
||||||
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if ex != nil {
|
||||||
|
return nil, ex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(response) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
acc = mapFromAccount(response[0].Account)
|
||||||
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
|
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||||
acc.ServerName = s.serverName
|
acc.ServerName = s.serverName
|
||||||
|
|
||||||
|
|
@ -176,12 +305,30 @@ func (s *accountsStatements) selectAccountByLocalpart(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) selectNewNumericLocalpart(
|
func (s *accountsStatements) selectNewNumericLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context,
|
||||||
) (id int64, err error) {
|
) (id int64, err error) {
|
||||||
stmt := s.selectNewNumericLocalpartStmt
|
|
||||||
if txn != nil {
|
// "SELECT COUNT(localpart) FROM account_accounts"
|
||||||
stmt = sqlutil.TxStmt(txn, stmt)
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||||
|
var response []AccountCosmosUserCount
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
}
|
}
|
||||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
return
|
var query = cosmosdbapi.GetQuery(s.selectNewNumericLocalpartStmt, params)
|
||||||
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if ex != nil {
|
||||||
|
return -1, ex
|
||||||
|
}
|
||||||
|
|
||||||
|
return int64(response[0].UserCount), nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,51 +2,70 @@ package cosmosdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const openIDTokenSchema = `
|
// const openIDTokenSchema = `
|
||||||
-- Stores data about accounts.
|
// -- Stores data about accounts.
|
||||||
CREATE TABLE IF NOT EXISTS open_id_tokens (
|
// CREATE TABLE IF NOT EXISTS open_id_tokens (
|
||||||
-- The value of the token issued to a user
|
// -- The value of the token issued to a user
|
||||||
token TEXT NOT NULL PRIMARY KEY,
|
// token TEXT NOT NULL PRIMARY KEY,
|
||||||
-- The Matrix user ID for this account
|
// -- The Matrix user ID for this account
|
||||||
localpart TEXT NOT NULL,
|
// localpart TEXT NOT NULL,
|
||||||
-- When the token expires, as a unix timestamp (ms resolution).
|
// -- When the token expires, as a unix timestamp (ms resolution).
|
||||||
token_expires_at_ms BIGINT NOT NULL
|
// token_expires_at_ms BIGINT NOT NULL
|
||||||
);
|
// );
|
||||||
`
|
// `
|
||||||
|
|
||||||
const insertTokenSQL = "" +
|
// OpenIDToken represents an OpenID token
|
||||||
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
type OpenIDTokenCosmos struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
ExpiresAtMS int64 `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
const selectTokenSQL = "" +
|
type OpenIdTokenCosmosData struct {
|
||||||
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
Id string `json:"id"`
|
||||||
|
Pk string `json:"_pk"`
|
||||||
|
Cn string `json:"_cn"`
|
||||||
|
ETag string `json:"_etag"`
|
||||||
|
Timestamp int64 `json:"_ts"`
|
||||||
|
OpenIdToken OpenIDTokenCosmos `json:"mx_userapi_openidtoken"`
|
||||||
|
}
|
||||||
|
|
||||||
type tokenStatements struct {
|
type tokenStatements struct {
|
||||||
db *sql.DB
|
db *Database
|
||||||
insertTokenStmt *sql.Stmt
|
// insertTokenStmt *sql.Stmt
|
||||||
selectTokenStmt *sql.Stmt
|
selectTokenStmt string
|
||||||
|
tableName string
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
func mapFromToken(db OpenIDTokenCosmos) api.OpenIDToken {
|
||||||
|
return api.OpenIDToken{
|
||||||
|
ExpiresAtMS: db.ExpiresAtMS,
|
||||||
|
Token: db.Token,
|
||||||
|
UserID: db.UserID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapToToken(api api.OpenIDToken) OpenIDTokenCosmos {
|
||||||
|
return OpenIDTokenCosmos{
|
||||||
|
ExpiresAtMS: api.ExpiresAtMS,
|
||||||
|
Token: api.Token,
|
||||||
|
UserID: api.UserID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tokenStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
_, err = db.Exec(openIDTokenSchema)
|
s.selectTokenStmt = "select * from c where c._cn = @x1 and c.mx_userapi_openidtoken.token = @x2"
|
||||||
if err != nil {
|
s.tableName = "open_id_tokens"
|
||||||
return err
|
|
||||||
}
|
|
||||||
if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.serverName = server
|
s.serverName = server
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -55,12 +74,40 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam
|
||||||
// Returns new token, otherwise returns error if the token already exists.
|
// Returns new token, otherwise returns error if the token already exists.
|
||||||
func (s *tokenStatements) insertToken(
|
func (s *tokenStatements) insertToken(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
txn *sql.Tx,
|
|
||||||
token, localpart string,
|
token, localpart string,
|
||||||
expiresAtMS int64,
|
expiresAtMS int64,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
|
||||||
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
|
// "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||||
|
var result = &api.OpenIDToken{
|
||||||
|
UserID: localpart,
|
||||||
|
Token: token,
|
||||||
|
ExpiresAtMS: expiresAtMS,
|
||||||
|
}
|
||||||
|
|
||||||
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
|
||||||
|
|
||||||
|
var dbData = OpenIdTokenCosmosData{
|
||||||
|
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Token),
|
||||||
|
Cn: dbCollectionName,
|
||||||
|
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
OpenIdToken: mapToToken(*result),
|
||||||
|
}
|
||||||
|
|
||||||
|
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||||
|
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
dbData,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if ex != nil {
|
||||||
|
return ex
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -71,16 +118,38 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes(
|
||||||
token string,
|
token string,
|
||||||
) (*api.OpenIDTokenAttributes, error) {
|
) (*api.OpenIDTokenAttributes, error) {
|
||||||
var openIDTokenAttrs api.OpenIDTokenAttributes
|
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||||
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
|
||||||
&openIDTokenAttrs.UserID,
|
// "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
||||||
&openIDTokenAttrs.ExpiresAtMS,
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
|
||||||
if err != nil {
|
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||||
if err != sql.ErrNoRows {
|
response := []OpenIdTokenCosmosData{}
|
||||||
log.WithError(err).Error("Unable to retrieve token from the db")
|
params := map[string]interface{}{
|
||||||
}
|
"@x1": dbCollectionName,
|
||||||
return nil, err
|
"@x2": token,
|
||||||
|
}
|
||||||
|
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(s.selectTokenStmt, params)
|
||||||
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if ex != nil {
|
||||||
|
return nil, ex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(response) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var openIdToken = response[0].OpenIdToken
|
||||||
|
openIDTokenAttrs = api.OpenIDTokenAttributes{
|
||||||
|
UserID: openIdToken.UserID,
|
||||||
|
ExpiresAtMS: openIdToken.ExpiresAtMS,
|
||||||
|
}
|
||||||
return &openIDTokenAttrs, nil
|
return &openIDTokenAttrs, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,107 +16,216 @@ package cosmosdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const profilesSchema = `
|
// const profilesSchema = `
|
||||||
-- Stores data about accounts profiles.
|
// -- Stores data about accounts profiles.
|
||||||
CREATE TABLE IF NOT EXISTS account_profiles (
|
// CREATE TABLE IF NOT EXISTS account_profiles (
|
||||||
-- The Matrix user ID localpart for this account
|
// -- The Matrix user ID localpart for this account
|
||||||
localpart TEXT NOT NULL PRIMARY KEY,
|
// localpart TEXT NOT NULL PRIMARY KEY,
|
||||||
-- The display name for this account
|
// -- The display name for this account
|
||||||
display_name TEXT,
|
// display_name TEXT,
|
||||||
-- The URL of the avatar for this account
|
// -- The URL of the avatar for this account
|
||||||
avatar_url TEXT
|
// avatar_url TEXT
|
||||||
);
|
// );
|
||||||
`
|
// `
|
||||||
|
|
||||||
const insertProfileSQL = "" +
|
// Profile represents the profile for a Matrix account.
|
||||||
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
type ProfileCosmos struct {
|
||||||
|
Localpart string `json:"local_part"`
|
||||||
const selectProfileByLocalpartSQL = "" +
|
DisplayName string `json:"display_name"`
|
||||||
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
|
AvatarURL string `json:"avatar_url"`
|
||||||
|
|
||||||
const setAvatarURLSQL = "" +
|
|
||||||
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
|
|
||||||
|
|
||||||
const setDisplayNameSQL = "" +
|
|
||||||
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
|
|
||||||
|
|
||||||
const selectProfilesBySearchSQL = "" +
|
|
||||||
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
|
||||||
|
|
||||||
type profilesStatements struct {
|
|
||||||
db *sql.DB
|
|
||||||
insertProfileStmt *sql.Stmt
|
|
||||||
selectProfileByLocalpartStmt *sql.Stmt
|
|
||||||
setAvatarURLStmt *sql.Stmt
|
|
||||||
setDisplayNameStmt *sql.Stmt
|
|
||||||
selectProfilesBySearchStmt *sql.Stmt
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
type ProfileCosmosData struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Pk string `json:"_pk"`
|
||||||
|
Cn string `json:"_cn"`
|
||||||
|
ETag string `json:"_etag"`
|
||||||
|
Timestamp int64 `json:"_ts"`
|
||||||
|
Profile ProfileCosmos `json:"mx_userapi_profile"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type profilesStatements struct {
|
||||||
|
db *Database
|
||||||
|
// insertProfileStmt *sql.Stmt
|
||||||
|
selectProfileByLocalpartStmt string
|
||||||
|
// setAvatarURLStmt *sql.Stmt
|
||||||
|
// setDisplayNameStmt *sql.Stmt
|
||||||
|
selectProfilesBySearchStmt string
|
||||||
|
tableName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapFromProfile(db ProfileCosmos) authtypes.Profile {
|
||||||
|
return authtypes.Profile{
|
||||||
|
AvatarURL: db.AvatarURL,
|
||||||
|
DisplayName: db.DisplayName,
|
||||||
|
Localpart: db.Localpart,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapToProfile(api authtypes.Profile) ProfileCosmos {
|
||||||
|
return ProfileCosmos{
|
||||||
|
AvatarURL: api.AvatarURL,
|
||||||
|
DisplayName: api.DisplayName,
|
||||||
|
Localpart: api.Localpart,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *profilesStatements) prepare(db *Database) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
_, err = db.Exec(profilesSchema)
|
s.selectProfileByLocalpartStmt = "select * from c where c._cn = @x1 and c.mx_userapi_profile.local_part = @x2"
|
||||||
if err != nil {
|
s.selectProfilesBySearchStmt = "select top @x3 * from c where c._cn = @x1 and contains(c.mx_userapi_profile.local_part, @x2)"
|
||||||
return
|
s.tableName = "account_profiles"
|
||||||
}
|
|
||||||
if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*ProfileCosmosData, error) {
|
||||||
|
response := ProfileCosmosData{}
|
||||||
|
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
|
||||||
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
docId,
|
||||||
|
optionsGet,
|
||||||
|
&response)
|
||||||
|
return &response, ex
|
||||||
|
}
|
||||||
|
|
||||||
|
func setProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) {
|
||||||
|
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, profile.ETag)
|
||||||
|
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
profile.Id,
|
||||||
|
&profile,
|
||||||
|
optionsReplace)
|
||||||
|
return &profile, ex
|
||||||
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) insertProfile(
|
func (s *profilesStatements) insertProfile(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) error {
|
) error {
|
||||||
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
|
||||||
|
// "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
||||||
|
var result = &authtypes.Profile{
|
||||||
|
Localpart: localpart,
|
||||||
|
}
|
||||||
|
|
||||||
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||||
|
|
||||||
|
var dbData = ProfileCosmosData{
|
||||||
|
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart),
|
||||||
|
Cn: dbCollectionName,
|
||||||
|
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
Profile: mapToProfile(*result),
|
||||||
|
}
|
||||||
|
|
||||||
|
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||||
|
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
dbData,
|
||||||
|
options)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) selectProfileByLocalpart(
|
func (s *profilesStatements) selectProfileByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (*authtypes.Profile, error) {
|
) (*authtypes.Profile, error) {
|
||||||
var profile authtypes.Profile
|
|
||||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
|
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
|
||||||
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||||
if err != nil {
|
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||||
return nil, err
|
response := []ProfileCosmosData{}
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": localpart,
|
||||||
}
|
}
|
||||||
return &profile, nil
|
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(s.selectProfileByLocalpartStmt, params)
|
||||||
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if ex != nil {
|
||||||
|
return nil, ex
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(response) == 0 {
|
||||||
|
return nil, errors.New(fmt.Sprintf("Localpart %s not found", len(response)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(response) != 1 {
|
||||||
|
return nil, errors.New(fmt.Sprintf("Localpart %s has multiple entries", len(response)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var result = mapFromProfile(response[0].Profile)
|
||||||
|
return &result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) setAvatarURL(
|
func (s *profilesStatements) setAvatarURL(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
ctx context.Context, localpart string, avatarURL string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
|
||||||
_, err = stmt.ExecContext(ctx, avatarURL, localpart)
|
// "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
|
||||||
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||||
|
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
|
||||||
|
|
||||||
|
var response, exGet = getProfile(s, ctx, config, pk, docId)
|
||||||
|
if exGet != nil {
|
||||||
|
return exGet
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Profile.AvatarURL = avatarURL
|
||||||
|
|
||||||
|
var _, exReplace = setProfile(s, ctx, config, pk, *response)
|
||||||
|
if exReplace != nil {
|
||||||
|
return exReplace
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) setDisplayName(
|
func (s *profilesStatements) setDisplayName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
ctx context.Context, localpart string, displayName string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
|
||||||
_, err = stmt.ExecContext(ctx, displayName, localpart)
|
// "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
|
||||||
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||||
|
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
|
||||||
|
var response, exGet = getProfile(s, ctx, config, pk, docId)
|
||||||
|
if exGet != nil {
|
||||||
|
return exGet
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Profile.DisplayName = displayName
|
||||||
|
|
||||||
|
var _, exReplace = setProfile(s, ctx, config, pk, *response)
|
||||||
|
if exReplace != nil {
|
||||||
|
return exReplace
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -124,20 +233,35 @@ func (s *profilesStatements) selectProfilesBySearch(
|
||||||
ctx context.Context, searchString string, limit int,
|
ctx context.Context, searchString string, limit int,
|
||||||
) ([]authtypes.Profile, error) {
|
) ([]authtypes.Profile, error) {
|
||||||
var profiles []authtypes.Profile
|
var profiles []authtypes.Profile
|
||||||
// The fmt.Sprintf directive below is building a parameter for the
|
|
||||||
// "LIKE" condition in the SQL query. %% escapes the % char, so the
|
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||||
// statement in the end will look like "LIKE %searchString%".
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
rows, err := s.selectProfilesBySearchStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||||
if err != nil {
|
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||||
return nil, err
|
response := []ProfileCosmosData{}
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": searchString,
|
||||||
|
"@x3": limit,
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
for rows.Next() {
|
var query = cosmosdbapi.GetQuery(s.selectProfilesBySearchStmt, params)
|
||||||
var profile authtypes.Profile
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
ctx,
|
||||||
return nil, err
|
config.DatabaseName,
|
||||||
}
|
config.TenantName,
|
||||||
profiles = append(profiles, profile)
|
query,
|
||||||
|
&response,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if ex != nil {
|
||||||
|
return nil, ex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(response); i++ {
|
||||||
|
var responseData = response[i]
|
||||||
|
profiles = append(profiles, mapFromProfile(responseData.Profile))
|
||||||
|
}
|
||||||
|
|
||||||
return profiles, nil
|
return profiles, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,29 +15,27 @@
|
||||||
package cosmosdb
|
package cosmosdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
|
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||||
|
|
||||||
|
// "sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Database represents an account database
|
// Database represents an account database
|
||||||
type Database struct {
|
type Database struct {
|
||||||
db *sql.DB
|
|
||||||
writer sqlutil.Writer
|
|
||||||
|
|
||||||
sqlutil.PartitionOffsetStatements
|
sqlutil.PartitionOffsetStatements
|
||||||
accounts accountsStatements
|
accounts accountsStatements
|
||||||
profiles profilesStatements
|
profiles profilesStatements
|
||||||
|
|
@ -48,55 +46,57 @@ type Database struct {
|
||||||
bcryptCost int
|
bcryptCost int
|
||||||
openIDTokenLifetimeMS int64
|
openIDTokenLifetimeMS int64
|
||||||
|
|
||||||
accountsMu sync.Mutex
|
databaseName string
|
||||||
profilesMu sync.Mutex
|
connection cosmosdbapi.CosmosConnection
|
||||||
accountDatasMu sync.Mutex
|
|
||||||
threepidsMu sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase creates a new accounts and profiles database
|
// NewDatabase creates a new accounts and profiles database
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
|
||||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
connString := cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||||
db, err := sqlutil.Open(dbProperties)
|
connMap := cosmosdbutil.GetConnectionProperties(string(connString))
|
||||||
if err != nil {
|
accountEndpoint := connMap["AccountEndpoint"]
|
||||||
return nil, err
|
accountKey := connMap["AccountKey"]
|
||||||
}
|
conn := cosmosdbapi.GetCosmosConnection(accountEndpoint, accountKey)
|
||||||
|
|
||||||
d := &Database{
|
d := &Database{
|
||||||
serverName: serverName,
|
serverName: serverName,
|
||||||
db: db,
|
databaseName: "userapi",
|
||||||
writer: sqlutil.NewExclusiveWriter(),
|
connection: conn,
|
||||||
bcryptCost: bcryptCost,
|
// db: db,
|
||||||
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
// writer: sqlutil.NewExclusiveWriter(),
|
||||||
|
// bcryptCost: bcryptCost,
|
||||||
|
// openIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
// and THEN prepare statements so we don't fail due to referencing new columns
|
||||||
if err = d.accounts.execSchema(db); err != nil {
|
// if err = d.accounts.execSchema(db); err != nil {
|
||||||
return nil, err
|
// return nil, err
|
||||||
}
|
// }
|
||||||
m := sqlutil.NewMigrations()
|
// m := sqlutil.NewMigrations()
|
||||||
deltas.LoadIsActive(m)
|
// deltas.LoadIsActive(m)
|
||||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
// if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||||
return nil, err
|
// return nil, err
|
||||||
}
|
// }
|
||||||
|
|
||||||
partitions := sqlutil.PartitionOffsetStatements{}
|
// partitions := sqlutil.PartitionOffsetStatements{}
|
||||||
if err = partitions.Prepare(db, d.writer, "account"); err != nil {
|
// if err = partitions.Prepare(db, d.writer, "account"); err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
var err error
|
||||||
|
if err = d.accounts.prepare(d, serverName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err = d.accounts.prepare(db, serverName); err != nil {
|
if err = d.profiles.prepare(d); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err = d.profiles.prepare(db); err != nil {
|
if err = d.accountDatas.prepare(d); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err = d.accountDatas.prepare(db); err != nil {
|
if err = d.threepids.prepare(d); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err = d.threepids.prepare(db); err != nil {
|
if err = d.openIDTokens.prepare(d, serverName); err != nil {
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.openIDTokens.prepare(db, serverName); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -131,11 +131,11 @@ func (d *Database) GetProfileByLocalpart(
|
||||||
func (d *Database) SetAvatarURL(
|
func (d *Database) SetAvatarURL(
|
||||||
ctx context.Context, localpart string, avatarURL string,
|
ctx context.Context, localpart string, avatarURL string,
|
||||||
) error {
|
) error {
|
||||||
d.profilesMu.Lock()
|
// d.profilesMu.Lock()
|
||||||
defer d.profilesMu.Unlock()
|
// defer d.profilesMu.Unlock()
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
// return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL)
|
// })
|
||||||
})
|
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetDisplayName updates the display name of the profile associated with the given
|
// SetDisplayName updates the display name of the profile associated with the given
|
||||||
|
|
@ -143,11 +143,12 @@ func (d *Database) SetAvatarURL(
|
||||||
func (d *Database) SetDisplayName(
|
func (d *Database) SetDisplayName(
|
||||||
ctx context.Context, localpart string, displayName string,
|
ctx context.Context, localpart string, displayName string,
|
||||||
) error {
|
) error {
|
||||||
d.profilesMu.Lock()
|
// d.profilesMu.Lock()
|
||||||
defer d.profilesMu.Unlock()
|
// defer d.profilesMu.Unlock()
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
// return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
return d.profiles.setDisplayName(ctx, txn, localpart, displayName)
|
// return d.profiles.setDisplayName(ctx, txn, localpart, displayName)
|
||||||
})
|
// })
|
||||||
|
return d.profiles.setDisplayName(ctx, localpart, displayName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetPassword sets the account password to the given hash.
|
// SetPassword sets the account password to the given hash.
|
||||||
|
|
@ -170,22 +171,23 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
|
||||||
// when the first txn upgrades to a write txn. We also need to lock the account creation else we can
|
// when the first txn upgrades to a write txn. We also need to lock the account creation else we can
|
||||||
// race with CreateAccount
|
// race with CreateAccount
|
||||||
// We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed.
|
// We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed.
|
||||||
d.profilesMu.Lock()
|
|
||||||
d.accountDatasMu.Lock()
|
// d.profilesMu.Lock()
|
||||||
d.accountsMu.Lock()
|
// d.accountDatasMu.Lock()
|
||||||
defer d.profilesMu.Unlock()
|
// d.accountsMu.Lock()
|
||||||
defer d.accountDatasMu.Unlock()
|
// defer d.profilesMu.Unlock()
|
||||||
defer d.accountsMu.Unlock()
|
// defer d.accountDatasMu.Unlock()
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
// defer d.accountsMu.Unlock()
|
||||||
var numLocalpart int64
|
// err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
|
// })
|
||||||
if err != nil {
|
|
||||||
return err
|
var numLocalpart int64
|
||||||
}
|
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx)
|
||||||
localpart := strconv.FormatInt(numLocalpart, 10)
|
if err != nil {
|
||||||
acc, err = d.createAccount(ctx, txn, localpart, "", "")
|
return nil, err
|
||||||
return err
|
}
|
||||||
})
|
localpart := strconv.FormatInt(numLocalpart, 10)
|
||||||
|
acc, err = d.createAccount(ctx, localpart, "", "")
|
||||||
return acc, err
|
return acc, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -196,23 +198,25 @@ func (d *Database) CreateAccount(
|
||||||
ctx context.Context, localpart, plaintextPassword, appserviceID string,
|
ctx context.Context, localpart, plaintextPassword, appserviceID string,
|
||||||
) (acc *api.Account, err error) {
|
) (acc *api.Account, err error) {
|
||||||
// Create one account at a time else we can get 'database is locked'.
|
// Create one account at a time else we can get 'database is locked'.
|
||||||
d.profilesMu.Lock()
|
// d.profilesMu.Lock()
|
||||||
d.accountDatasMu.Lock()
|
// d.accountDatasMu.Lock()
|
||||||
d.accountsMu.Lock()
|
// d.accountsMu.Lock()
|
||||||
defer d.profilesMu.Unlock()
|
// defer d.profilesMu.Unlock()
|
||||||
defer d.accountDatasMu.Unlock()
|
// defer d.accountDatasMu.Unlock()
|
||||||
defer d.accountsMu.Unlock()
|
// defer d.accountsMu.Unlock()
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
// err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
|
// acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
|
||||||
return err
|
// return err
|
||||||
})
|
// })
|
||||||
return
|
|
||||||
|
acc, err = d.createAccount(ctx, localpart, plaintextPassword, appserviceID)
|
||||||
|
return acc, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// WARNING! This function assumes that the relevant mutexes have already
|
// WARNING! This function assumes that the relevant mutexes have already
|
||||||
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
||||||
func (d *Database) createAccount(
|
func (d *Database) createAccount(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
|
ctx context.Context, localpart, plaintextPassword, appserviceID string,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
var err error
|
var err error
|
||||||
var account *api.Account
|
var account *api.Account
|
||||||
|
|
@ -224,13 +228,13 @@ func (d *Database) createAccount(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil {
|
if account, err = d.accounts.insertAccount(ctx, localpart, hash, appserviceID); err != nil {
|
||||||
return nil, sqlutil.ErrUserExists
|
return nil, sqlutil.ErrUserExists
|
||||||
}
|
}
|
||||||
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
|
if err = d.profiles.insertProfile(ctx, localpart); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
|
if err = d.accountDatas.insertAccountData(ctx, localpart, "", "m.push_rules", json.RawMessage(`{
|
||||||
"global": {
|
"global": {
|
||||||
"content": [],
|
"content": [],
|
||||||
"override": [],
|
"override": [],
|
||||||
|
|
@ -252,11 +256,11 @@ func (d *Database) createAccount(
|
||||||
func (d *Database) SaveAccountData(
|
func (d *Database) SaveAccountData(
|
||||||
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
||||||
) error {
|
) error {
|
||||||
d.accountDatasMu.Lock()
|
// d.accountDatasMu.Lock()
|
||||||
defer d.accountDatasMu.Unlock()
|
// defer d.accountDatasMu.Unlock()
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
// return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
// })
|
||||||
})
|
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountData returns account data related to a given localpart
|
// GetAccountData returns account data related to a given localpart
|
||||||
|
|
@ -286,7 +290,7 @@ func (d *Database) GetAccountDataByType(
|
||||||
func (d *Database) GetNewNumericLocalpart(
|
func (d *Database) GetNewNumericLocalpart(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
) (int64, error) {
|
) (int64, error) {
|
||||||
return d.accounts.selectNewNumericLocalpart(ctx, nil)
|
return d.accounts.selectNewNumericLocalpart(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
|
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
|
||||||
|
|
@ -305,22 +309,23 @@ var Err3PIDInUse = errors.New("This third-party identifier is already in use")
|
||||||
func (d *Database) SaveThreePIDAssociation(
|
func (d *Database) SaveThreePIDAssociation(
|
||||||
ctx context.Context, threepid, localpart, medium string,
|
ctx context.Context, threepid, localpart, medium string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
d.threepidsMu.Lock()
|
// d.threepidsMu.Lock()
|
||||||
defer d.threepidsMu.Unlock()
|
// defer d.threepidsMu.Unlock()
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
// return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
user, err := d.threepids.selectLocalpartForThreePID(
|
// })
|
||||||
ctx, txn, threepid, medium,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(user) > 0 {
|
user, err := d.threepids.selectLocalpartForThreePID(
|
||||||
return Err3PIDInUse
|
ctx, threepid, medium,
|
||||||
}
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
|
if len(user) > 0 {
|
||||||
})
|
return Err3PIDInUse
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.threepids.insertThreePID(ctx, threepid, medium, localpart)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveThreePIDAssociation removes the association involving a given third-party
|
// RemoveThreePIDAssociation removes the association involving a given third-party
|
||||||
|
|
@ -330,11 +335,11 @@ func (d *Database) SaveThreePIDAssociation(
|
||||||
func (d *Database) RemoveThreePIDAssociation(
|
func (d *Database) RemoveThreePIDAssociation(
|
||||||
ctx context.Context, threepid string, medium string,
|
ctx context.Context, threepid string, medium string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
d.threepidsMu.Lock()
|
// d.threepidsMu.Lock()
|
||||||
defer d.threepidsMu.Unlock()
|
// defer d.threepidsMu.Unlock()
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
// return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
return d.threepids.deleteThreePID(ctx, txn, threepid, medium)
|
// })
|
||||||
})
|
return d.threepids.deleteThreePID(ctx, threepid, medium)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
|
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
|
||||||
|
|
@ -345,7 +350,7 @@ func (d *Database) RemoveThreePIDAssociation(
|
||||||
func (d *Database) GetLocalpartForThreePID(
|
func (d *Database) GetLocalpartForThreePID(
|
||||||
ctx context.Context, threepid string, medium string,
|
ctx context.Context, threepid string, medium string,
|
||||||
) (localpart string, err error) {
|
) (localpart string, err error) {
|
||||||
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
|
return d.threepids.selectLocalpartForThreePID(ctx, threepid, medium)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
|
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
|
||||||
|
|
@ -362,11 +367,11 @@ func (d *Database) GetThreePIDsForLocalpart(
|
||||||
// in the database.
|
// in the database.
|
||||||
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
||||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
||||||
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
|
response, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||||
if err == sql.ErrNoRows {
|
// if err == sql.ErrNoRows {
|
||||||
return true, nil
|
// return true, nil
|
||||||
}
|
// }
|
||||||
return false, err
|
return response == nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByLocalpart returns the account associated with the given localpart.
|
// GetAccountByLocalpart returns the account associated with the given localpart.
|
||||||
|
|
@ -395,9 +400,9 @@ func (d *Database) CreateOpenIDToken(
|
||||||
token, localpart string,
|
token, localpart string,
|
||||||
) (int64, error) {
|
) (int64, error) {
|
||||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
|
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
|
||||||
err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
// err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
|
// })
|
||||||
})
|
var err = d.openIDTokens.insertToken(ctx, token, localpart, expiresAtMS)
|
||||||
return expiresAtMS, err
|
return expiresAtMS, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,118 +16,190 @@ package cosmosdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
const threepidSchema = `
|
// const threepidSchema = `
|
||||||
-- Stores data about third party identifiers
|
// -- Stores data about third party identifiers
|
||||||
CREATE TABLE IF NOT EXISTS account_threepid (
|
// CREATE TABLE IF NOT EXISTS account_threepid (
|
||||||
-- The third party identifier
|
// -- The third party identifier
|
||||||
threepid TEXT NOT NULL,
|
// threepid TEXT NOT NULL,
|
||||||
-- The 3PID medium
|
// -- The 3PID medium
|
||||||
medium TEXT NOT NULL DEFAULT 'email',
|
// medium TEXT NOT NULL DEFAULT 'email',
|
||||||
-- The localpart of the Matrix user ID associated to this 3PID
|
// -- The localpart of the Matrix user ID associated to this 3PID
|
||||||
localpart TEXT NOT NULL,
|
// localpart TEXT NOT NULL,
|
||||||
|
|
||||||
PRIMARY KEY(threepid, medium)
|
// PRIMARY KEY(threepid, medium)
|
||||||
);
|
// );
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart);
|
type ThreePIDCosmos struct {
|
||||||
`
|
Localpart string `json:"local_part"`
|
||||||
|
ThreePID string `json:"three_pid"`
|
||||||
const selectLocalpartForThreePIDSQL = "" +
|
Medium string `json:"medium"`
|
||||||
"SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
|
||||||
|
|
||||||
const selectThreePIDsForLocalpartSQL = "" +
|
|
||||||
"SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
|
|
||||||
|
|
||||||
const insertThreePIDSQL = "" +
|
|
||||||
"INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
|
||||||
|
|
||||||
const deleteThreePIDSQL = "" +
|
|
||||||
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
|
||||||
|
|
||||||
type threepidStatements struct {
|
|
||||||
db *sql.DB
|
|
||||||
selectLocalpartForThreePIDStmt *sql.Stmt
|
|
||||||
selectThreePIDsForLocalpartStmt *sql.Stmt
|
|
||||||
insertThreePIDStmt *sql.Stmt
|
|
||||||
deleteThreePIDStmt *sql.Stmt
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
|
type ThreePIDCosmosData struct {
|
||||||
s.db = db
|
Id string `json:"id"`
|
||||||
_, err = db.Exec(threepidSchema)
|
Pk string `json:"_pk"`
|
||||||
if err != nil {
|
Cn string `json:"_cn"`
|
||||||
return
|
ETag string `json:"_etag"`
|
||||||
}
|
Timestamp int64 `json:"_ts"`
|
||||||
if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil {
|
ThreePID ThreePIDCosmos `json:"mx_userapi_threepid"`
|
||||||
return
|
}
|
||||||
}
|
|
||||||
if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
|
type threepidStatements struct {
|
||||||
|
db *Database
|
||||||
|
selectLocalpartForThreePIDStmt string
|
||||||
|
selectThreePIDsForLocalpartStmt string
|
||||||
|
// insertThreePIDStmt *sql.Stmt
|
||||||
|
// deleteThreePIDStmt *sql.Stmt
|
||||||
|
tableName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *threepidStatements) prepare(db *Database) (err error) {
|
||||||
|
s.db = db
|
||||||
|
s.selectLocalpartForThreePIDStmt = "select * from c where c._cn = @x1 and c.mx_userapi_threepid.three_pid = @x2 and c.mx_userapi_threepid.medium = @x3"
|
||||||
|
s.selectThreePIDsForLocalpartStmt = "select * from c where c._cn = @x1 and c.mx_userapi_threepid.local_part = @x2"
|
||||||
|
s.tableName = "account_threepid"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) selectLocalpartForThreePID(
|
func (s *threepidStatements) selectLocalpartForThreePID(
|
||||||
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
ctx context.Context, threepid string, medium string,
|
||||||
) (localpart string, err error) {
|
) (localpart string, err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
|
||||||
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
|
// "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||||
if err == sql.ErrNoRows {
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||||
|
response := []ThreePIDCosmosData{}
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": threepid,
|
||||||
|
"@x3": medium,
|
||||||
|
}
|
||||||
|
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(s.selectLocalpartForThreePIDStmt, params)
|
||||||
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if ex != nil {
|
||||||
|
return "", ex
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(response) == 0 {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
return
|
|
||||||
|
return response[0].ThreePID.Localpart, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) selectThreePIDsForLocalpart(
|
func (s *threepidStatements) selectThreePIDsForLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (threepids []authtypes.ThreePID, err error) {
|
) (threepids []authtypes.ThreePID, err error) {
|
||||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
|
||||||
if err != nil {
|
// "SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
|
||||||
return
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||||
|
response := []ThreePIDCosmosData{}
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": localpart,
|
||||||
|
}
|
||||||
|
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(s.selectThreePIDsForLocalpartStmt, params)
|
||||||
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if ex != nil {
|
||||||
|
return threepids, ex
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(response) == 0 {
|
||||||
|
return threepids, nil
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectThreePIDsForLocalpart: rows.close() failed")
|
|
||||||
|
|
||||||
threepids = []authtypes.ThreePID{}
|
threepids = []authtypes.ThreePID{}
|
||||||
for rows.Next() {
|
for _, item := range response {
|
||||||
var threepid string
|
|
||||||
var medium string
|
|
||||||
if err = rows.Scan(&threepid, &medium); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
threepids = append(threepids, authtypes.ThreePID{
|
threepids = append(threepids, authtypes.ThreePID{
|
||||||
Address: threepid,
|
Address: item.ThreePID.ThreePID,
|
||||||
Medium: medium,
|
Medium: item.ThreePID.Medium,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return threepids, rows.Err()
|
return threepids, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) insertThreePID(
|
func (s *threepidStatements) insertThreePID(
|
||||||
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
ctx context.Context, threepid, medium, localpart string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
|
||||||
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
|
// "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
||||||
return err
|
var result = ThreePIDCosmos{
|
||||||
|
Localpart: localpart,
|
||||||
|
Medium: medium,
|
||||||
|
ThreePID: threepid,
|
||||||
|
}
|
||||||
|
|
||||||
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||||
|
|
||||||
|
id := fmt.Sprintf("%s_%s", threepid, medium)
|
||||||
|
var dbData = ThreePIDCosmosData{
|
||||||
|
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id),
|
||||||
|
Cn: dbCollectionName,
|
||||||
|
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
ThreePID: result,
|
||||||
|
}
|
||||||
|
|
||||||
|
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||||
|
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
dbData,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) deleteThreePID(
|
func (s *threepidStatements) deleteThreePID(
|
||||||
ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) {
|
ctx context.Context, threepid string, medium string) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
|
|
||||||
_, err = stmt.ExecContext(ctx, threepid, medium)
|
// "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||||
return err
|
var config = cosmosdbapi.DefaultConfig()
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||||
|
id := fmt.Sprintf("%s_%s", threepid, medium)
|
||||||
|
pk := cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||||
|
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
|
||||||
|
_, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
|
||||||
|
ctx,
|
||||||
|
config.DatabaseName,
|
||||||
|
config.TenantName,
|
||||||
|
id,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue