From b696923333c14615d9011ef4dc858d1056b5f0aa Mon Sep 17 00:00:00 2001 From: alexfca <75228224+alexfca@users.noreply.github.com> Date: Wed, 12 May 2021 16:30:49 +1000 Subject: [PATCH] - Implement Cosmos for the devices_table (#4) - Use the ConnectionString in the YAML to include the Tenant - Revert all other non implemented tables back to use SQLLite3 --- appservice/storage/cosmosdb/storage.go | 2 - dendrite-config-cosmosdb.yaml | 4 +- federationsender/storage/cosmosdb/storage.go | 2 - internal/cosmosdbapi/cosmosconfig.go | 6 + internal/cosmosdbapi/tenant.go | 14 - internal/cosmosdbutil/connection.go | 40 +- keyserver/storage/cosmosdb/storage.go | 2 - mediaapi/storage/cosmosdb/storage.go | 2 - roomserver/storage/cosmosdb/storage.go | 2 - setup/kafka/kafka.go | 6 +- signingkeyserver/storage/cosmosdb/keydb.go | 2 - syncapi/storage/cosmosdb/syncserver.go | 2 - .../accounts/cosmosdb/account_data_table.go | 23 +- .../accounts/cosmosdb/accounts_table.go | 60 +- .../storage/accounts/cosmosdb/openid_table.go | 16 +- .../accounts/cosmosdb/profile_table.go | 53 +- userapi/storage/accounts/cosmosdb/storage.go | 9 +- .../accounts/cosmosdb/threepid_table.go | 38 +- .../storage/devices/cosmosdb/devices_table.go | 558 +++++++++++------- userapi/storage/devices/cosmosdb/storage.go | 120 ++-- 20 files changed, 547 insertions(+), 414 deletions(-) create mode 100644 internal/cosmosdbapi/cosmosconfig.go delete mode 100644 internal/cosmosdbapi/tenant.go diff --git a/appservice/storage/cosmosdb/storage.go b/appservice/storage/cosmosdb/storage.go index 2f07167b9..3639010e1 100644 --- a/appservice/storage/cosmosdb/storage.go +++ b/appservice/storage/cosmosdb/storage.go @@ -16,7 +16,6 @@ package cosmosdb import ( - "github.com/matrix-org/dendrite/internal/cosmosdbutil" "context" "database/sql" @@ -38,7 +37,6 @@ type Database struct { // NewDatabase opens a new database func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { - dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) var result Database var err error if result.db, err = sqlutil.Open(dbProperties); err != nil { diff --git a/dendrite-config-cosmosdb.yaml b/dendrite-config-cosmosdb.yaml index f3be7dcad..189abe766 100644 --- a/dendrite-config-cosmosdb.yaml +++ b/dendrite-config-cosmosdb.yaml @@ -354,12 +354,12 @@ user_api: listen: http://localhost:7781 connect: http://localhost:7781 account_database: - connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==" + connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=criticalarc.com;" max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 device_database: - connection_string: file:userapi_devices.db + connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=criticalarc.com;" max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 diff --git a/federationsender/storage/cosmosdb/storage.go b/federationsender/storage/cosmosdb/storage.go index fb38d6e6d..da429046b 100644 --- a/federationsender/storage/cosmosdb/storage.go +++ b/federationsender/storage/cosmosdb/storage.go @@ -16,7 +16,6 @@ package cosmosdb import ( - "github.com/matrix-org/dendrite/internal/cosmosdbutil" "database/sql" _ "github.com/mattn/go-sqlite3" @@ -38,7 +37,6 @@ type Database struct { // NewDatabase opens a new database func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (*Database, error) { - dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) var d Database var err error if d.db, err = sqlutil.Open(dbProperties); err != nil { diff --git a/internal/cosmosdbapi/cosmosconfig.go b/internal/cosmosdbapi/cosmosconfig.go new file mode 100644 index 000000000..b6b715303 --- /dev/null +++ b/internal/cosmosdbapi/cosmosconfig.go @@ -0,0 +1,6 @@ +package cosmosdbapi + +type CosmosConfig struct { + DatabaseName string + ContainerName string +} diff --git a/internal/cosmosdbapi/tenant.go b/internal/cosmosdbapi/tenant.go deleted file mode 100644 index d9cb825f2..000000000 --- a/internal/cosmosdbapi/tenant.go +++ /dev/null @@ -1,14 +0,0 @@ -package cosmosdbapi - -type Tenant struct { - DatabaseName string - TenantName string -} - -//TODO: Move into Config or the JWT -func DefaultConfig() Tenant { - return Tenant{ - DatabaseName: "safezone_local", - TenantName: "criticalarc.com", - } -} diff --git a/internal/cosmosdbutil/connection.go b/internal/cosmosdbutil/connection.go index d8b6527a9..5767201b5 100644 --- a/internal/cosmosdbutil/connection.go +++ b/internal/cosmosdbutil/connection.go @@ -1,22 +1,50 @@ package cosmosdbutil import ( - "github.com/matrix-org/dendrite/setup/config" "strings" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/setup/config" ) -func GetConnectionString(d *config.DataSource) config.DataSource { +const accountEndpointName = "AccountEndpoint" +const accountKeyName = "AccountKey" +const databaseName = "DatabaseName" +const containerName = "ContainerName" + +func getConnectionString(d *config.DataSource) config.DataSource { var connString string connString = string(*d) return config.DataSource(strings.Replace(connString, "cosmosdb:", "", 1)) } -func GetConnectionProperties(connectionString string) map[string]string { +func getConnectionProperties(connectionString string) map[string]string { connectionItemsRaw := strings.Split(connectionString, ";") connectionItems := map[string]string{} for _, item := range connectionItemsRaw { - itemSplit := strings.SplitN(item, "=", 2) - connectionItems[itemSplit[0]] = itemSplit[1] + if len(item) > 0 { + itemSplit := strings.SplitN(item, "=", 2) + connectionItems[itemSplit[0]] = itemSplit[1] + } } return connectionItems -} \ No newline at end of file +} + +func GetCosmosConnection(d *config.DataSource) cosmosdbapi.CosmosConnection { + connString := getConnectionString(d) + connMap := getConnectionProperties(string(connString)) + accountEndpoint := connMap[accountEndpointName] + accountKey := connMap[accountKeyName] + return cosmosdbapi.GetCosmosConnection(accountEndpoint, accountKey) +} + +func GetCosmosConfig(d *config.DataSource) cosmosdbapi.CosmosConfig { + connString := getConnectionString(d) + connMap := getConnectionProperties(string(connString)) + database := connMap[databaseName] + container := connMap[containerName] + return cosmosdbapi.CosmosConfig{ + DatabaseName: database, + ContainerName: container, + } +} diff --git a/keyserver/storage/cosmosdb/storage.go b/keyserver/storage/cosmosdb/storage.go index c4a0c0c97..ba000cb24 100644 --- a/keyserver/storage/cosmosdb/storage.go +++ b/keyserver/storage/cosmosdb/storage.go @@ -15,14 +15,12 @@ package cosmosdb import ( - "github.com/matrix-org/dendrite/internal/cosmosdbutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/storage/shared" "github.com/matrix-org/dendrite/setup/config" ) func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) { - dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err diff --git a/mediaapi/storage/cosmosdb/storage.go b/mediaapi/storage/cosmosdb/storage.go index 43b2879df..b05373868 100644 --- a/mediaapi/storage/cosmosdb/storage.go +++ b/mediaapi/storage/cosmosdb/storage.go @@ -16,7 +16,6 @@ package cosmosdb import ( - "github.com/matrix-org/dendrite/internal/cosmosdbutil" "context" "database/sql" @@ -37,7 +36,6 @@ type Database struct { // Open opens a postgres database. func Open(dbProperties *config.DatabaseOptions) (*Database, error) { - dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) d := Database{ writer: sqlutil.NewExclusiveWriter(), } diff --git a/roomserver/storage/cosmosdb/storage.go b/roomserver/storage/cosmosdb/storage.go index aa712d07d..bb3f6af2e 100644 --- a/roomserver/storage/cosmosdb/storage.go +++ b/roomserver/storage/cosmosdb/storage.go @@ -16,7 +16,6 @@ package cosmosdb import ( - "github.com/matrix-org/dendrite/internal/cosmosdbutil" "context" "database/sql" @@ -38,7 +37,6 @@ type Database struct { // Open a sqlite database. func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { - dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) var d Database var db *sql.DB var err error diff --git a/setup/kafka/kafka.go b/setup/kafka/kafka.go index 431da23b6..936115a37 100644 --- a/setup/kafka/kafka.go +++ b/setup/kafka/kafka.go @@ -1,7 +1,6 @@ package kafka import ( - "github.com/matrix-org/dendrite/internal/cosmosdbutil" "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/naffka" @@ -47,8 +46,9 @@ func setupNaffka(cfg *config.Kafka) (sarama.Consumer, sarama.SyncProducer) { if naffkaInstance != nil { return naffkaInstance, naffkaInstance } - if(cfg.Database.ConnectionString.IsCosmosDB()) { - cfg.Database.ConnectionString = cosmosdbutil.GetConnectionString(&cfg.Database.ConnectionString) + if cfg.Database.ConnectionString.IsCosmosDB() { + //TODO: What do we do for Nafka + // cfg.Database.ConnectionString = cosmosdbutil.GetConnectionString(&cfg.Database.ConnectionString) } naffkaDB, err := naffkaStorage.NewDatabase(string(cfg.Database.ConnectionString)) diff --git a/signingkeyserver/storage/cosmosdb/keydb.go b/signingkeyserver/storage/cosmosdb/keydb.go index 46c95d88a..0f4371bce 100644 --- a/signingkeyserver/storage/cosmosdb/keydb.go +++ b/signingkeyserver/storage/cosmosdb/keydb.go @@ -16,7 +16,6 @@ package cosmosdb import ( - "github.com/matrix-org/dendrite/internal/cosmosdbutil" "context" "golang.org/x/crypto/ed25519" @@ -45,7 +44,6 @@ func NewDatabase( serverKey ed25519.PublicKey, serverKeyID gomatrixserverlib.KeyID, ) (*Database, error) { - dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err diff --git a/syncapi/storage/cosmosdb/syncserver.go b/syncapi/storage/cosmosdb/syncserver.go index 719c8fdad..7bf1a1387 100644 --- a/syncapi/storage/cosmosdb/syncserver.go +++ b/syncapi/storage/cosmosdb/syncserver.go @@ -16,7 +16,6 @@ package cosmosdb import ( - "github.com/matrix-org/dendrite/internal/cosmosdbutil" "database/sql" // Import the sqlite3 package @@ -41,7 +40,6 @@ type SyncServerDatasource struct { // NewDatabase creates a new sync server database // nolint: gocyclo func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { - dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) var d SyncServerDatasource var err error if d.db, err = sqlutil.Open(dbProperties); err != nil { diff --git a/userapi/storage/accounts/cosmosdb/account_data_table.go b/userapi/storage/accounts/cosmosdb/account_data_table.go index 67a406435..6a471d07a 100644 --- a/userapi/storage/accounts/cosmosdb/account_data_table.go +++ b/userapi/storage/accounts/cosmosdb/account_data_table.go @@ -84,7 +84,6 @@ func (s *accountDataStatements) insertAccountData( Content: content, } - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName) id := "" if roomID == "" { @@ -94,9 +93,9 @@ func (s *accountDataStatements) insertAccountData( } var dbData = AccountDataCosmosData{ - Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id), + Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, id), Cn: dbCollectionName, - Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), + Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName), Timestamp: time.Now().Unix(), AccountData: result, } @@ -104,8 +103,8 @@ func (s *accountDataStatements) insertAccountData( var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, dbData, options) @@ -120,9 +119,8 @@ func (s *accountDataStatements) selectAccountData( error, ) { // "SELECT room_id, type, content FROM account_data WHERE localpart = $1" - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName) - var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) response := []AccountDataCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, @@ -132,8 +130,8 @@ func (s *accountDataStatements) selectAccountData( var query = cosmosdbapi.GetQuery(s.selectAccountDataStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, query, &response, options) @@ -167,9 +165,8 @@ func (s *accountDataStatements) selectAccountDataByType( var bytes []byte // "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName) - var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) response := []AccountDataCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, @@ -181,8 +178,8 @@ func (s *accountDataStatements) selectAccountDataByType( var query = cosmosdbapi.GetQuery(s.selectAccountDataByTypeStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, query, &response, options) diff --git a/userapi/storage/accounts/cosmosdb/accounts_table.go b/userapi/storage/accounts/cosmosdb/accounts_table.go index 1f8dc866f..d20e01af3 100644 --- a/userapi/storage/accounts/cosmosdb/accounts_table.go +++ b/userapi/storage/accounts/cosmosdb/accounts_table.go @@ -87,26 +87,26 @@ func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.Serv return } -func getAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*AccountCosmosData, error) { +func getAccount(s *accountsStatements, ctx context.Context, pk string, docId string) (*AccountCosmosData, error) { response := AccountCosmosData{} var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk) var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, docId, optionsGet, &response) return &response, ex } -func setAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, account AccountCosmosData) (*AccountCosmosData, error) { +func setAccount(s *accountsStatements, ctx context.Context, pk string, account AccountCosmosData) (*AccountCosmosData, error) { response := AccountCosmosData{} var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, account.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, account.Id, &account, optionsReplace) @@ -153,13 +153,12 @@ func (s *accountsStatements) insertAccount( data.PasswordHash = hash data.IsDeactivated = false - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) var dbData = AccountCosmosData{ - Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart), + Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Localpart), Cn: dbCollectionName, - Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), + Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName), Timestamp: time.Now().Unix(), Account: data, } @@ -167,8 +166,8 @@ func (s *accountsStatements) insertAccount( var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, dbData, options) @@ -184,19 +183,18 @@ func (s *accountsStatements) updatePassword( ) (err error) { // "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) - var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart) - var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) - var response, exGet = getAccount(s, ctx, config, pk, docId) + var response, exGet = getAccount(s, ctx, pk, docId) if exGet != nil { return exGet } response.Account.PasswordHash = passwordHash - var _, exReplace = setAccount(s, ctx, config, pk, *response) + var _, exReplace = setAccount(s, ctx, pk, *response) if exReplace != nil { return exReplace } @@ -208,19 +206,18 @@ func (s *accountsStatements) deactivateAccount( ) (err error) { // "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) - var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart) - var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) - var response, exGet = getAccount(s, ctx, config, pk, docId) + var response, exGet = getAccount(s, ctx, pk, docId) if exGet != nil { return exGet } response.Account.IsDeactivated = true - var _, exReplace = setAccount(s, ctx, config, pk, *response) + var _, exReplace = setAccount(s, ctx, pk, *response) if exReplace != nil { return exReplace } @@ -232,9 +229,8 @@ func (s *accountsStatements) selectPasswordHash( ) (hash string, err error) { // "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) - var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) response := []AccountCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, @@ -244,8 +240,8 @@ func (s *accountsStatements) selectPasswordHash( var query = cosmosdbapi.GetQuery(s.selectPasswordHashStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, query, &response, options) @@ -271,9 +267,8 @@ func (s *accountsStatements) selectAccountByLocalpart( var acc api.Account // "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) - var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) response := []AccountCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, @@ -283,8 +278,8 @@ func (s *accountsStatements) selectAccountByLocalpart( var query = cosmosdbapi.GetQuery(s.selectAccountByLocalpartStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, query, &response, options) @@ -309,9 +304,8 @@ func (s *accountsStatements) selectNewNumericLocalpart( ) (id int64, err error) { // "SELECT COUNT(localpart) FROM account_accounts" - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) - var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) var response []AccountCosmosUserCount params := map[string]interface{}{ "@x1": dbCollectionName, @@ -320,8 +314,8 @@ func (s *accountsStatements) selectNewNumericLocalpart( var query = cosmosdbapi.GetQuery(s.selectNewNumericLocalpartStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, query, &response, options) diff --git a/userapi/storage/accounts/cosmosdb/openid_table.go b/userapi/storage/accounts/cosmosdb/openid_table.go index 3b62d244b..2567b8857 100644 --- a/userapi/storage/accounts/cosmosdb/openid_table.go +++ b/userapi/storage/accounts/cosmosdb/openid_table.go @@ -85,13 +85,12 @@ func (s *tokenStatements) insertToken( ExpiresAtMS: expiresAtMS, } - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) var dbData = OpenIdTokenCosmosData{ - Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Token), + Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Token), Cn: dbCollectionName, - Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), + Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName), Timestamp: time.Now().Unix(), OpenIdToken: mapToToken(*result), } @@ -99,8 +98,8 @@ func (s *tokenStatements) insertToken( var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).CreateDocument( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, dbData, options) @@ -120,9 +119,8 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes( var openIDTokenAttrs api.OpenIDTokenAttributes // "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) - var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) response := []OpenIdTokenCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, @@ -132,8 +130,8 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes( var query = cosmosdbapi.GetQuery(s.selectTokenStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, query, &response, options) diff --git a/userapi/storage/accounts/cosmosdb/profile_table.go b/userapi/storage/accounts/cosmosdb/profile_table.go index 9b47acc09..bb02f4867 100644 --- a/userapi/storage/accounts/cosmosdb/profile_table.go +++ b/userapi/storage/accounts/cosmosdb/profile_table.go @@ -87,25 +87,25 @@ func (s *profilesStatements) prepare(db *Database) (err error) { return } -func getProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*ProfileCosmosData, error) { +func getProfile(s *profilesStatements, ctx context.Context, pk string, docId string) (*ProfileCosmosData, error) { response := ProfileCosmosData{} var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk) var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, docId, optionsGet, &response) return &response, ex } -func setProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) { +func setProfile(s *profilesStatements, ctx context.Context, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) { var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, profile.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, profile.Id, &profile, optionsReplace) @@ -121,13 +121,12 @@ func (s *profilesStatements) insertProfile( Localpart: localpart, } - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) var dbData = ProfileCosmosData{ - Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart), + Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Localpart), Cn: dbCollectionName, - Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), + Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName), Timestamp: time.Now().Unix(), Profile: mapToProfile(*result), } @@ -135,8 +134,8 @@ func (s *profilesStatements) insertProfile( var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, dbData, options) @@ -148,9 +147,8 @@ func (s *profilesStatements) selectProfileByLocalpart( ) (*authtypes.Profile, error) { // "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) - var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) response := []ProfileCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, @@ -160,8 +158,8 @@ func (s *profilesStatements) selectProfileByLocalpart( var query = cosmosdbapi.GetQuery(s.selectProfileByLocalpartStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, query, &response, options) @@ -187,19 +185,18 @@ func (s *profilesStatements) setAvatarURL( ) (err error) { // "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) - var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) - var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart) - var response, exGet = getProfile(s, ctx, config, pk, docId) + var response, exGet = getProfile(s, ctx, pk, docId) if exGet != nil { return exGet } response.Profile.AvatarURL = avatarURL - var _, exReplace = setProfile(s, ctx, config, pk, *response) + var _, exReplace = setProfile(s, ctx, pk, *response) if exReplace != nil { return exReplace } @@ -211,18 +208,17 @@ func (s *profilesStatements) setDisplayName( ) (err error) { // "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) - var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) - var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart) - var response, exGet = getProfile(s, ctx, config, pk, docId) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart) + var response, exGet = getProfile(s, ctx, pk, docId) if exGet != nil { return exGet } response.Profile.DisplayName = displayName - var _, exReplace = setProfile(s, ctx, config, pk, *response) + var _, exReplace = setProfile(s, ctx, pk, *response) if exReplace != nil { return exReplace } @@ -235,9 +231,8 @@ func (s *profilesStatements) selectProfilesBySearch( var profiles []authtypes.Profile // "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) - var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) response := []ProfileCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, @@ -248,8 +243,8 @@ func (s *profilesStatements) selectProfilesBySearch( var query = cosmosdbapi.GetQuery(s.selectProfilesBySearchStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, query, &response, options) diff --git a/userapi/storage/accounts/cosmosdb/storage.go b/userapi/storage/accounts/cosmosdb/storage.go index 3fc2e7d88..20c4d2071 100644 --- a/userapi/storage/accounts/cosmosdb/storage.go +++ b/userapi/storage/accounts/cosmosdb/storage.go @@ -48,20 +48,19 @@ type Database struct { databaseName string connection cosmosdbapi.CosmosConnection + cosmosConfig cosmosdbapi.CosmosConfig } // NewDatabase creates a new accounts and profiles database func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { - connString := cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) - connMap := cosmosdbutil.GetConnectionProperties(string(connString)) - accountEndpoint := connMap["AccountEndpoint"] - accountKey := connMap["AccountKey"] - conn := cosmosdbapi.GetCosmosConnection(accountEndpoint, accountKey) + conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString) + config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString) d := &Database{ serverName: serverName, databaseName: "userapi", connection: conn, + cosmosConfig: config, // db: db, // writer: sqlutil.NewExclusiveWriter(), // bcryptCost: bcryptCost, diff --git a/userapi/storage/accounts/cosmosdb/threepid_table.go b/userapi/storage/accounts/cosmosdb/threepid_table.go index 94bd67ce3..b8bf12263 100644 --- a/userapi/storage/accounts/cosmosdb/threepid_table.go +++ b/userapi/storage/accounts/cosmosdb/threepid_table.go @@ -74,9 +74,8 @@ func (s *threepidStatements) selectLocalpartForThreePID( ) (localpart string, err error) { // "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName) - var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) response := []ThreePIDCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, @@ -87,8 +86,8 @@ func (s *threepidStatements) selectLocalpartForThreePID( var query = cosmosdbapi.GetQuery(s.selectLocalpartForThreePIDStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, query, &response, options) @@ -109,9 +108,8 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( ) (threepids []authtypes.ThreePID, err error) { // "SELECT threepid, medium FROM account_threepid WHERE localpart = $1" - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName) - var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) response := []ThreePIDCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, @@ -121,8 +119,8 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( var query = cosmosdbapi.GetQuery(s.selectThreePIDsForLocalpartStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, query, &response, options) @@ -156,14 +154,14 @@ func (s *threepidStatements) insertThreePID( ThreePID: threepid, } - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) - id := fmt.Sprintf("%s_%s", threepid, medium) + docId := fmt.Sprintf("%s_%s", threepid, medium) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) var dbData = ThreePIDCosmosData{ - Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id), + Id: cosmosDocId, Cn: dbCollectionName, - Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), + Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName), Timestamp: time.Now().Unix(), ThreePID: result, } @@ -171,8 +169,8 @@ func (s *threepidStatements) insertThreePID( var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( ctx, - config.DatabaseName, - config.TenantName, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, dbData, options) @@ -186,16 +184,16 @@ func (s *threepidStatements) deleteThreePID( ctx context.Context, threepid string, medium string) (err error) { // "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" - var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) - id := fmt.Sprintf("%s_%s", threepid, medium) - pk := cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + docId := fmt.Sprintf("%s_%s", threepid, medium) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) var options = cosmosdbapi.GetDeleteDocumentOptions(pk) _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, - config.DatabaseName, - config.TenantName, - id, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + cosmosDocId, options) if err != nil { diff --git a/userapi/storage/devices/cosmosdb/devices_table.go b/userapi/storage/devices/cosmosdb/devices_table.go index f52e76507..d968c6208 100644 --- a/userapi/storage/devices/cosmosdb/devices_table.go +++ b/userapi/storage/devices/cosmosdb/devices_table.go @@ -16,127 +16,145 @@ package cosmosdb import ( "context" - "database/sql" - "strings" + "errors" + "fmt" "time" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/gomatrixserverlib" ) -const devicesSchema = ` --- This sequence is used for automatic allocation of session_id. --- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1; +// const devicesSchema = ` +// -- This sequence is used for automatic allocation of session_id. +// -- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1; --- Stores data about devices. -CREATE TABLE IF NOT EXISTS device_devices ( - access_token TEXT PRIMARY KEY, - session_id INTEGER, - device_id TEXT , - localpart TEXT , - created_ts BIGINT, - display_name TEXT, - last_seen_ts BIGINT, - ip TEXT, - user_agent TEXT, +// -- Stores data about devices. +// CREATE TABLE IF NOT EXISTS device_devices ( +// access_token TEXT PRIMARY KEY, +// session_id INTEGER, +// device_id TEXT , +// localpart TEXT , +// created_ts BIGINT, +// display_name TEXT, +// last_seen_ts BIGINT, +// ip TEXT, +// user_agent TEXT, - UNIQUE (localpart, device_id) -); -` +// UNIQUE (localpart, device_id) +// ); +// ` -const insertDeviceSQL = "" + - "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +type DeviceCosmos struct { + ID string `json:"device_id"` + UserID string `json:"user_id"` + // The access_token granted to this device. + // This uniquely identifies the device from all other devices and clients. + AccessToken string `json:"access_token"` + // The unique ID of the session identified by the access token. + // Can be used as a secure substitution in places where data needs to be + // associated with access tokens. + SessionID int64 `json:"session_id"` + DisplayName string `json:"display_name"` + LastSeenTS int64 `json:"last_seen_ts"` + LastSeenIP string `json:"last_seen_ip"` + Localpart string `json:"local_part"` + UserAgent string `json:"user_agent"` + // If the device is for an appservice user, + // this is the appservice ID. + AppserviceID string `json:"app_service_id"` +} -const selectDevicesCountSQL = "" + - "SELECT COUNT(access_token) FROM device_devices" +type DeviceCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Device DeviceCosmos `json:"mx_userapi_device"` +} -const selectDeviceByTokenSQL = "" + - "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" - -const selectDeviceByIDSQL = "" + - "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" - -const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2" - -const updateDeviceNameSQL = "" + - "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" - -const deleteDeviceSQL = "" + - "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" - -const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" - -const deleteDevicesSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" - -const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)" - -const updateDeviceLastSeen = "" + - "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4" +type DeviceCosmosSessionCount struct { + SessionCount int64 `json:"sessioncount"` +} type devicesStatements struct { - db *sql.DB - writer sqlutil.Writer - insertDeviceStmt *sql.Stmt - selectDevicesCountStmt *sql.Stmt - selectDeviceByTokenStmt *sql.Stmt - selectDeviceByIDStmt *sql.Stmt - selectDevicesByIDStmt *sql.Stmt - selectDevicesByLocalpartStmt *sql.Stmt - updateDeviceNameStmt *sql.Stmt - updateDeviceLastSeenStmt *sql.Stmt - deleteDeviceStmt *sql.Stmt - deleteDevicesByLocalpartStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + db *Database + selectDevicesCountStmt string + selectDeviceByTokenStmt string + // selectDeviceByIDStmt *sql.Stmt + selectDevicesByIDStmt string + selectDevicesByLocalpartStmt string + selectDevicesByLocalpartExceptIDStmt string + serverName gomatrixserverlib.ServerName + tableName string } -func (s *devicesStatements) execSchema(db *sql.DB) error { - _, err := db.Exec(devicesSchema) - return err +func mapFromDevice(db DeviceCosmos) api.Device { + return api.Device{ + AccessToken: db.AccessToken, + AppserviceID: db.AppserviceID, + ID: db.ID, + LastSeenIP: db.LastSeenIP, + LastSeenTS: db.LastSeenTS, + SessionID: db.SessionID, + UserAgent: db.UserAgent, + UserID: db.UserID, + } } -func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { +func mapTodevice(api api.Device, s *devicesStatements) DeviceCosmos { + localPart, _ := userutil.ParseUsernameParam(api.UserID, &s.serverName) + return DeviceCosmos{ + AccessToken: api.AccessToken, + AppserviceID: api.AppserviceID, + ID: api.ID, + LastSeenIP: api.LastSeenIP, + LastSeenTS: api.LastSeenTS, + Localpart: localPart, + SessionID: api.SessionID, + UserAgent: api.UserAgent, + UserID: api.UserID, + } +} + +func getDevice(s *devicesStatements, ctx context.Context, pk string, docId string) (*DeviceCosmosData, error) { + response := DeviceCosmosData{} + var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk) + var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + docId, + optionsGet, + &response) + return &response, ex +} + +func setDevice(s *devicesStatements, ctx context.Context, pk string, device DeviceCosmosData) (*DeviceCosmosData, error) { + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, device.ETag) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + device.Id, + &device, + optionsReplace) + return &device, ex +} + +func (s *devicesStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) { s.db = db - s.writer = writer - if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { - return - } - if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil { - return - } - if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { - return - } - if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil { - return - } - if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil { - return - } - if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil { - return - } - if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil { - return - } - if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { - return - } - if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { - return - } - if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil { - return - } + s.selectDevicesCountStmt = "select count(c._ts) as sessioncount from c where c._cn = @x1" + s.selectDevicesByLocalpartStmt = "select * from c where c._cn = @x1 and c.mx_userapi_device.local_part = @x2 and ARRAY_CONTAINS(@x3, c.mx_userapi_device.device_id)" + s.selectDevicesByLocalpartExceptIDStmt = "select * from c where c._cn = @x1 and c.mx_userapi_device.local_part = @x2 and c.mx_userapi_device.device_id != @x3" + s.selectDeviceByTokenStmt = "select * from c where c._cn = @x1 and c.mx_userapi_device.access_token = @x2" + s.selectDevicesByIDStmt = "select * from c where c._cn = @x1 and ARRAY_CONTAINS(@x2, c.mx_userapi_device.device_id)" s.serverName = server + s.tableName = "device_devices" return } @@ -144,85 +162,219 @@ func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server go // Returns an error if the user already has a device with the given device ID. // Returns the device on success. func (s *devicesStatements) insertDevice( - ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, + ctx context.Context, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 - countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt) - insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) - if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { + // "SELECT COUNT(access_token) FROM device_devices" + // HACK: Do we need a Cosmos Table for the sequence? + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []DeviceCosmosSessionCount + params := map[string]interface{}{ + "@x1": dbCollectionName, + } + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(s.selectDevicesCountStmt, params) + var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } + sessionID = response[0].SessionCount sessionID++ - if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { - return nil, err - } - return &api.Device{ + // "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + + // " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + + data := DeviceCosmos{ ID: id, UserID: userutil.MakeUserID(localpart, s.serverName), AccessToken: accessToken, SessionID: sessionID, LastSeenTS: createdTimeMS, LastSeenIP: ipAddr, + Localpart: localpart, UserAgent: userAgent, - }, nil + } + + // access_token TEXT PRIMARY KEY, + // UNIQUE (localpart, device_id) + // HACK: check for duplicate PK as we are using the UNIQUE key for the DocId + docId := fmt.Sprintf("%s_%s", localpart, id) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + var dbData = DeviceCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName), + Timestamp: time.Now().Unix(), + Device: data, + } + + var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + var _, _, errCreate = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData, + optionsCreate) + + if errCreate != nil { + return nil, errCreate + } + + var result = mapFromDevice(dbData.Device) + return &result, nil } func (s *devicesStatements) deleteDevice( - ctx context.Context, txn *sql.Tx, id, localpart string, + ctx context.Context, id, localpart string, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) - _, err := stmt.ExecContext(ctx, id, localpart) + // "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) + docId := fmt.Sprintf("%s_%s", localpart, id) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var options = cosmosdbapi.GetDeleteDocumentOptions(pk) + var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + cosmosDocId, + options) + + if err != nil { + return err + } return err } func (s *devicesStatements) deleteDevices( - ctx context.Context, txn *sql.Tx, localpart string, devices []string, + ctx context.Context, localpart string, devices []string, ) error { - orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1) - prep, err := s.db.Prepare(orig) + // "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []DeviceCosmosData + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": localpart, + "@x3": devices, + } + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartStmt, params) + var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) if err != nil { return err } - stmt := sqlutil.TxStmt(txn, prep) - params := make([]interface{}, len(devices)+1) - params[0] = localpart - for i, v := range devices { - params[i+1] = v + for _, item := range response { + s.deleteDevice(ctx, item.Device.ID, item.Device.Localpart) } - _, err = stmt.ExecContext(ctx, params...) return err } func (s *devicesStatements) deleteDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, localpart, exceptDeviceID string, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) + // "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []DeviceCosmosData + exceptDevices := []string{ + exceptDeviceID, + } + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": localpart, + "@x3": exceptDevices, + } + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartStmt, params) + var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { + return err + } + for _, item := range response { + s.deleteDevice(ctx, item.Device.ID, item.Device.Localpart) + } return err } func (s *devicesStatements) updateDeviceName( - ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, + ctx context.Context, localpart, deviceID string, displayName *string, ) error { - stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) - _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) - return err + // "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + docId := fmt.Sprintf("%s_%s", localpart, deviceID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + var response, exGet = getDevice(s, ctx, pk, cosmosDocId) + if exGet != nil { + return exGet + } + + response.Device.DisplayName = *displayName + + var _, exReplace = setDevice(s, ctx, pk, *response) + if exReplace != nil { + return exReplace + } + return exReplace } func (s *devicesStatements) selectDeviceByToken( ctx context.Context, accessToken string, ) (*api.Device, error) { - var dev api.Device - var localpart string - stmt := s.selectDeviceByTokenStmt - err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) - if err == nil { - dev.UserID = userutil.MakeUserID(localpart, s.serverName) - dev.AccessToken = accessToken + // "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []DeviceCosmosData + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": accessToken, } - return &dev, err + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(s.selectDeviceByTokenStmt, params) + var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { + return nil, err + } + if len(response) == 0 { + return nil, errors.New(fmt.Sprintf("No Devices found with AccessToken %s", accessToken)) + } + + if err == nil { + result := mapFromDevice(response[0].Device) + return &result, nil + } + return nil, err } // selectDeviceByID retrieves a device from the database with the given user @@ -230,54 +382,48 @@ func (s *devicesStatements) selectDeviceByToken( func (s *devicesStatements) selectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { - var dev api.Device - var displayName sql.NullString - stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) - if err == nil { - dev.ID = deviceID - dev.UserID = userutil.MakeUserID(localpart, s.serverName) - if displayName.Valid { - dev.DisplayName = displayName.String - } + // "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + docId := fmt.Sprintf("%s_%s", localpart, deviceID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + var response, exGet = getDevice(s, ctx, pk, cosmosDocId) + if exGet != nil { + return nil, exGet } - return &dev, err + result := mapFromDevice(response.Device) + return &result, nil } func (s *devicesStatements) selectDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, localpart, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} - rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) - - if err != nil { - return devices, err + // "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []DeviceCosmosData + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": localpart, + "@x3": exceptDeviceID, } - for rows.Next() { - var dev api.Device - var lastseents sql.NullInt64 - var id, displayname, ip, useragent sql.NullString - err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent) - if err != nil { - return devices, err - } - if id.Valid { - dev.ID = id.String - } - if displayname.Valid { - dev.DisplayName = displayname.String - } - if lastseents.Valid { - dev.LastSeenTS = lastseents.Int64 - } - if ip.Valid { - dev.LastSeenIP = ip.String - } - if useragent.Valid { - dev.UserAgent = useragent.String - } + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartExceptIDStmt, params) + var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { + return nil, err + } + for _, item := range response { + dev := mapFromDevice(item.Device) dev.UserID = userutil.MakeUserID(localpart, s.serverName) devices = append(devices, dev) } @@ -286,37 +432,53 @@ func (s *devicesStatements) selectDevicesByLocalpart( } func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { - sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1) - iDeviceIDs := make([]interface{}, len(deviceIDs)) - for i := range deviceIDs { - iDeviceIDs[i] = deviceIDs[i] + // "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)" + var devices []api.Device + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []DeviceCosmosData + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": deviceIDs, } - rows, err := s.db.QueryContext(ctx, sqlQuery, iDeviceIDs...) + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(s.selectDevicesByIDStmt, params) + var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed") - var devices []api.Device - for rows.Next() { - var dev api.Device - var localpart string - var displayName sql.NullString - if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil { - return nil, err - } - if displayName.Valid { - dev.DisplayName = displayName.String - } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + for _, item := range response { + dev := mapFromDevice(item.Device) devices = append(devices, dev) } - return devices, rows.Err() + return devices, nil } -func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { +func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { lastSeenTs := time.Now().UnixNano() / 1000000 - stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) - _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) - return err + + // "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + docId := fmt.Sprintf("%s_%s", localpart, deviceID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + var response, exGet = getDevice(s, ctx, pk, cosmosDocId) + if exGet != nil { + return exGet + } + + response.Device.LastSeenTS = lastSeenTs + + var _, exReplace = setDevice(s, ctx, pk, *response) + if exReplace != nil { + return exReplace + } + return exReplace } diff --git a/userapi/storage/devices/cosmosdb/storage.go b/userapi/storage/devices/cosmosdb/storage.go index d414e9026..a5ddd5977 100644 --- a/userapi/storage/devices/cosmosdb/storage.go +++ b/userapi/storage/devices/cosmosdb/storage.go @@ -15,16 +15,18 @@ package cosmosdb import ( - "github.com/matrix-org/dendrite/internal/cosmosdbutil" "context" "crypto/rand" - "database/sql" "encoding/base64" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas" + "github.com/matrix-org/gomatrixserverlib" _ "github.com/mattn/go-sqlite3" @@ -35,35 +37,32 @@ var deviceIDByteLength = 6 // Database represents a device database. type Database struct { - db *sql.DB - writer sqlutil.Writer - devices devicesStatements + writer sqlutil.Writer + devices devicesStatements + connection cosmosdbapi.CosmosConnection + databaseName string + cosmosConfig cosmosdbapi.CosmosConfig + serverName gomatrixserverlib.ServerName } // NewDatabase creates a new device database func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) { - dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) - db, err := sqlutil.Open(dbProperties) - if err != nil { - return nil, err - } - writer := sqlutil.NewExclusiveWriter() - d := devicesStatements{} + conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString) + config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString) + devices := devicesStatements{} // Create tables before executing migrations so we don't fail if the table is missing, // and THEN prepare statements so we don't fail due to referencing new columns - if err = d.execSchema(db); err != nil { - return nil, err + d := &Database{ + databaseName: "userapi", + devices: devices, + serverName: serverName, + connection: conn, + cosmosConfig: config, } - m := sqlutil.NewMigrations() - deltas.LoadLastSeenTSIP(m) - if err = m.RunDeltas(db, dbProperties); err != nil { - return nil, err - } - if err = d.prepare(db, writer, serverName); err != nil { - return nil, err - } - return &Database{db, writer, d}, nil + err := d.devices.prepare(d, serverName) + + return d, err } // GetDeviceByAccessToken returns the device matching the given access token. @@ -86,7 +85,7 @@ func (d *Database) GetDeviceByID( func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") + return d.devices.selectDevicesByLocalpart(ctx, localpart, "") } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -104,16 +103,14 @@ func (d *Database) CreateDevice( displayName *string, ipAddr, userAgent string, ) (dev *api.Device, returnErr error) { if deviceID != nil { - returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var err error - // Revoke existing tokens for this device - if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { - return err - } + var err error + // Revoke existing tokens for this device + if err = d.devices.deleteDevice(ctx, *deviceID, localpart); err != nil { + return nil, err + } - dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) + dev, err = d.devices.insertDevice(ctx, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return dev, err } else { // We generate device IDs in a loop in case its already taken. // We cap this at going round 5 times to ensure we don't spin forever @@ -124,11 +121,9 @@ func (d *Database) CreateDevice( return } - returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var err error - dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) + var err error + dev, err = d.devices.insertDevice(ctx, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return dev, err if returnErr == nil { return } @@ -154,9 +149,7 @@ func generateDeviceID() (string, error) { func (d *Database) UpdateDevice( ctx context.Context, localpart, deviceID string, displayName *string, ) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) - }) + return d.devices.updateDeviceName(ctx, localpart, deviceID, displayName) } // RemoveDevice revokes a device by deleting the entry in the database @@ -166,12 +159,10 @@ func (d *Database) UpdateDevice( func (d *Database) RemoveDevice( ctx context.Context, deviceID, localpart string, ) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { - return err - } - return nil - }) + if err := d.devices.deleteDevice(ctx, deviceID, localpart); err != nil { + return err + } + return nil } // RemoveDevices revokes one or more devices by deleting the entry in the database @@ -181,12 +172,10 @@ func (d *Database) RemoveDevice( func (d *Database) RemoveDevices( ctx context.Context, localpart string, devices []string, ) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { - return err - } - return nil - }) + if err := d.devices.deleteDevices(ctx, localpart, devices); err != nil { + return err + } + return nil } // RemoveAllDevices revokes devices by deleting the entry in the @@ -195,22 +184,17 @@ func (d *Database) RemoveDevices( func (d *Database) RemoveAllDevices( ctx context.Context, localpart, exceptDeviceID string, ) (devices []api.Device, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) - if err != nil { - return err - } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { - return err - } - return nil - }) - return + devices, err = d.devices.selectDevicesByLocalpart(ctx, localpart, exceptDeviceID) + if err != nil { + return nil, err + } + if err := d.devices.deleteDevicesByLocalpart(ctx, localpart, exceptDeviceID); err != nil { + return nil, err + } + return devices, nil } // UpdateDeviceLastSeen updates a the last seen timestamp and the ip address func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) - }) + return d.devices.updateDeviceLastSeen(ctx, localpart, deviceID, ipAddr) }