- Implement Cosmos for the devices_table (#4)

- Use the ConnectionString in the YAML to include the Tenant
- Revert all other non implemented tables back to use SQLLite3
This commit is contained in:
alexfca 2021-05-12 16:30:49 +10:00 committed by GitHub
parent dfd5d445ac
commit b696923333
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 547 additions and 414 deletions

View file

@ -16,7 +16,6 @@
package cosmosdb
import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"context"
"database/sql"
@ -38,7 +37,6 @@ type Database struct {
// NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
var result Database
var err error
if result.db, err = sqlutil.Open(dbProperties); err != nil {

View file

@ -354,12 +354,12 @@ user_api:
listen: http://localhost:7781
connect: http://localhost:7781
account_database:
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=criticalarc.com;"
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
device_database:
connection_string: file:userapi_devices.db
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=criticalarc.com;"
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1

View file

@ -16,7 +16,6 @@
package cosmosdb
import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"database/sql"
_ "github.com/mattn/go-sqlite3"
@ -38,7 +37,6 @@ type Database struct {
// NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (*Database, error) {
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
var d Database
var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil {

View file

@ -0,0 +1,6 @@
package cosmosdbapi
type CosmosConfig struct {
DatabaseName string
ContainerName string
}

View file

@ -1,14 +0,0 @@
package cosmosdbapi
type Tenant struct {
DatabaseName string
TenantName string
}
//TODO: Move into Config or the JWT
func DefaultConfig() Tenant {
return Tenant{
DatabaseName: "safezone_local",
TenantName: "criticalarc.com",
}
}

View file

@ -1,22 +1,50 @@
package cosmosdbutil
import (
"github.com/matrix-org/dendrite/setup/config"
"strings"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/setup/config"
)
func GetConnectionString(d *config.DataSource) config.DataSource {
const accountEndpointName = "AccountEndpoint"
const accountKeyName = "AccountKey"
const databaseName = "DatabaseName"
const containerName = "ContainerName"
func getConnectionString(d *config.DataSource) config.DataSource {
var connString string
connString = string(*d)
return config.DataSource(strings.Replace(connString, "cosmosdb:", "", 1))
}
func GetConnectionProperties(connectionString string) map[string]string {
func getConnectionProperties(connectionString string) map[string]string {
connectionItemsRaw := strings.Split(connectionString, ";")
connectionItems := map[string]string{}
for _, item := range connectionItemsRaw {
itemSplit := strings.SplitN(item, "=", 2)
connectionItems[itemSplit[0]] = itemSplit[1]
if len(item) > 0 {
itemSplit := strings.SplitN(item, "=", 2)
connectionItems[itemSplit[0]] = itemSplit[1]
}
}
return connectionItems
}
}
func GetCosmosConnection(d *config.DataSource) cosmosdbapi.CosmosConnection {
connString := getConnectionString(d)
connMap := getConnectionProperties(string(connString))
accountEndpoint := connMap[accountEndpointName]
accountKey := connMap[accountKeyName]
return cosmosdbapi.GetCosmosConnection(accountEndpoint, accountKey)
}
func GetCosmosConfig(d *config.DataSource) cosmosdbapi.CosmosConfig {
connString := getConnectionString(d)
connMap := getConnectionProperties(string(connString))
database := connMap[databaseName]
container := connMap[containerName]
return cosmosdbapi.CosmosConfig{
DatabaseName: database,
ContainerName: container,
}
}

View file

@ -15,14 +15,12 @@
package cosmosdb
import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/shared"
"github.com/matrix-org/dendrite/setup/config"
)
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err

View file

@ -16,7 +16,6 @@
package cosmosdb
import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"context"
"database/sql"
@ -37,7 +36,6 @@ type Database struct {
// Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
d := Database{
writer: sqlutil.NewExclusiveWriter(),
}

View file

@ -16,7 +16,6 @@
package cosmosdb
import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"context"
"database/sql"
@ -38,7 +37,6 @@ type Database struct {
// Open a sqlite database.
func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
var d Database
var db *sql.DB
var err error

View file

@ -1,7 +1,6 @@
package kafka
import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/naffka"
@ -47,8 +46,9 @@ func setupNaffka(cfg *config.Kafka) (sarama.Consumer, sarama.SyncProducer) {
if naffkaInstance != nil {
return naffkaInstance, naffkaInstance
}
if(cfg.Database.ConnectionString.IsCosmosDB()) {
cfg.Database.ConnectionString = cosmosdbutil.GetConnectionString(&cfg.Database.ConnectionString)
if cfg.Database.ConnectionString.IsCosmosDB() {
//TODO: What do we do for Nafka
// cfg.Database.ConnectionString = cosmosdbutil.GetConnectionString(&cfg.Database.ConnectionString)
}
naffkaDB, err := naffkaStorage.NewDatabase(string(cfg.Database.ConnectionString))

View file

@ -16,7 +16,6 @@
package cosmosdb
import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"context"
"golang.org/x/crypto/ed25519"
@ -45,7 +44,6 @@ func NewDatabase(
serverKey ed25519.PublicKey,
serverKeyID gomatrixserverlib.KeyID,
) (*Database, error) {
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err

View file

@ -16,7 +16,6 @@
package cosmosdb
import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"database/sql"
// Import the sqlite3 package
@ -41,7 +40,6 @@ type SyncServerDatasource struct {
// NewDatabase creates a new sync server database
// nolint: gocyclo
func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) {
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
var d SyncServerDatasource
var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil {

View file

@ -84,7 +84,6 @@ func (s *accountDataStatements) insertAccountData(
Content: content,
}
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
id := ""
if roomID == "" {
@ -94,9 +93,9 @@ func (s *accountDataStatements) insertAccountData(
}
var dbData = AccountDataCosmosData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id),
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, id),
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
Timestamp: time.Now().Unix(),
AccountData: result,
}
@ -104,8 +103,8 @@ func (s *accountDataStatements) insertAccountData(
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData,
options)
@ -120,9 +119,8 @@ func (s *accountDataStatements) selectAccountData(
error,
) {
// "SELECT room_id, type, content FROM account_data WHERE localpart = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []AccountDataCosmosData{}
params := map[string]interface{}{
"@x1": dbCollectionName,
@ -132,8 +130,8 @@ func (s *accountDataStatements) selectAccountData(
var query = cosmosdbapi.GetQuery(s.selectAccountDataStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
@ -167,9 +165,8 @@ func (s *accountDataStatements) selectAccountDataByType(
var bytes []byte
// "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []AccountDataCosmosData{}
params := map[string]interface{}{
"@x1": dbCollectionName,
@ -181,8 +178,8 @@ func (s *accountDataStatements) selectAccountDataByType(
var query = cosmosdbapi.GetQuery(s.selectAccountDataByTypeStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)

View file

@ -87,26 +87,26 @@ func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.Serv
return
}
func getAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*AccountCosmosData, error) {
func getAccount(s *accountsStatements, ctx context.Context, pk string, docId string) (*AccountCosmosData, error) {
response := AccountCosmosData{}
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
docId,
optionsGet,
&response)
return &response, ex
}
func setAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, account AccountCosmosData) (*AccountCosmosData, error) {
func setAccount(s *accountsStatements, ctx context.Context, pk string, account AccountCosmosData) (*AccountCosmosData, error) {
response := AccountCosmosData{}
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, account.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
account.Id,
&account,
optionsReplace)
@ -153,13 +153,12 @@ func (s *accountsStatements) insertAccount(
data.PasswordHash = hash
data.IsDeactivated = false
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var dbData = AccountCosmosData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart),
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Localpart),
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
Timestamp: time.Now().Unix(),
Account: data,
}
@ -167,8 +166,8 @@ func (s *accountsStatements) insertAccount(
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData,
options)
@ -184,19 +183,18 @@ func (s *accountsStatements) updatePassword(
) (err error) {
// "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getAccount(s, ctx, config, pk, docId)
var response, exGet = getAccount(s, ctx, pk, docId)
if exGet != nil {
return exGet
}
response.Account.PasswordHash = passwordHash
var _, exReplace = setAccount(s, ctx, config, pk, *response)
var _, exReplace = setAccount(s, ctx, pk, *response)
if exReplace != nil {
return exReplace
}
@ -208,19 +206,18 @@ func (s *accountsStatements) deactivateAccount(
) (err error) {
// "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response, exGet = getAccount(s, ctx, config, pk, docId)
var response, exGet = getAccount(s, ctx, pk, docId)
if exGet != nil {
return exGet
}
response.Account.IsDeactivated = true
var _, exReplace = setAccount(s, ctx, config, pk, *response)
var _, exReplace = setAccount(s, ctx, pk, *response)
if exReplace != nil {
return exReplace
}
@ -232,9 +229,8 @@ func (s *accountsStatements) selectPasswordHash(
) (hash string, err error) {
// "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []AccountCosmosData{}
params := map[string]interface{}{
"@x1": dbCollectionName,
@ -244,8 +240,8 @@ func (s *accountsStatements) selectPasswordHash(
var query = cosmosdbapi.GetQuery(s.selectPasswordHashStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
@ -271,9 +267,8 @@ func (s *accountsStatements) selectAccountByLocalpart(
var acc api.Account
// "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []AccountCosmosData{}
params := map[string]interface{}{
"@x1": dbCollectionName,
@ -283,8 +278,8 @@ func (s *accountsStatements) selectAccountByLocalpart(
var query = cosmosdbapi.GetQuery(s.selectAccountByLocalpartStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
@ -309,9 +304,8 @@ func (s *accountsStatements) selectNewNumericLocalpart(
) (id int64, err error) {
// "SELECT COUNT(localpart) FROM account_accounts"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []AccountCosmosUserCount
params := map[string]interface{}{
"@x1": dbCollectionName,
@ -320,8 +314,8 @@ func (s *accountsStatements) selectNewNumericLocalpart(
var query = cosmosdbapi.GetQuery(s.selectNewNumericLocalpartStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)

View file

@ -85,13 +85,12 @@ func (s *tokenStatements) insertToken(
ExpiresAtMS: expiresAtMS,
}
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
var dbData = OpenIdTokenCosmosData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Token),
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Token),
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
Timestamp: time.Now().Unix(),
OpenIdToken: mapToToken(*result),
}
@ -99,8 +98,8 @@ func (s *tokenStatements) insertToken(
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData,
options)
@ -120,9 +119,8 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes(
var openIDTokenAttrs api.OpenIDTokenAttributes
// "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []OpenIdTokenCosmosData{}
params := map[string]interface{}{
"@x1": dbCollectionName,
@ -132,8 +130,8 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes(
var query = cosmosdbapi.GetQuery(s.selectTokenStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)

View file

@ -87,25 +87,25 @@ func (s *profilesStatements) prepare(db *Database) (err error) {
return
}
func getProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*ProfileCosmosData, error) {
func getProfile(s *profilesStatements, ctx context.Context, pk string, docId string) (*ProfileCosmosData, error) {
response := ProfileCosmosData{}
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
docId,
optionsGet,
&response)
return &response, ex
}
func setProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) {
func setProfile(s *profilesStatements, ctx context.Context, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, profile.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
profile.Id,
&profile,
optionsReplace)
@ -121,13 +121,12 @@ func (s *profilesStatements) insertProfile(
Localpart: localpart,
}
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var dbData = ProfileCosmosData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart),
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Localpart),
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
Timestamp: time.Now().Unix(),
Profile: mapToProfile(*result),
}
@ -135,8 +134,8 @@ func (s *profilesStatements) insertProfile(
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData,
options)
@ -148,9 +147,8 @@ func (s *profilesStatements) selectProfileByLocalpart(
) (*authtypes.Profile, error) {
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []ProfileCosmosData{}
params := map[string]interface{}{
"@x1": dbCollectionName,
@ -160,8 +158,8 @@ func (s *profilesStatements) selectProfileByLocalpart(
var query = cosmosdbapi.GetQuery(s.selectProfileByLocalpartStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
@ -187,19 +185,18 @@ func (s *profilesStatements) setAvatarURL(
) (err error) {
// "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
var response, exGet = getProfile(s, ctx, config, pk, docId)
var response, exGet = getProfile(s, ctx, pk, docId)
if exGet != nil {
return exGet
}
response.Profile.AvatarURL = avatarURL
var _, exReplace = setProfile(s, ctx, config, pk, *response)
var _, exReplace = setProfile(s, ctx, pk, *response)
if exReplace != nil {
return exReplace
}
@ -211,18 +208,17 @@ func (s *profilesStatements) setDisplayName(
) (err error) {
// "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
var response, exGet = getProfile(s, ctx, config, pk, docId)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
var response, exGet = getProfile(s, ctx, pk, docId)
if exGet != nil {
return exGet
}
response.Profile.DisplayName = displayName
var _, exReplace = setProfile(s, ctx, config, pk, *response)
var _, exReplace = setProfile(s, ctx, pk, *response)
if exReplace != nil {
return exReplace
}
@ -235,9 +231,8 @@ func (s *profilesStatements) selectProfilesBySearch(
var profiles []authtypes.Profile
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []ProfileCosmosData{}
params := map[string]interface{}{
"@x1": dbCollectionName,
@ -248,8 +243,8 @@ func (s *profilesStatements) selectProfilesBySearch(
var query = cosmosdbapi.GetQuery(s.selectProfilesBySearchStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)

View file

@ -48,20 +48,19 @@ type Database struct {
databaseName string
connection cosmosdbapi.CosmosConnection
cosmosConfig cosmosdbapi.CosmosConfig
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
connString := cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
connMap := cosmosdbutil.GetConnectionProperties(string(connString))
accountEndpoint := connMap["AccountEndpoint"]
accountKey := connMap["AccountKey"]
conn := cosmosdbapi.GetCosmosConnection(accountEndpoint, accountKey)
conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
d := &Database{
serverName: serverName,
databaseName: "userapi",
connection: conn,
cosmosConfig: config,
// db: db,
// writer: sqlutil.NewExclusiveWriter(),
// bcryptCost: bcryptCost,

View file

@ -74,9 +74,8 @@ func (s *threepidStatements) selectLocalpartForThreePID(
) (localpart string, err error) {
// "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []ThreePIDCosmosData{}
params := map[string]interface{}{
"@x1": dbCollectionName,
@ -87,8 +86,8 @@ func (s *threepidStatements) selectLocalpartForThreePID(
var query = cosmosdbapi.GetQuery(s.selectLocalpartForThreePIDStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
@ -109,9 +108,8 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
) (threepids []authtypes.ThreePID, err error) {
// "SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response := []ThreePIDCosmosData{}
params := map[string]interface{}{
"@x1": dbCollectionName,
@ -121,8 +119,8 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
var query = cosmosdbapi.GetQuery(s.selectThreePIDsForLocalpartStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
@ -156,14 +154,14 @@ func (s *threepidStatements) insertThreePID(
ThreePID: threepid,
}
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
id := fmt.Sprintf("%s_%s", threepid, medium)
docId := fmt.Sprintf("%s_%s", threepid, medium)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var dbData = ThreePIDCosmosData{
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id),
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
Timestamp: time.Now().Unix(),
ThreePID: result,
}
@ -171,8 +169,8 @@ func (s *threepidStatements) insertThreePID(
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
config.DatabaseName,
config.TenantName,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData,
options)
@ -186,16 +184,16 @@ func (s *threepidStatements) deleteThreePID(
ctx context.Context, threepid string, medium string) (err error) {
// "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
var config = cosmosdbapi.DefaultConfig()
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
id := fmt.Sprintf("%s_%s", threepid, medium)
pk := cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
docId := fmt.Sprintf("%s_%s", threepid, medium)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
_, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
config.DatabaseName,
config.TenantName,
id,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
cosmosDocId,
options)
if err != nil {

View file

@ -16,127 +16,145 @@ package cosmosdb
import (
"context"
"database/sql"
"strings"
"errors"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/gomatrixserverlib"
)
const devicesSchema = `
-- This sequence is used for automatic allocation of session_id.
-- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
// const devicesSchema = `
// -- This sequence is used for automatic allocation of session_id.
// -- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
-- Stores data about devices.
CREATE TABLE IF NOT EXISTS device_devices (
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
created_ts BIGINT,
display_name TEXT,
last_seen_ts BIGINT,
ip TEXT,
user_agent TEXT,
// -- Stores data about devices.
// CREATE TABLE IF NOT EXISTS device_devices (
// access_token TEXT PRIMARY KEY,
// session_id INTEGER,
// device_id TEXT ,
// localpart TEXT ,
// created_ts BIGINT,
// display_name TEXT,
// last_seen_ts BIGINT,
// ip TEXT,
// user_agent TEXT,
UNIQUE (localpart, device_id)
);
`
// UNIQUE (localpart, device_id)
// );
// `
const insertDeviceSQL = "" +
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
type DeviceCosmos struct {
ID string `json:"device_id"`
UserID string `json:"user_id"`
// The access_token granted to this device.
// This uniquely identifies the device from all other devices and clients.
AccessToken string `json:"access_token"`
// The unique ID of the session identified by the access token.
// Can be used as a secure substitution in places where data needs to be
// associated with access tokens.
SessionID int64 `json:"session_id"`
DisplayName string `json:"display_name"`
LastSeenTS int64 `json:"last_seen_ts"`
LastSeenIP string `json:"last_seen_ip"`
Localpart string `json:"local_part"`
UserAgent string `json:"user_agent"`
// If the device is for an appservice user,
// this is the appservice ID.
AppserviceID string `json:"app_service_id"`
}
const selectDevicesCountSQL = "" +
"SELECT COUNT(access_token) FROM device_devices"
type DeviceCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Device DeviceCosmos `json:"mx_userapi_device"`
}
const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
const updateDeviceLastSeen = "" +
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
type DeviceCosmosSessionCount struct {
SessionCount int64 `json:"sessioncount"`
}
type devicesStatements struct {
db *sql.DB
writer sqlutil.Writer
insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
selectDeviceByIDStmt *sql.Stmt
selectDevicesByIDStmt *sql.Stmt
selectDevicesByLocalpartStmt *sql.Stmt
updateDeviceNameStmt *sql.Stmt
updateDeviceLastSeenStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
db *Database
selectDevicesCountStmt string
selectDeviceByTokenStmt string
// selectDeviceByIDStmt *sql.Stmt
selectDevicesByIDStmt string
selectDevicesByLocalpartStmt string
selectDevicesByLocalpartExceptIDStmt string
serverName gomatrixserverlib.ServerName
tableName string
}
func (s *devicesStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(devicesSchema)
return err
func mapFromDevice(db DeviceCosmos) api.Device {
return api.Device{
AccessToken: db.AccessToken,
AppserviceID: db.AppserviceID,
ID: db.ID,
LastSeenIP: db.LastSeenIP,
LastSeenTS: db.LastSeenTS,
SessionID: db.SessionID,
UserAgent: db.UserAgent,
UserID: db.UserID,
}
}
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
func mapTodevice(api api.Device, s *devicesStatements) DeviceCosmos {
localPart, _ := userutil.ParseUsernameParam(api.UserID, &s.serverName)
return DeviceCosmos{
AccessToken: api.AccessToken,
AppserviceID: api.AppserviceID,
ID: api.ID,
LastSeenIP: api.LastSeenIP,
LastSeenTS: api.LastSeenTS,
Localpart: localPart,
SessionID: api.SessionID,
UserAgent: api.UserAgent,
UserID: api.UserID,
}
}
func getDevice(s *devicesStatements, ctx context.Context, pk string, docId string) (*DeviceCosmosData, error) {
response := DeviceCosmosData{}
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
docId,
optionsGet,
&response)
return &response, ex
}
func setDevice(s *devicesStatements, ctx context.Context, pk string, device DeviceCosmosData) (*DeviceCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, device.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
device.Id,
&device,
optionsReplace)
return &device, ex
}
func (s *devicesStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.writer = writer
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return
}
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil {
return
}
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
return
}
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
return
}
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
return
}
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
return
}
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
return
}
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
return
}
s.selectDevicesCountStmt = "select count(c._ts) as sessioncount from c where c._cn = @x1"
s.selectDevicesByLocalpartStmt = "select * from c where c._cn = @x1 and c.mx_userapi_device.local_part = @x2 and ARRAY_CONTAINS(@x3, c.mx_userapi_device.device_id)"
s.selectDevicesByLocalpartExceptIDStmt = "select * from c where c._cn = @x1 and c.mx_userapi_device.local_part = @x2 and c.mx_userapi_device.device_id != @x3"
s.selectDeviceByTokenStmt = "select * from c where c._cn = @x1 and c.mx_userapi_device.access_token = @x2"
s.selectDevicesByIDStmt = "select * from c where c._cn = @x1 and ARRAY_CONTAINS(@x2, c.mx_userapi_device.device_id)"
s.serverName = server
s.tableName = "device_devices"
return
}
@ -144,85 +162,219 @@ func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server go
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
func (s *devicesStatements) insertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
ctx context.Context, id, localpart, accessToken string,
displayName *string, ipAddr, userAgent string,
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
// "SELECT COUNT(access_token) FROM device_devices"
// HACK: Do we need a Cosmos Table for the sequence?
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosSessionCount
params := map[string]interface{}{
"@x1": dbCollectionName,
}
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectDevicesCountStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
sessionID = response[0].SessionCount
sessionID++
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
return nil, err
}
return &api.Device{
// "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
// " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
data := DeviceCosmos{
ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName),
AccessToken: accessToken,
SessionID: sessionID,
LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr,
Localpart: localpart,
UserAgent: userAgent,
}, nil
}
// access_token TEXT PRIMARY KEY,
// UNIQUE (localpart, device_id)
// HACK: check for duplicate PK as we are using the UNIQUE key for the DocId
docId := fmt.Sprintf("%s_%s", localpart, id)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var dbData = DeviceCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
Timestamp: time.Now().Unix(),
Device: data,
}
var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
var _, _, errCreate = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData,
optionsCreate)
if errCreate != nil {
return nil, errCreate
}
var result = mapFromDevice(dbData.Device)
return &result, nil
}
func (s *devicesStatements) deleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
ctx context.Context, id, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
_, err := stmt.ExecContext(ctx, id, localpart)
// "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
docId := fmt.Sprintf("%s_%s", localpart, id)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
cosmosDocId,
options)
if err != nil {
return err
}
return err
}
func (s *devicesStatements) deleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
ctx context.Context, localpart string, devices []string,
) error {
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
prep, err := s.db.Prepare(orig)
// "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
"@x3": devices,
}
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return err
}
stmt := sqlutil.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+1)
params[0] = localpart
for i, v := range devices {
params[i+1] = v
for _, item := range response {
s.deleteDevice(ctx, item.Device.ID, item.Device.Localpart)
}
_, err = stmt.ExecContext(ctx, params...)
return err
}
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
ctx context.Context, localpart, exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
// "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData
exceptDevices := []string{
exceptDeviceID,
}
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
"@x3": exceptDevices,
}
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return err
}
for _, item := range response {
s.deleteDevice(ctx, item.Device.ID, item.Device.Localpart)
}
return err
}
func (s *devicesStatements) updateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
return err
// "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
if exGet != nil {
return exGet
}
response.Device.DisplayName = *displayName
var _, exReplace = setDevice(s, ctx, pk, *response)
if exReplace != nil {
return exReplace
}
return exReplace
}
func (s *devicesStatements) selectDeviceByToken(
ctx context.Context, accessToken string,
) (*api.Device, error) {
var dev api.Device
var localpart string
stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
if err == nil {
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.AccessToken = accessToken
// "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": accessToken,
}
return &dev, err
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectDeviceByTokenStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
if len(response) == 0 {
return nil, errors.New(fmt.Sprintf("No Devices found with AccessToken %s", accessToken))
}
if err == nil {
result := mapFromDevice(response[0].Device)
return &result, nil
}
return nil, err
}
// selectDeviceByID retrieves a device from the database with the given user
@ -230,54 +382,48 @@ func (s *devicesStatements) selectDeviceByToken(
func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
var dev api.Device
var displayName sql.NullString
stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
if err == nil {
dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
if displayName.Valid {
dev.DisplayName = displayName.String
}
// "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
if exGet != nil {
return nil, exGet
}
return &dev, err
result := mapFromDevice(response.Device)
return &result, nil
}
func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
ctx context.Context, localpart, exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
if err != nil {
return devices, err
// "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": localpart,
"@x3": exceptDeviceID,
}
for rows.Next() {
var dev api.Device
var lastseents sql.NullInt64
var id, displayname, ip, useragent sql.NullString
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
if err != nil {
return devices, err
}
if id.Valid {
dev.ID = id.String
}
if displayname.Valid {
dev.DisplayName = displayname.String
}
if lastseents.Valid {
dev.LastSeenTS = lastseents.Int64
}
if ip.Valid {
dev.LastSeenIP = ip.String
}
if useragent.Valid {
dev.UserAgent = useragent.String
}
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartExceptIDStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
for _, item := range response {
dev := mapFromDevice(item.Device)
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev)
}
@ -286,37 +432,53 @@ func (s *devicesStatements) selectDevicesByLocalpart(
}
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
iDeviceIDs := make([]interface{}, len(deviceIDs))
for i := range deviceIDs {
iDeviceIDs[i] = deviceIDs[i]
// "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
var devices []api.Device
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": deviceIDs,
}
rows, err := s.db.QueryContext(ctx, sqlQuery, iDeviceIDs...)
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectDevicesByIDStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
var devices []api.Device
for rows.Next() {
var dev api.Device
var localpart string
var displayName sql.NullString
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
return nil, err
}
if displayName.Valid {
dev.DisplayName = displayName.String
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
for _, item := range response {
dev := mapFromDevice(item.Device)
devices = append(devices, dev)
}
return devices, rows.Err()
return devices, nil
}
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)
return err
// "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
if exGet != nil {
return exGet
}
response.Device.LastSeenTS = lastSeenTs
var _, exReplace = setDevice(s, ctx, pk, *response)
if exReplace != nil {
return exReplace
}
return exReplace
}

View file

@ -15,16 +15,18 @@
package cosmosdb
import (
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
"github.com/matrix-org/gomatrixserverlib"
_ "github.com/mattn/go-sqlite3"
@ -35,35 +37,32 @@ var deviceIDByteLength = 6
// Database represents a device database.
type Database struct {
db *sql.DB
writer sqlutil.Writer
devices devicesStatements
writer sqlutil.Writer
devices devicesStatements
connection cosmosdbapi.CosmosConnection
databaseName string
cosmosConfig cosmosdbapi.CosmosConfig
serverName gomatrixserverlib.ServerName
}
// NewDatabase creates a new device database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) {
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
writer := sqlutil.NewExclusiveWriter()
d := devicesStatements{}
conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
devices := devicesStatements{}
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.execSchema(db); err != nil {
return nil, err
d := &Database{
databaseName: "userapi",
devices: devices,
serverName: serverName,
connection: conn,
cosmosConfig: config,
}
m := sqlutil.NewMigrations()
deltas.LoadLastSeenTSIP(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = d.prepare(db, writer, serverName); err != nil {
return nil, err
}
return &Database{db, writer, d}, nil
err := d.devices.prepare(d, serverName)
return d, err
}
// GetDeviceByAccessToken returns the device matching the given access token.
@ -86,7 +85,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
return d.devices.selectDevicesByLocalpart(ctx, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -104,16 +103,14 @@ func (d *Database) CreateDevice(
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, *deviceID, localpart); err != nil {
return nil, err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
dev, err = d.devices.insertDevice(ctx, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return dev, err
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
@ -124,11 +121,9 @@ func (d *Database) CreateDevice(
return
}
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
var err error
dev, err = d.devices.insertDevice(ctx, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return dev, err
if returnErr == nil {
return
}
@ -154,9 +149,7 @@ func generateDeviceID() (string, error) {
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
return d.devices.updateDeviceName(ctx, localpart, deviceID, displayName)
}
// RemoveDevice revokes a device by deleting the entry in the database
@ -166,12 +159,10 @@ func (d *Database) UpdateDevice(
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
if err := d.devices.deleteDevice(ctx, deviceID, localpart); err != nil {
return err
}
return nil
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
@ -181,12 +172,10 @@ func (d *Database) RemoveDevice(
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
if err := d.devices.deleteDevices(ctx, localpart, devices); err != nil {
return err
}
return nil
}
// RemoveAllDevices revokes devices by deleting the entry in the
@ -195,22 +184,17 @@ func (d *Database) RemoveDevices(
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
})
return
devices, err = d.devices.selectDevicesByLocalpart(ctx, localpart, exceptDeviceID)
if err != nil {
return nil, err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, localpart, exceptDeviceID); err != nil {
return nil, err
}
return devices, nil
}
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
return d.devices.updateDeviceLastSeen(ctx, localpart, deviceID, ipAddr)
}