- Implement Cosmos for the devices_table (#4)

- Use the ConnectionString in the YAML to include the Tenant
- Revert all other non implemented tables back to use SQLLite3
This commit is contained in:
alexfca 2021-05-12 16:30:49 +10:00 committed by GitHub
parent dfd5d445ac
commit b696923333
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 547 additions and 414 deletions

View file

@ -16,7 +16,6 @@
package cosmosdb package cosmosdb
import ( import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"context" "context"
"database/sql" "database/sql"
@ -38,7 +37,6 @@ type Database struct {
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
var result Database var result Database
var err error var err error
if result.db, err = sqlutil.Open(dbProperties); err != nil { if result.db, err = sqlutil.Open(dbProperties); err != nil {

View file

@ -354,12 +354,12 @@ user_api:
listen: http://localhost:7781 listen: http://localhost:7781
connect: http://localhost:7781 connect: http://localhost:7781
account_database: account_database:
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==" connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=criticalarc.com;"
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
device_database: device_database:
connection_string: file:userapi_devices.db connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=criticalarc.com;"
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1

View file

@ -16,7 +16,6 @@
package cosmosdb package cosmosdb
import ( import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"database/sql" "database/sql"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
@ -38,7 +37,6 @@ type Database struct {
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (*Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (*Database, error) {
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
var d Database var d Database
var err error var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil { if d.db, err = sqlutil.Open(dbProperties); err != nil {

View file

@ -0,0 +1,6 @@
package cosmosdbapi
type CosmosConfig struct {
DatabaseName string
ContainerName string
}

View file

@ -1,14 +0,0 @@
package cosmosdbapi
type Tenant struct {
DatabaseName string
TenantName string
}
//TODO: Move into Config or the JWT
func DefaultConfig() Tenant {
return Tenant{
DatabaseName: "safezone_local",
TenantName: "criticalarc.com",
}
}

View file

@ -1,22 +1,50 @@
package cosmosdbutil package cosmosdbutil
import ( import (
"github.com/matrix-org/dendrite/setup/config"
"strings" "strings"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/setup/config"
) )
func GetConnectionString(d *config.DataSource) config.DataSource { const accountEndpointName = "AccountEndpoint"
const accountKeyName = "AccountKey"
const databaseName = "DatabaseName"
const containerName = "ContainerName"
func getConnectionString(d *config.DataSource) config.DataSource {
var connString string var connString string
connString = string(*d) connString = string(*d)
return config.DataSource(strings.Replace(connString, "cosmosdb:", "", 1)) return config.DataSource(strings.Replace(connString, "cosmosdb:", "", 1))
} }
func GetConnectionProperties(connectionString string) map[string]string { func getConnectionProperties(connectionString string) map[string]string {
connectionItemsRaw := strings.Split(connectionString, ";") connectionItemsRaw := strings.Split(connectionString, ";")
connectionItems := map[string]string{} connectionItems := map[string]string{}
for _, item := range connectionItemsRaw { for _, item := range connectionItemsRaw {
itemSplit := strings.SplitN(item, "=", 2) if len(item) > 0 {
connectionItems[itemSplit[0]] = itemSplit[1] itemSplit := strings.SplitN(item, "=", 2)
connectionItems[itemSplit[0]] = itemSplit[1]
}
} }
return connectionItems return connectionItems
} }
func GetCosmosConnection(d *config.DataSource) cosmosdbapi.CosmosConnection {
connString := getConnectionString(d)
connMap := getConnectionProperties(string(connString))
accountEndpoint := connMap[accountEndpointName]
accountKey := connMap[accountKeyName]
return cosmosdbapi.GetCosmosConnection(accountEndpoint, accountKey)
}
func GetCosmosConfig(d *config.DataSource) cosmosdbapi.CosmosConfig {
connString := getConnectionString(d)
connMap := getConnectionProperties(string(connString))
database := connMap[databaseName]
container := connMap[containerName]
return cosmosdbapi.CosmosConfig{
DatabaseName: database,
ContainerName: container,
}
}

View file

@ -15,14 +15,12 @@
package cosmosdb package cosmosdb
import ( import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/shared" "github.com/matrix-org/dendrite/keyserver/storage/shared"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
) )
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
db, err := sqlutil.Open(dbProperties) db, err := sqlutil.Open(dbProperties)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -16,7 +16,6 @@
package cosmosdb package cosmosdb
import ( import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"context" "context"
"database/sql" "database/sql"
@ -37,7 +36,6 @@ type Database struct {
// Open opens a postgres database. // Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (*Database, error) { func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
d := Database{ d := Database{
writer: sqlutil.NewExclusiveWriter(), writer: sqlutil.NewExclusiveWriter(),
} }

View file

@ -16,7 +16,6 @@
package cosmosdb package cosmosdb
import ( import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"context" "context"
"database/sql" "database/sql"
@ -38,7 +37,6 @@ type Database struct {
// Open a sqlite database. // Open a sqlite database.
func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
var d Database var d Database
var db *sql.DB var db *sql.DB
var err error var err error

View file

@ -1,7 +1,6 @@
package kafka package kafka
import ( import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/naffka" "github.com/matrix-org/naffka"
@ -47,8 +46,9 @@ func setupNaffka(cfg *config.Kafka) (sarama.Consumer, sarama.SyncProducer) {
if naffkaInstance != nil { if naffkaInstance != nil {
return naffkaInstance, naffkaInstance return naffkaInstance, naffkaInstance
} }
if(cfg.Database.ConnectionString.IsCosmosDB()) { if cfg.Database.ConnectionString.IsCosmosDB() {
cfg.Database.ConnectionString = cosmosdbutil.GetConnectionString(&cfg.Database.ConnectionString) //TODO: What do we do for Nafka
// cfg.Database.ConnectionString = cosmosdbutil.GetConnectionString(&cfg.Database.ConnectionString)
} }
naffkaDB, err := naffkaStorage.NewDatabase(string(cfg.Database.ConnectionString)) naffkaDB, err := naffkaStorage.NewDatabase(string(cfg.Database.ConnectionString))

View file

@ -16,7 +16,6 @@
package cosmosdb package cosmosdb
import ( import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"context" "context"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
@ -45,7 +44,6 @@ func NewDatabase(
serverKey ed25519.PublicKey, serverKey ed25519.PublicKey,
serverKeyID gomatrixserverlib.KeyID, serverKeyID gomatrixserverlib.KeyID,
) (*Database, error) { ) (*Database, error) {
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
db, err := sqlutil.Open(dbProperties) db, err := sqlutil.Open(dbProperties)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -16,7 +16,6 @@
package cosmosdb package cosmosdb
import ( import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"database/sql" "database/sql"
// Import the sqlite3 package // Import the sqlite3 package
@ -41,7 +40,6 @@ type SyncServerDatasource struct {
// NewDatabase creates a new sync server database // NewDatabase creates a new sync server database
// nolint: gocyclo // nolint: gocyclo
func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) {
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
var d SyncServerDatasource var d SyncServerDatasource
var err error var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil { if d.db, err = sqlutil.Open(dbProperties); err != nil {

View file

@ -84,7 +84,6 @@ func (s *accountDataStatements) insertAccountData(
Content: content, Content: content,
} }
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
id := "" id := ""
if roomID == "" { if roomID == "" {
@ -94,9 +93,9 @@ func (s *accountDataStatements) insertAccountData(
} }
var dbData = AccountDataCosmosData{ var dbData = AccountDataCosmosData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id), Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, id),
Cn: dbCollectionName, Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
AccountData: result, AccountData: result,
} }
@ -104,8 +103,8 @@ func (s *accountDataStatements) insertAccountData(
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
dbData, dbData,
options) options)
@ -120,9 +119,8 @@ func (s *accountDataStatements) selectAccountData(
error, error,
) { ) {
// "SELECT room_id, type, content FROM account_data WHERE localpart = $1" // "SELECT room_id, type, content FROM account_data WHERE localpart = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []AccountDataCosmosData{} response := []AccountDataCosmosData{}
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
@ -132,8 +130,8 @@ func (s *accountDataStatements) selectAccountData(
var query = cosmosdbapi.GetQuery(s.selectAccountDataStmt, params) var query = cosmosdbapi.GetQuery(s.selectAccountDataStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
query, query,
&response, &response,
options) options)
@ -167,9 +165,8 @@ func (s *accountDataStatements) selectAccountDataByType(
var bytes []byte var bytes []byte
// "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" // "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []AccountDataCosmosData{} response := []AccountDataCosmosData{}
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
@ -181,8 +178,8 @@ func (s *accountDataStatements) selectAccountDataByType(
var query = cosmosdbapi.GetQuery(s.selectAccountDataByTypeStmt, params) var query = cosmosdbapi.GetQuery(s.selectAccountDataByTypeStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
query, query,
&response, &response,
options) options)

View file

@ -87,26 +87,26 @@ func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.Serv
return return
} }
func getAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*AccountCosmosData, error) { func getAccount(s *accountsStatements, ctx context.Context, pk string, docId string) (*AccountCosmosData, error) {
response := AccountCosmosData{} response := AccountCosmosData{}
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk) var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument( var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
docId, docId,
optionsGet, optionsGet,
&response) &response)
return &response, ex return &response, ex
} }
func setAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, account AccountCosmosData) (*AccountCosmosData, error) { func setAccount(s *accountsStatements, ctx context.Context, pk string, account AccountCosmosData) (*AccountCosmosData, error) {
response := AccountCosmosData{} response := AccountCosmosData{}
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, account.ETag) var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, account.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
account.Id, account.Id,
&account, &account,
optionsReplace) optionsReplace)
@ -153,13 +153,12 @@ func (s *accountsStatements) insertAccount(
data.PasswordHash = hash data.PasswordHash = hash
data.IsDeactivated = false data.IsDeactivated = false
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var dbData = AccountCosmosData{ var dbData = AccountCosmosData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart), Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Localpart),
Cn: dbCollectionName, Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
Account: data, Account: data,
} }
@ -167,8 +166,8 @@ func (s *accountsStatements) insertAccount(
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
dbData, dbData,
options) options)
@ -184,19 +183,18 @@ func (s *accountsStatements) updatePassword(
) (err error) { ) (err error) {
// "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" // "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart) var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getAccount(s, ctx, config, pk, docId) var response, exGet = getAccount(s, ctx, pk, docId)
if exGet != nil { if exGet != nil {
return exGet return exGet
} }
response.Account.PasswordHash = passwordHash response.Account.PasswordHash = passwordHash
var _, exReplace = setAccount(s, ctx, config, pk, *response) var _, exReplace = setAccount(s, ctx, pk, *response)
if exReplace != nil { if exReplace != nil {
return exReplace return exReplace
} }
@ -208,19 +206,18 @@ func (s *accountsStatements) deactivateAccount(
) (err error) { ) (err error) {
// "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" // "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart) var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getAccount(s, ctx, config, pk, docId) var response, exGet = getAccount(s, ctx, pk, docId)
if exGet != nil { if exGet != nil {
return exGet return exGet
} }
response.Account.IsDeactivated = true response.Account.IsDeactivated = true
var _, exReplace = setAccount(s, ctx, config, pk, *response) var _, exReplace = setAccount(s, ctx, pk, *response)
if exReplace != nil { if exReplace != nil {
return exReplace return exReplace
} }
@ -232,9 +229,8 @@ func (s *accountsStatements) selectPasswordHash(
) (hash string, err error) { ) (hash string, err error) {
// "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" // "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []AccountCosmosData{} response := []AccountCosmosData{}
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
@ -244,8 +240,8 @@ func (s *accountsStatements) selectPasswordHash(
var query = cosmosdbapi.GetQuery(s.selectPasswordHashStmt, params) var query = cosmosdbapi.GetQuery(s.selectPasswordHashStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
query, query,
&response, &response,
options) options)
@ -271,9 +267,8 @@ func (s *accountsStatements) selectAccountByLocalpart(
var acc api.Account var acc api.Account
// "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" // "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []AccountCosmosData{} response := []AccountCosmosData{}
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
@ -283,8 +278,8 @@ func (s *accountsStatements) selectAccountByLocalpart(
var query = cosmosdbapi.GetQuery(s.selectAccountByLocalpartStmt, params) var query = cosmosdbapi.GetQuery(s.selectAccountByLocalpartStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
query, query,
&response, &response,
options) options)
@ -309,9 +304,8 @@ func (s *accountsStatements) selectNewNumericLocalpart(
) (id int64, err error) { ) (id int64, err error) {
// "SELECT COUNT(localpart) FROM account_accounts" // "SELECT COUNT(localpart) FROM account_accounts"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []AccountCosmosUserCount var response []AccountCosmosUserCount
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
@ -320,8 +314,8 @@ func (s *accountsStatements) selectNewNumericLocalpart(
var query = cosmosdbapi.GetQuery(s.selectNewNumericLocalpartStmt, params) var query = cosmosdbapi.GetQuery(s.selectNewNumericLocalpartStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
query, query,
&response, &response,
options) options)

View file

@ -85,13 +85,12 @@ func (s *tokenStatements) insertToken(
ExpiresAtMS: expiresAtMS, ExpiresAtMS: expiresAtMS,
} }
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
var dbData = OpenIdTokenCosmosData{ var dbData = OpenIdTokenCosmosData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Token), Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Token),
Cn: dbCollectionName, Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
OpenIdToken: mapToToken(*result), OpenIdToken: mapToToken(*result),
} }
@ -99,8 +98,8 @@ func (s *tokenStatements) insertToken(
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).CreateDocument( var _, _, ex = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
dbData, dbData,
options) options)
@ -120,9 +119,8 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes(
var openIDTokenAttrs api.OpenIDTokenAttributes var openIDTokenAttrs api.OpenIDTokenAttributes
// "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" // "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []OpenIdTokenCosmosData{} response := []OpenIdTokenCosmosData{}
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
@ -132,8 +130,8 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes(
var query = cosmosdbapi.GetQuery(s.selectTokenStmt, params) var query = cosmosdbapi.GetQuery(s.selectTokenStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
query, query,
&response, &response,
options) options)

View file

@ -87,25 +87,25 @@ func (s *profilesStatements) prepare(db *Database) (err error) {
return return
} }
func getProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*ProfileCosmosData, error) { func getProfile(s *profilesStatements, ctx context.Context, pk string, docId string) (*ProfileCosmosData, error) {
response := ProfileCosmosData{} response := ProfileCosmosData{}
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk) var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument( var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
docId, docId,
optionsGet, optionsGet,
&response) &response)
return &response, ex return &response, ex
} }
func setProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) { func setProfile(s *profilesStatements, ctx context.Context, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, profile.ETag) var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, profile.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
profile.Id, profile.Id,
&profile, &profile,
optionsReplace) optionsReplace)
@ -121,13 +121,12 @@ func (s *profilesStatements) insertProfile(
Localpart: localpart, Localpart: localpart,
} }
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var dbData = ProfileCosmosData{ var dbData = ProfileCosmosData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart), Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Localpart),
Cn: dbCollectionName, Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
Profile: mapToProfile(*result), Profile: mapToProfile(*result),
} }
@ -135,8 +134,8 @@ func (s *profilesStatements) insertProfile(
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
dbData, dbData,
options) options)
@ -148,9 +147,8 @@ func (s *profilesStatements) selectProfileByLocalpart(
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" // "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []ProfileCosmosData{} response := []ProfileCosmosData{}
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
@ -160,8 +158,8 @@ func (s *profilesStatements) selectProfileByLocalpart(
var query = cosmosdbapi.GetQuery(s.selectProfileByLocalpartStmt, params) var query = cosmosdbapi.GetQuery(s.selectProfileByLocalpartStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
query, query,
&response, &response,
options) options)
@ -187,19 +185,18 @@ func (s *profilesStatements) setAvatarURL(
) (err error) { ) (err error) {
// "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" // "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart) var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
var response, exGet = getProfile(s, ctx, config, pk, docId) var response, exGet = getProfile(s, ctx, pk, docId)
if exGet != nil { if exGet != nil {
return exGet return exGet
} }
response.Profile.AvatarURL = avatarURL response.Profile.AvatarURL = avatarURL
var _, exReplace = setProfile(s, ctx, config, pk, *response) var _, exReplace = setProfile(s, ctx, pk, *response)
if exReplace != nil { if exReplace != nil {
return exReplace return exReplace
} }
@ -211,18 +208,17 @@ func (s *profilesStatements) setDisplayName(
) (err error) { ) (err error) {
// "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" // "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart) var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
var response, exGet = getProfile(s, ctx, config, pk, docId) var response, exGet = getProfile(s, ctx, pk, docId)
if exGet != nil { if exGet != nil {
return exGet return exGet
} }
response.Profile.DisplayName = displayName response.Profile.DisplayName = displayName
var _, exReplace = setProfile(s, ctx, config, pk, *response) var _, exReplace = setProfile(s, ctx, pk, *response)
if exReplace != nil { if exReplace != nil {
return exReplace return exReplace
} }
@ -235,9 +231,8 @@ func (s *profilesStatements) selectProfilesBySearch(
var profiles []authtypes.Profile var profiles []authtypes.Profile
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" // "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []ProfileCosmosData{} response := []ProfileCosmosData{}
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
@ -248,8 +243,8 @@ func (s *profilesStatements) selectProfilesBySearch(
var query = cosmosdbapi.GetQuery(s.selectProfilesBySearchStmt, params) var query = cosmosdbapi.GetQuery(s.selectProfilesBySearchStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
query, query,
&response, &response,
options) options)

View file

@ -48,20 +48,19 @@ type Database struct {
databaseName string databaseName string
connection cosmosdbapi.CosmosConnection connection cosmosdbapi.CosmosConnection
cosmosConfig cosmosdbapi.CosmosConfig
} }
// NewDatabase creates a new accounts and profiles database // NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
connString := cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
connMap := cosmosdbutil.GetConnectionProperties(string(connString)) config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
accountEndpoint := connMap["AccountEndpoint"]
accountKey := connMap["AccountKey"]
conn := cosmosdbapi.GetCosmosConnection(accountEndpoint, accountKey)
d := &Database{ d := &Database{
serverName: serverName, serverName: serverName,
databaseName: "userapi", databaseName: "userapi",
connection: conn, connection: conn,
cosmosConfig: config,
// db: db, // db: db,
// writer: sqlutil.NewExclusiveWriter(), // writer: sqlutil.NewExclusiveWriter(),
// bcryptCost: bcryptCost, // bcryptCost: bcryptCost,

View file

@ -74,9 +74,8 @@ func (s *threepidStatements) selectLocalpartForThreePID(
) (localpart string, err error) { ) (localpart string, err error) {
// "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" // "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []ThreePIDCosmosData{} response := []ThreePIDCosmosData{}
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
@ -87,8 +86,8 @@ func (s *threepidStatements) selectLocalpartForThreePID(
var query = cosmosdbapi.GetQuery(s.selectLocalpartForThreePIDStmt, params) var query = cosmosdbapi.GetQuery(s.selectLocalpartForThreePIDStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
query, query,
&response, &response,
options) options)
@ -109,9 +108,8 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
) (threepids []authtypes.ThreePID, err error) { ) (threepids []authtypes.ThreePID, err error) {
// "SELECT threepid, medium FROM account_threepid WHERE localpart = $1" // "SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []ThreePIDCosmosData{} response := []ThreePIDCosmosData{}
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": dbCollectionName, "@x1": dbCollectionName,
@ -121,8 +119,8 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
var query = cosmosdbapi.GetQuery(s.selectThreePIDsForLocalpartStmt, params) var query = cosmosdbapi.GetQuery(s.selectThreePIDsForLocalpartStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
query, query,
&response, &response,
options) options)
@ -156,14 +154,14 @@ func (s *threepidStatements) insertThreePID(
ThreePID: threepid, ThreePID: threepid,
} }
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
id := fmt.Sprintf("%s_%s", threepid, medium) docId := fmt.Sprintf("%s_%s", threepid, medium)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var dbData = ThreePIDCosmosData{ var dbData = ThreePIDCosmosData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id), Id: cosmosDocId,
Cn: dbCollectionName, Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
ThreePID: result, ThreePID: result,
} }
@ -171,8 +169,8 @@ func (s *threepidStatements) insertThreePID(
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
dbData, dbData,
options) options)
@ -186,16 +184,16 @@ func (s *threepidStatements) deleteThreePID(
ctx context.Context, threepid string, medium string) (err error) { ctx context.Context, threepid string, medium string) (err error) {
// "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" // "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
id := fmt.Sprintf("%s_%s", threepid, medium) docId := fmt.Sprintf("%s_%s", threepid, medium)
pk := cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var options = cosmosdbapi.GetDeleteDocumentOptions(pk) var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
_, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx, ctx,
config.DatabaseName, s.db.cosmosConfig.DatabaseName,
config.TenantName, s.db.cosmosConfig.ContainerName,
id, cosmosDocId,
options) options)
if err != nil { if err != nil {

View file

@ -16,127 +16,145 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "errors"
"strings" "fmt"
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
const devicesSchema = ` // const devicesSchema = `
-- This sequence is used for automatic allocation of session_id. // -- This sequence is used for automatic allocation of session_id.
-- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1; // -- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
-- Stores data about devices. // -- Stores data about devices.
CREATE TABLE IF NOT EXISTS device_devices ( // CREATE TABLE IF NOT EXISTS device_devices (
access_token TEXT PRIMARY KEY, // access_token TEXT PRIMARY KEY,
session_id INTEGER, // session_id INTEGER,
device_id TEXT , // device_id TEXT ,
localpart TEXT , // localpart TEXT ,
created_ts BIGINT, // created_ts BIGINT,
display_name TEXT, // display_name TEXT,
last_seen_ts BIGINT, // last_seen_ts BIGINT,
ip TEXT, // ip TEXT,
user_agent TEXT, // user_agent TEXT,
UNIQUE (localpart, device_id) // UNIQUE (localpart, device_id)
); // );
` // `
const insertDeviceSQL = "" + type DeviceCosmos struct {
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + ID string `json:"device_id"`
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" UserID string `json:"user_id"`
// The access_token granted to this device.
// This uniquely identifies the device from all other devices and clients.
AccessToken string `json:"access_token"`
// The unique ID of the session identified by the access token.
// Can be used as a secure substitution in places where data needs to be
// associated with access tokens.
SessionID int64 `json:"session_id"`
DisplayName string `json:"display_name"`
LastSeenTS int64 `json:"last_seen_ts"`
LastSeenIP string `json:"last_seen_ip"`
Localpart string `json:"local_part"`
UserAgent string `json:"user_agent"`
// If the device is for an appservice user,
// this is the appservice ID.
AppserviceID string `json:"app_service_id"`
}
const selectDevicesCountSQL = "" + type DeviceCosmosData struct {
"SELECT COUNT(access_token) FROM device_devices" Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Device DeviceCosmos `json:"mx_userapi_device"`
}
const selectDeviceByTokenSQL = "" + type DeviceCosmosSessionCount struct {
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" SessionCount int64 `json:"sessioncount"`
}
const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
const updateDeviceLastSeen = "" +
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
type devicesStatements struct { type devicesStatements struct {
db *sql.DB db *Database
writer sqlutil.Writer selectDevicesCountStmt string
insertDeviceStmt *sql.Stmt selectDeviceByTokenStmt string
selectDevicesCountStmt *sql.Stmt // selectDeviceByIDStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt selectDevicesByIDStmt string
selectDeviceByIDStmt *sql.Stmt selectDevicesByLocalpartStmt string
selectDevicesByIDStmt *sql.Stmt selectDevicesByLocalpartExceptIDStmt string
selectDevicesByLocalpartStmt *sql.Stmt serverName gomatrixserverlib.ServerName
updateDeviceNameStmt *sql.Stmt tableName string
updateDeviceLastSeenStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
} }
func (s *devicesStatements) execSchema(db *sql.DB) error { func mapFromDevice(db DeviceCosmos) api.Device {
_, err := db.Exec(devicesSchema) return api.Device{
return err AccessToken: db.AccessToken,
AppserviceID: db.AppserviceID,
ID: db.ID,
LastSeenIP: db.LastSeenIP,
LastSeenTS: db.LastSeenTS,
SessionID: db.SessionID,
UserAgent: db.UserAgent,
UserID: db.UserID,
}
} }
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { func mapTodevice(api api.Device, s *devicesStatements) DeviceCosmos {
localPart, _ := userutil.ParseUsernameParam(api.UserID, &s.serverName)
return DeviceCosmos{
AccessToken: api.AccessToken,
AppserviceID: api.AppserviceID,
ID: api.ID,
LastSeenIP: api.LastSeenIP,
LastSeenTS: api.LastSeenTS,
Localpart: localPart,
SessionID: api.SessionID,
UserAgent: api.UserAgent,
UserID: api.UserID,
}
}
func getDevice(s *devicesStatements, ctx context.Context, pk string, docId string) (*DeviceCosmosData, error) {
response := DeviceCosmosData{}
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
docId,
optionsGet,
&response)
return &response, ex
}
func setDevice(s *devicesStatements, ctx context.Context, pk string, device DeviceCosmosData) (*DeviceCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, device.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
device.Id,
&device,
optionsReplace)
return &device, ex
}
func (s *devicesStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
s.db = db s.db = db
s.writer = writer s.selectDevicesCountStmt = "select count(c._ts) as sessioncount from c where c._cn = @x1"
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { s.selectDevicesByLocalpartStmt = "select * from c where c._cn = @x1 and c.mx_userapi_device.local_part = @x2 and ARRAY_CONTAINS(@x3, c.mx_userapi_device.device_id)"
return s.selectDevicesByLocalpartExceptIDStmt = "select * from c where c._cn = @x1 and c.mx_userapi_device.local_part = @x2 and c.mx_userapi_device.device_id != @x3"
} s.selectDeviceByTokenStmt = "select * from c where c._cn = @x1 and c.mx_userapi_device.access_token = @x2"
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil { s.selectDevicesByIDStmt = "select * from c where c._cn = @x1 and ARRAY_CONTAINS(@x2, c.mx_userapi_device.device_id)"
return
}
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
return
}
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
return
}
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
return
}
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
return
}
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
return
}
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
return
}
s.serverName = server s.serverName = server
s.tableName = "device_devices"
return return
} }
@ -144,85 +162,219 @@ func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server go
// Returns an error if the user already has a device with the given device ID. // Returns an error if the user already has a device with the given device ID.
// Returns the device on success. // Returns the device on success.
func (s *devicesStatements) insertDevice( func (s *devicesStatements) insertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, ctx context.Context, id, localpart, accessToken string,
displayName *string, ipAddr, userAgent string, displayName *string, ipAddr, userAgent string,
) (*api.Device, error) { ) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64 var sessionID int64
countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt) // "SELECT COUNT(access_token) FROM device_devices"
insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) // HACK: Do we need a Cosmos Table for the sequence?
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosSessionCount
params := map[string]interface{}{
"@x1": dbCollectionName,
}
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectDevicesCountStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err return nil, err
} }
sessionID = response[0].SessionCount
sessionID++ sessionID++
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { // "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
return nil, err // " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
}
return &api.Device{ data := DeviceCosmos{
ID: id, ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName), UserID: userutil.MakeUserID(localpart, s.serverName),
AccessToken: accessToken, AccessToken: accessToken,
SessionID: sessionID, SessionID: sessionID,
LastSeenTS: createdTimeMS, LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr, LastSeenIP: ipAddr,
Localpart: localpart,
UserAgent: userAgent, UserAgent: userAgent,
}, nil }
// access_token TEXT PRIMARY KEY,
// UNIQUE (localpart, device_id)
// HACK: check for duplicate PK as we are using the UNIQUE key for the DocId
docId := fmt.Sprintf("%s_%s", localpart, id)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var dbData = DeviceCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
Timestamp: time.Now().Unix(),
Device: data,
}
var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
var _, _, errCreate = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData,
optionsCreate)
if errCreate != nil {
return nil, errCreate
}
var result = mapFromDevice(dbData.Device)
return &result, nil
} }
func (s *devicesStatements) deleteDevice( func (s *devicesStatements) deleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string, ctx context.Context, id, localpart string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) // "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
_, err := stmt.ExecContext(ctx, id, localpart) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
docId := fmt.Sprintf("%s_%s", localpart, id)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
cosmosDocId,
options)
if err != nil {
return err
}
return err return err
} }
func (s *devicesStatements) deleteDevices( func (s *devicesStatements) deleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string, ctx context.Context, localpart string, devices []string,
) error { ) error {
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1) // "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
prep, err := s.db.Prepare(orig) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
"@x3": devices,
}
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return err return err
} }
stmt := sqlutil.TxStmt(txn, prep) for _, item := range response {
params := make([]interface{}, len(devices)+1) s.deleteDevice(ctx, item.Device.ID, item.Device.Localpart)
params[0] = localpart
for i, v := range devices {
params[i+1] = v
} }
_, err = stmt.ExecContext(ctx, params...)
return err return err
} }
func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ctx context.Context, localpart, exceptDeviceID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) // "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData
exceptDevices := []string{
exceptDeviceID,
}
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
"@x3": exceptDevices,
}
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return err
}
for _, item := range response {
s.deleteDevice(ctx, item.Device.ID, item.Device.Localpart)
}
return err return err
} }
func (s *devicesStatements) updateDeviceName( func (s *devicesStatements) updateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ctx context.Context, localpart, deviceID string, displayName *string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) // "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
return err var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
if exGet != nil {
return exGet
}
response.Device.DisplayName = *displayName
var _, exReplace = setDevice(s, ctx, pk, *response)
if exReplace != nil {
return exReplace
}
return exReplace
} }
func (s *devicesStatements) selectDeviceByToken( func (s *devicesStatements) selectDeviceByToken(
ctx context.Context, accessToken string, ctx context.Context, accessToken string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device // "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
var localpart string var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
stmt := s.selectDeviceByTokenStmt var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) var response []DeviceCosmosData
if err == nil { params := map[string]interface{}{
dev.UserID = userutil.MakeUserID(localpart, s.serverName) "@x1": dbCollectionName,
dev.AccessToken = accessToken "@x2": accessToken,
} }
return &dev, err
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectDeviceByTokenStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
if len(response) == 0 {
return nil, errors.New(fmt.Sprintf("No Devices found with AccessToken %s", accessToken))
}
if err == nil {
result := mapFromDevice(response[0].Device)
return &result, nil
}
return nil, err
} }
// selectDeviceByID retrieves a device from the database with the given user // selectDeviceByID retrieves a device from the database with the given user
@ -230,54 +382,48 @@ func (s *devicesStatements) selectDeviceByToken(
func (s *devicesStatements) selectDeviceByID( func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context, localpart, deviceID string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device // "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
var displayName sql.NullString var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
stmt := s.selectDeviceByIDStmt var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) docId := fmt.Sprintf("%s_%s", localpart, deviceID)
if err == nil { cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
dev.ID = deviceID var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
dev.UserID = userutil.MakeUserID(localpart, s.serverName) if exGet != nil {
if displayName.Valid { return nil, exGet
dev.DisplayName = displayName.String
}
} }
return &dev, err result := mapFromDevice(response.Device)
return &result, nil
} }
func (s *devicesStatements) selectDevicesByLocalpart( func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ctx context.Context, localpart, exceptDeviceID string,
) ([]api.Device, error) { ) ([]api.Device, error) {
devices := []api.Device{} devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) // "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
if err != nil { var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
return devices, err var response []DeviceCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
"@x3": exceptDeviceID,
} }
for rows.Next() { var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var dev api.Device var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartExceptIDStmt, params)
var lastseents sql.NullInt64 var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
var id, displayname, ip, useragent sql.NullString ctx,
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent) s.db.cosmosConfig.DatabaseName,
if err != nil { s.db.cosmosConfig.ContainerName,
return devices, err query,
} &response,
if id.Valid { optionsQry)
dev.ID = id.String if err != nil {
} return nil, err
if displayname.Valid { }
dev.DisplayName = displayname.String
}
if lastseents.Valid {
dev.LastSeenTS = lastseents.Int64
}
if ip.Valid {
dev.LastSeenIP = ip.String
}
if useragent.Valid {
dev.UserAgent = useragent.String
}
for _, item := range response {
dev := mapFromDevice(item.Device)
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev) devices = append(devices, dev)
} }
@ -286,37 +432,53 @@ func (s *devicesStatements) selectDevicesByLocalpart(
} }
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1) // "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
iDeviceIDs := make([]interface{}, len(deviceIDs)) var devices []api.Device
for i := range deviceIDs { var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
iDeviceIDs[i] = deviceIDs[i] var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": deviceIDs,
} }
rows, err := s.db.QueryContext(ctx, sqlQuery, iDeviceIDs...) var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectDevicesByIDStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed") for _, item := range response {
var devices []api.Device dev := mapFromDevice(item.Device)
for rows.Next() {
var dev api.Device
var localpart string
var displayName sql.NullString
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
return nil, err
}
if displayName.Valid {
dev.DisplayName = displayName.String
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev) devices = append(devices, dev)
} }
return devices, rows.Err() return devices, nil
} }
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
lastSeenTs := time.Now().UnixNano() / 1000000 lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) // "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
return err var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
if exGet != nil {
return exGet
}
response.Device.LastSeenTS = lastSeenTs
var _, exReplace = setDevice(s, ctx, pk, *response)
if exReplace != nil {
return exReplace
}
return exReplace
} }

View file

@ -15,16 +15,18 @@
package cosmosdb package cosmosdb
import ( import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"context" "context"
"crypto/rand" "crypto/rand"
"database/sql"
"encoding/base64" "encoding/base64"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
@ -35,35 +37,32 @@ var deviceIDByteLength = 6
// Database represents a device database. // Database represents a device database.
type Database struct { type Database struct {
db *sql.DB writer sqlutil.Writer
writer sqlutil.Writer devices devicesStatements
devices devicesStatements connection cosmosdbapi.CosmosConnection
databaseName string
cosmosConfig cosmosdbapi.CosmosConfig
serverName gomatrixserverlib.ServerName
} }
// NewDatabase creates a new device database // NewDatabase creates a new device database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) {
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
db, err := sqlutil.Open(dbProperties) config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
if err != nil { devices := devicesStatements{}
return nil, err
}
writer := sqlutil.NewExclusiveWriter()
d := devicesStatements{}
// Create tables before executing migrations so we don't fail if the table is missing, // Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns // and THEN prepare statements so we don't fail due to referencing new columns
if err = d.execSchema(db); err != nil { d := &Database{
return nil, err databaseName: "userapi",
devices: devices,
serverName: serverName,
connection: conn,
cosmosConfig: config,
} }
m := sqlutil.NewMigrations() err := d.devices.prepare(d, serverName)
deltas.LoadLastSeenTSIP(m)
if err = m.RunDeltas(db, dbProperties); err != nil { return d, err
return nil, err
}
if err = d.prepare(db, writer, serverName); err != nil {
return nil, err
}
return &Database{db, writer, d}, nil
} }
// GetDeviceByAccessToken returns the device matching the given access token. // GetDeviceByAccessToken returns the device matching the given access token.
@ -86,7 +85,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart( func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ([]api.Device, error) { ) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") return d.devices.selectDevicesByLocalpart(ctx, localpart, "")
} }
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -104,16 +103,14 @@ func (d *Database) CreateDevice(
displayName *string, ipAddr, userAgent string, displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) { ) (dev *api.Device, returnErr error) {
if deviceID != nil { if deviceID != nil {
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { var err error
var err error // Revoke existing tokens for this device
// Revoke existing tokens for this device if err = d.devices.deleteDevice(ctx, *deviceID, localpart); err != nil {
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { return nil, err
return err }
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) dev, err = d.devices.insertDevice(ctx, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err return dev, err
})
} else { } else {
// We generate device IDs in a loop in case its already taken. // We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever // We cap this at going round 5 times to ensure we don't spin forever
@ -124,11 +121,9 @@ func (d *Database) CreateDevice(
return return
} }
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { var err error
var err error dev, err = d.devices.insertDevice(ctx, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) return dev, err
return err
})
if returnErr == nil { if returnErr == nil {
return return
} }
@ -154,9 +149,7 @@ func generateDeviceID() (string, error) {
func (d *Database) UpdateDevice( func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string, ctx context.Context, localpart, deviceID string, displayName *string,
) error { ) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { return d.devices.updateDeviceName(ctx, localpart, deviceID, displayName)
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
} }
// RemoveDevice revokes a device by deleting the entry in the database // RemoveDevice revokes a device by deleting the entry in the database
@ -166,12 +159,10 @@ func (d *Database) UpdateDevice(
func (d *Database) RemoveDevice( func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string, ctx context.Context, deviceID, localpart string,
) error { ) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { if err := d.devices.deleteDevice(ctx, deviceID, localpart); err != nil {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { return err
return err }
} return nil
return nil
})
} }
// RemoveDevices revokes one or more devices by deleting the entry in the database // RemoveDevices revokes one or more devices by deleting the entry in the database
@ -181,12 +172,10 @@ func (d *Database) RemoveDevice(
func (d *Database) RemoveDevices( func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string, ctx context.Context, localpart string, devices []string,
) error { ) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { if err := d.devices.deleteDevices(ctx, localpart, devices); err != nil {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { return err
return err }
} return nil
return nil
})
} }
// RemoveAllDevices revokes devices by deleting the entry in the // RemoveAllDevices revokes devices by deleting the entry in the
@ -195,22 +184,17 @@ func (d *Database) RemoveDevices(
func (d *Database) RemoveAllDevices( func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string, ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) { ) (devices []api.Device, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { devices, err = d.devices.selectDevicesByLocalpart(ctx, localpart, exceptDeviceID)
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) if err != nil {
if err != nil { return nil, err
return err }
} if err := d.devices.deleteDevicesByLocalpart(ctx, localpart, exceptDeviceID); err != nil {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { return nil, err
return err }
} return devices, nil
return nil
})
return
} }
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address // UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { return d.devices.updateDeviceLastSeen(ctx, localpart, deviceID, ipAddr)
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
} }