mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-26 08:13:09 -06:00
- Implement Cosmos for the devices_table (#4)
- Use the ConnectionString in the YAML to include the Tenant - Revert all other non implemented tables back to use SQLLite3
This commit is contained in:
parent
dfd5d445ac
commit
b696923333
|
|
@ -16,7 +16,6 @@
|
|||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
|
|
@ -38,7 +37,6 @@ type Database struct {
|
|||
|
||||
// NewDatabase opens a new database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
var result Database
|
||||
var err error
|
||||
if result.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
|
|
|
|||
|
|
@ -354,12 +354,12 @@ user_api:
|
|||
listen: http://localhost:7781
|
||||
connect: http://localhost:7781
|
||||
account_database:
|
||||
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
|
||||
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=criticalarc.com;"
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
device_database:
|
||||
connection_string: file:userapi_devices.db
|
||||
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=criticalarc.com;"
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@
|
|||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"database/sql"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
|
@ -38,7 +37,6 @@ type Database struct {
|
|||
|
||||
// NewDatabase opens a new database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (*Database, error) {
|
||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
var d Database
|
||||
var err error
|
||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
|
|
|
|||
6
internal/cosmosdbapi/cosmosconfig.go
Normal file
6
internal/cosmosdbapi/cosmosconfig.go
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
package cosmosdbapi
|
||||
|
||||
type CosmosConfig struct {
|
||||
DatabaseName string
|
||||
ContainerName string
|
||||
}
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
package cosmosdbapi
|
||||
|
||||
type Tenant struct {
|
||||
DatabaseName string
|
||||
TenantName string
|
||||
}
|
||||
|
||||
//TODO: Move into Config or the JWT
|
||||
func DefaultConfig() Tenant {
|
||||
return Tenant{
|
||||
DatabaseName: "safezone_local",
|
||||
TenantName: "criticalarc.com",
|
||||
}
|
||||
}
|
||||
|
|
@ -1,22 +1,50 @@
|
|||
package cosmosdbutil
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
func GetConnectionString(d *config.DataSource) config.DataSource {
|
||||
const accountEndpointName = "AccountEndpoint"
|
||||
const accountKeyName = "AccountKey"
|
||||
const databaseName = "DatabaseName"
|
||||
const containerName = "ContainerName"
|
||||
|
||||
func getConnectionString(d *config.DataSource) config.DataSource {
|
||||
var connString string
|
||||
connString = string(*d)
|
||||
return config.DataSource(strings.Replace(connString, "cosmosdb:", "", 1))
|
||||
}
|
||||
|
||||
func GetConnectionProperties(connectionString string) map[string]string {
|
||||
func getConnectionProperties(connectionString string) map[string]string {
|
||||
connectionItemsRaw := strings.Split(connectionString, ";")
|
||||
connectionItems := map[string]string{}
|
||||
for _, item := range connectionItemsRaw {
|
||||
itemSplit := strings.SplitN(item, "=", 2)
|
||||
connectionItems[itemSplit[0]] = itemSplit[1]
|
||||
if len(item) > 0 {
|
||||
itemSplit := strings.SplitN(item, "=", 2)
|
||||
connectionItems[itemSplit[0]] = itemSplit[1]
|
||||
}
|
||||
}
|
||||
return connectionItems
|
||||
}
|
||||
}
|
||||
|
||||
func GetCosmosConnection(d *config.DataSource) cosmosdbapi.CosmosConnection {
|
||||
connString := getConnectionString(d)
|
||||
connMap := getConnectionProperties(string(connString))
|
||||
accountEndpoint := connMap[accountEndpointName]
|
||||
accountKey := connMap[accountKeyName]
|
||||
return cosmosdbapi.GetCosmosConnection(accountEndpoint, accountKey)
|
||||
}
|
||||
|
||||
func GetCosmosConfig(d *config.DataSource) cosmosdbapi.CosmosConfig {
|
||||
connString := getConnectionString(d)
|
||||
connMap := getConnectionProperties(string(connString))
|
||||
database := connMap[databaseName]
|
||||
container := connMap[containerName]
|
||||
return cosmosdbapi.CosmosConfig{
|
||||
DatabaseName: database,
|
||||
ContainerName: container,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,14 +15,12 @@
|
|||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
|
||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@
|
|||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
|
|
@ -37,7 +36,6 @@ type Database struct {
|
|||
|
||||
// Open opens a postgres database.
|
||||
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
d := Database{
|
||||
writer: sqlutil.NewExclusiveWriter(),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@
|
|||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
|
|
@ -38,7 +37,6 @@ type Database struct {
|
|||
|
||||
// Open a sqlite database.
|
||||
func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
|
||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
var d Database
|
||||
var db *sql.DB
|
||||
var err error
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
package kafka
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"github.com/Shopify/sarama"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/naffka"
|
||||
|
|
@ -47,8 +46,9 @@ func setupNaffka(cfg *config.Kafka) (sarama.Consumer, sarama.SyncProducer) {
|
|||
if naffkaInstance != nil {
|
||||
return naffkaInstance, naffkaInstance
|
||||
}
|
||||
if(cfg.Database.ConnectionString.IsCosmosDB()) {
|
||||
cfg.Database.ConnectionString = cosmosdbutil.GetConnectionString(&cfg.Database.ConnectionString)
|
||||
if cfg.Database.ConnectionString.IsCosmosDB() {
|
||||
//TODO: What do we do for Nafka
|
||||
// cfg.Database.ConnectionString = cosmosdbutil.GetConnectionString(&cfg.Database.ConnectionString)
|
||||
}
|
||||
|
||||
naffkaDB, err := naffkaStorage.NewDatabase(string(cfg.Database.ConnectionString))
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@
|
|||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"context"
|
||||
|
||||
"golang.org/x/crypto/ed25519"
|
||||
|
|
@ -45,7 +44,6 @@ func NewDatabase(
|
|||
serverKey ed25519.PublicKey,
|
||||
serverKeyID gomatrixserverlib.KeyID,
|
||||
) (*Database, error) {
|
||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@
|
|||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"database/sql"
|
||||
|
||||
// Import the sqlite3 package
|
||||
|
|
@ -41,7 +40,6 @@ type SyncServerDatasource struct {
|
|||
// NewDatabase creates a new sync server database
|
||||
// nolint: gocyclo
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) {
|
||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
var d SyncServerDatasource
|
||||
var err error
|
||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
|
|
|
|||
|
|
@ -84,7 +84,6 @@ func (s *accountDataStatements) insertAccountData(
|
|||
Content: content,
|
||||
}
|
||||
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
|
||||
id := ""
|
||||
if roomID == "" {
|
||||
|
|
@ -94,9 +93,9 @@ func (s *accountDataStatements) insertAccountData(
|
|||
}
|
||||
|
||||
var dbData = AccountDataCosmosData{
|
||||
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id),
|
||||
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, id),
|
||||
Cn: dbCollectionName,
|
||||
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
||||
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
|
||||
Timestamp: time.Now().Unix(),
|
||||
AccountData: result,
|
||||
}
|
||||
|
|
@ -104,8 +103,8 @@ func (s *accountDataStatements) insertAccountData(
|
|||
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
dbData,
|
||||
options)
|
||||
|
||||
|
|
@ -120,9 +119,8 @@ func (s *accountDataStatements) selectAccountData(
|
|||
error,
|
||||
) {
|
||||
// "SELECT room_id, type, content FROM account_data WHERE localpart = $1"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response := []AccountDataCosmosData{}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
|
|
@ -132,8 +130,8 @@ func (s *accountDataStatements) selectAccountData(
|
|||
var query = cosmosdbapi.GetQuery(s.selectAccountDataStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
|
@ -167,9 +165,8 @@ func (s *accountDataStatements) selectAccountDataByType(
|
|||
var bytes []byte
|
||||
|
||||
// "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response := []AccountDataCosmosData{}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
|
|
@ -181,8 +178,8 @@ func (s *accountDataStatements) selectAccountDataByType(
|
|||
var query = cosmosdbapi.GetQuery(s.selectAccountDataByTypeStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
|
|
|||
|
|
@ -87,26 +87,26 @@ func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.Serv
|
|||
return
|
||||
}
|
||||
|
||||
func getAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*AccountCosmosData, error) {
|
||||
func getAccount(s *accountsStatements, ctx context.Context, pk string, docId string) (*AccountCosmosData, error) {
|
||||
response := AccountCosmosData{}
|
||||
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
docId,
|
||||
optionsGet,
|
||||
&response)
|
||||
return &response, ex
|
||||
}
|
||||
|
||||
func setAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, account AccountCosmosData) (*AccountCosmosData, error) {
|
||||
func setAccount(s *accountsStatements, ctx context.Context, pk string, account AccountCosmosData) (*AccountCosmosData, error) {
|
||||
response := AccountCosmosData{}
|
||||
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, account.ETag)
|
||||
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
account.Id,
|
||||
&account,
|
||||
optionsReplace)
|
||||
|
|
@ -153,13 +153,12 @@ func (s *accountsStatements) insertAccount(
|
|||
data.PasswordHash = hash
|
||||
data.IsDeactivated = false
|
||||
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
|
||||
var dbData = AccountCosmosData{
|
||||
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart),
|
||||
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Localpart),
|
||||
Cn: dbCollectionName,
|
||||
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
||||
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
|
||||
Timestamp: time.Now().Unix(),
|
||||
Account: data,
|
||||
}
|
||||
|
|
@ -167,8 +166,8 @@ func (s *accountsStatements) insertAccount(
|
|||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
dbData,
|
||||
options)
|
||||
|
||||
|
|
@ -184,19 +183,18 @@ func (s *accountsStatements) updatePassword(
|
|||
) (err error) {
|
||||
|
||||
// "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
var response, exGet = getAccount(s, ctx, config, pk, docId)
|
||||
var response, exGet = getAccount(s, ctx, pk, docId)
|
||||
if exGet != nil {
|
||||
return exGet
|
||||
}
|
||||
|
||||
response.Account.PasswordHash = passwordHash
|
||||
|
||||
var _, exReplace = setAccount(s, ctx, config, pk, *response)
|
||||
var _, exReplace = setAccount(s, ctx, pk, *response)
|
||||
if exReplace != nil {
|
||||
return exReplace
|
||||
}
|
||||
|
|
@ -208,19 +206,18 @@ func (s *accountsStatements) deactivateAccount(
|
|||
) (err error) {
|
||||
|
||||
// "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
var response, exGet = getAccount(s, ctx, config, pk, docId)
|
||||
var response, exGet = getAccount(s, ctx, pk, docId)
|
||||
if exGet != nil {
|
||||
return exGet
|
||||
}
|
||||
|
||||
response.Account.IsDeactivated = true
|
||||
|
||||
var _, exReplace = setAccount(s, ctx, config, pk, *response)
|
||||
var _, exReplace = setAccount(s, ctx, pk, *response)
|
||||
if exReplace != nil {
|
||||
return exReplace
|
||||
}
|
||||
|
|
@ -232,9 +229,8 @@ func (s *accountsStatements) selectPasswordHash(
|
|||
) (hash string, err error) {
|
||||
|
||||
// "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response := []AccountCosmosData{}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
|
|
@ -244,8 +240,8 @@ func (s *accountsStatements) selectPasswordHash(
|
|||
var query = cosmosdbapi.GetQuery(s.selectPasswordHashStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
|
@ -271,9 +267,8 @@ func (s *accountsStatements) selectAccountByLocalpart(
|
|||
var acc api.Account
|
||||
|
||||
// "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response := []AccountCosmosData{}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
|
|
@ -283,8 +278,8 @@ func (s *accountsStatements) selectAccountByLocalpart(
|
|||
var query = cosmosdbapi.GetQuery(s.selectAccountByLocalpartStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
|
@ -309,9 +304,8 @@ func (s *accountsStatements) selectNewNumericLocalpart(
|
|||
) (id int64, err error) {
|
||||
|
||||
// "SELECT COUNT(localpart) FROM account_accounts"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []AccountCosmosUserCount
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
|
|
@ -320,8 +314,8 @@ func (s *accountsStatements) selectNewNumericLocalpart(
|
|||
var query = cosmosdbapi.GetQuery(s.selectNewNumericLocalpartStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
|
|
|||
|
|
@ -85,13 +85,12 @@ func (s *tokenStatements) insertToken(
|
|||
ExpiresAtMS: expiresAtMS,
|
||||
}
|
||||
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
|
||||
|
||||
var dbData = OpenIdTokenCosmosData{
|
||||
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Token),
|
||||
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Token),
|
||||
Cn: dbCollectionName,
|
||||
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
||||
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
|
||||
Timestamp: time.Now().Unix(),
|
||||
OpenIdToken: mapToToken(*result),
|
||||
}
|
||||
|
|
@ -99,8 +98,8 @@ func (s *tokenStatements) insertToken(
|
|||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
dbData,
|
||||
options)
|
||||
|
||||
|
|
@ -120,9 +119,8 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes(
|
|||
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||
|
||||
// "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response := []OpenIdTokenCosmosData{}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
|
|
@ -132,8 +130,8 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes(
|
|||
var query = cosmosdbapi.GetQuery(s.selectTokenStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
|
|
|||
|
|
@ -87,25 +87,25 @@ func (s *profilesStatements) prepare(db *Database) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func getProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*ProfileCosmosData, error) {
|
||||
func getProfile(s *profilesStatements, ctx context.Context, pk string, docId string) (*ProfileCosmosData, error) {
|
||||
response := ProfileCosmosData{}
|
||||
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
docId,
|
||||
optionsGet,
|
||||
&response)
|
||||
return &response, ex
|
||||
}
|
||||
|
||||
func setProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) {
|
||||
func setProfile(s *profilesStatements, ctx context.Context, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) {
|
||||
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, profile.ETag)
|
||||
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
profile.Id,
|
||||
&profile,
|
||||
optionsReplace)
|
||||
|
|
@ -121,13 +121,12 @@ func (s *profilesStatements) insertProfile(
|
|||
Localpart: localpart,
|
||||
}
|
||||
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||
|
||||
var dbData = ProfileCosmosData{
|
||||
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart),
|
||||
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Localpart),
|
||||
Cn: dbCollectionName,
|
||||
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
||||
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
|
||||
Timestamp: time.Now().Unix(),
|
||||
Profile: mapToProfile(*result),
|
||||
}
|
||||
|
|
@ -135,8 +134,8 @@ func (s *profilesStatements) insertProfile(
|
|||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
dbData,
|
||||
options)
|
||||
|
||||
|
|
@ -148,9 +147,8 @@ func (s *profilesStatements) selectProfileByLocalpart(
|
|||
) (*authtypes.Profile, error) {
|
||||
|
||||
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response := []ProfileCosmosData{}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
|
|
@ -160,8 +158,8 @@ func (s *profilesStatements) selectProfileByLocalpart(
|
|||
var query = cosmosdbapi.GetQuery(s.selectProfileByLocalpartStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
|
@ -187,19 +185,18 @@ func (s *profilesStatements) setAvatarURL(
|
|||
) (err error) {
|
||||
|
||||
// "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
|
||||
|
||||
var response, exGet = getProfile(s, ctx, config, pk, docId)
|
||||
var response, exGet = getProfile(s, ctx, pk, docId)
|
||||
if exGet != nil {
|
||||
return exGet
|
||||
}
|
||||
|
||||
response.Profile.AvatarURL = avatarURL
|
||||
|
||||
var _, exReplace = setProfile(s, ctx, config, pk, *response)
|
||||
var _, exReplace = setProfile(s, ctx, pk, *response)
|
||||
if exReplace != nil {
|
||||
return exReplace
|
||||
}
|
||||
|
|
@ -211,18 +208,17 @@ func (s *profilesStatements) setDisplayName(
|
|||
) (err error) {
|
||||
|
||||
// "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
|
||||
var response, exGet = getProfile(s, ctx, config, pk, docId)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
|
||||
var response, exGet = getProfile(s, ctx, pk, docId)
|
||||
if exGet != nil {
|
||||
return exGet
|
||||
}
|
||||
|
||||
response.Profile.DisplayName = displayName
|
||||
|
||||
var _, exReplace = setProfile(s, ctx, config, pk, *response)
|
||||
var _, exReplace = setProfile(s, ctx, pk, *response)
|
||||
if exReplace != nil {
|
||||
return exReplace
|
||||
}
|
||||
|
|
@ -235,9 +231,8 @@ func (s *profilesStatements) selectProfilesBySearch(
|
|||
var profiles []authtypes.Profile
|
||||
|
||||
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response := []ProfileCosmosData{}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
|
|
@ -248,8 +243,8 @@ func (s *profilesStatements) selectProfilesBySearch(
|
|||
var query = cosmosdbapi.GetQuery(s.selectProfilesBySearchStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
|
|
|||
|
|
@ -48,20 +48,19 @@ type Database struct {
|
|||
|
||||
databaseName string
|
||||
connection cosmosdbapi.CosmosConnection
|
||||
cosmosConfig cosmosdbapi.CosmosConfig
|
||||
}
|
||||
|
||||
// NewDatabase creates a new accounts and profiles database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
|
||||
connString := cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
connMap := cosmosdbutil.GetConnectionProperties(string(connString))
|
||||
accountEndpoint := connMap["AccountEndpoint"]
|
||||
accountKey := connMap["AccountKey"]
|
||||
conn := cosmosdbapi.GetCosmosConnection(accountEndpoint, accountKey)
|
||||
conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
|
||||
config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
|
||||
|
||||
d := &Database{
|
||||
serverName: serverName,
|
||||
databaseName: "userapi",
|
||||
connection: conn,
|
||||
cosmosConfig: config,
|
||||
// db: db,
|
||||
// writer: sqlutil.NewExclusiveWriter(),
|
||||
// bcryptCost: bcryptCost,
|
||||
|
|
|
|||
|
|
@ -74,9 +74,8 @@ func (s *threepidStatements) selectLocalpartForThreePID(
|
|||
) (localpart string, err error) {
|
||||
|
||||
// "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response := []ThreePIDCosmosData{}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
|
|
@ -87,8 +86,8 @@ func (s *threepidStatements) selectLocalpartForThreePID(
|
|||
var query = cosmosdbapi.GetQuery(s.selectLocalpartForThreePIDStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
|
@ -109,9 +108,8 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
|
|||
) (threepids []authtypes.ThreePID, err error) {
|
||||
|
||||
// "SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
response := []ThreePIDCosmosData{}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
|
|
@ -121,8 +119,8 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
|
|||
var query = cosmosdbapi.GetQuery(s.selectThreePIDsForLocalpartStmt, params)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
options)
|
||||
|
|
@ -156,14 +154,14 @@ func (s *threepidStatements) insertThreePID(
|
|||
ThreePID: threepid,
|
||||
}
|
||||
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
|
||||
id := fmt.Sprintf("%s_%s", threepid, medium)
|
||||
docId := fmt.Sprintf("%s_%s", threepid, medium)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
var dbData = ThreePIDCosmosData{
|
||||
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id),
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
||||
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
|
||||
Timestamp: time.Now().Unix(),
|
||||
ThreePID: result,
|
||||
}
|
||||
|
|
@ -171,8 +169,8 @@ func (s *threepidStatements) insertThreePID(
|
|||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
dbData,
|
||||
options)
|
||||
|
||||
|
|
@ -186,16 +184,16 @@ func (s *threepidStatements) deleteThreePID(
|
|||
ctx context.Context, threepid string, medium string) (err error) {
|
||||
|
||||
// "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||
var config = cosmosdbapi.DefaultConfig()
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||
id := fmt.Sprintf("%s_%s", threepid, medium)
|
||||
pk := cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
||||
docId := fmt.Sprintf("%s_%s", threepid, medium)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
|
||||
_, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
|
||||
ctx,
|
||||
config.DatabaseName,
|
||||
config.TenantName,
|
||||
id,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
cosmosDocId,
|
||||
options)
|
||||
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -16,127 +16,145 @@ package cosmosdb
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const devicesSchema = `
|
||||
-- This sequence is used for automatic allocation of session_id.
|
||||
-- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
|
||||
// const devicesSchema = `
|
||||
// -- This sequence is used for automatic allocation of session_id.
|
||||
// -- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
|
||||
|
||||
-- Stores data about devices.
|
||||
CREATE TABLE IF NOT EXISTS device_devices (
|
||||
access_token TEXT PRIMARY KEY,
|
||||
session_id INTEGER,
|
||||
device_id TEXT ,
|
||||
localpart TEXT ,
|
||||
created_ts BIGINT,
|
||||
display_name TEXT,
|
||||
last_seen_ts BIGINT,
|
||||
ip TEXT,
|
||||
user_agent TEXT,
|
||||
// -- Stores data about devices.
|
||||
// CREATE TABLE IF NOT EXISTS device_devices (
|
||||
// access_token TEXT PRIMARY KEY,
|
||||
// session_id INTEGER,
|
||||
// device_id TEXT ,
|
||||
// localpart TEXT ,
|
||||
// created_ts BIGINT,
|
||||
// display_name TEXT,
|
||||
// last_seen_ts BIGINT,
|
||||
// ip TEXT,
|
||||
// user_agent TEXT,
|
||||
|
||||
UNIQUE (localpart, device_id)
|
||||
);
|
||||
`
|
||||
// UNIQUE (localpart, device_id)
|
||||
// );
|
||||
// `
|
||||
|
||||
const insertDeviceSQL = "" +
|
||||
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||||
type DeviceCosmos struct {
|
||||
ID string `json:"device_id"`
|
||||
UserID string `json:"user_id"`
|
||||
// The access_token granted to this device.
|
||||
// This uniquely identifies the device from all other devices and clients.
|
||||
AccessToken string `json:"access_token"`
|
||||
// The unique ID of the session identified by the access token.
|
||||
// Can be used as a secure substitution in places where data needs to be
|
||||
// associated with access tokens.
|
||||
SessionID int64 `json:"session_id"`
|
||||
DisplayName string `json:"display_name"`
|
||||
LastSeenTS int64 `json:"last_seen_ts"`
|
||||
LastSeenIP string `json:"last_seen_ip"`
|
||||
Localpart string `json:"local_part"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
// If the device is for an appservice user,
|
||||
// this is the appservice ID.
|
||||
AppserviceID string `json:"app_service_id"`
|
||||
}
|
||||
|
||||
const selectDevicesCountSQL = "" +
|
||||
"SELECT COUNT(access_token) FROM device_devices"
|
||||
type DeviceCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
Device DeviceCosmos `json:"mx_userapi_device"`
|
||||
}
|
||||
|
||||
const selectDeviceByTokenSQL = "" +
|
||||
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
|
||||
|
||||
const selectDeviceByIDSQL = "" +
|
||||
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||
|
||||
const selectDevicesByLocalpartSQL = "" +
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||
|
||||
const updateDeviceNameSQL = "" +
|
||||
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
|
||||
const deleteDeviceSQL = "" +
|
||||
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
|
||||
|
||||
const deleteDevicesByLocalpartSQL = "" +
|
||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||
|
||||
const deleteDevicesSQL = "" +
|
||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||
|
||||
const selectDevicesByIDSQL = "" +
|
||||
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
|
||||
|
||||
const updateDeviceLastSeen = "" +
|
||||
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
|
||||
type DeviceCosmosSessionCount struct {
|
||||
SessionCount int64 `json:"sessioncount"`
|
||||
}
|
||||
|
||||
type devicesStatements struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
insertDeviceStmt *sql.Stmt
|
||||
selectDevicesCountStmt *sql.Stmt
|
||||
selectDeviceByTokenStmt *sql.Stmt
|
||||
selectDeviceByIDStmt *sql.Stmt
|
||||
selectDevicesByIDStmt *sql.Stmt
|
||||
selectDevicesByLocalpartStmt *sql.Stmt
|
||||
updateDeviceNameStmt *sql.Stmt
|
||||
updateDeviceLastSeenStmt *sql.Stmt
|
||||
deleteDeviceStmt *sql.Stmt
|
||||
deleteDevicesByLocalpartStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
db *Database
|
||||
selectDevicesCountStmt string
|
||||
selectDeviceByTokenStmt string
|
||||
// selectDeviceByIDStmt *sql.Stmt
|
||||
selectDevicesByIDStmt string
|
||||
selectDevicesByLocalpartStmt string
|
||||
selectDevicesByLocalpartExceptIDStmt string
|
||||
serverName gomatrixserverlib.ServerName
|
||||
tableName string
|
||||
}
|
||||
|
||||
func (s *devicesStatements) execSchema(db *sql.DB) error {
|
||||
_, err := db.Exec(devicesSchema)
|
||||
return err
|
||||
func mapFromDevice(db DeviceCosmos) api.Device {
|
||||
return api.Device{
|
||||
AccessToken: db.AccessToken,
|
||||
AppserviceID: db.AppserviceID,
|
||||
ID: db.ID,
|
||||
LastSeenIP: db.LastSeenIP,
|
||||
LastSeenTS: db.LastSeenTS,
|
||||
SessionID: db.SessionID,
|
||||
UserAgent: db.UserAgent,
|
||||
UserID: db.UserID,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
|
||||
func mapTodevice(api api.Device, s *devicesStatements) DeviceCosmos {
|
||||
localPart, _ := userutil.ParseUsernameParam(api.UserID, &s.serverName)
|
||||
return DeviceCosmos{
|
||||
AccessToken: api.AccessToken,
|
||||
AppserviceID: api.AppserviceID,
|
||||
ID: api.ID,
|
||||
LastSeenIP: api.LastSeenIP,
|
||||
LastSeenTS: api.LastSeenTS,
|
||||
Localpart: localPart,
|
||||
SessionID: api.SessionID,
|
||||
UserAgent: api.UserAgent,
|
||||
UserID: api.UserID,
|
||||
}
|
||||
}
|
||||
|
||||
func getDevice(s *devicesStatements, ctx context.Context, pk string, docId string) (*DeviceCosmosData, error) {
|
||||
response := DeviceCosmosData{}
|
||||
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
|
||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
docId,
|
||||
optionsGet,
|
||||
&response)
|
||||
return &response, ex
|
||||
}
|
||||
|
||||
func setDevice(s *devicesStatements, ctx context.Context, pk string, device DeviceCosmosData) (*DeviceCosmosData, error) {
|
||||
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, device.ETag)
|
||||
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
device.Id,
|
||||
&device,
|
||||
optionsReplace)
|
||||
return &device, ex
|
||||
}
|
||||
|
||||
func (s *devicesStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
|
||||
s.db = db
|
||||
s.writer = writer
|
||||
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
|
||||
return
|
||||
}
|
||||
s.selectDevicesCountStmt = "select count(c._ts) as sessioncount from c where c._cn = @x1"
|
||||
s.selectDevicesByLocalpartStmt = "select * from c where c._cn = @x1 and c.mx_userapi_device.local_part = @x2 and ARRAY_CONTAINS(@x3, c.mx_userapi_device.device_id)"
|
||||
s.selectDevicesByLocalpartExceptIDStmt = "select * from c where c._cn = @x1 and c.mx_userapi_device.local_part = @x2 and c.mx_userapi_device.device_id != @x3"
|
||||
s.selectDeviceByTokenStmt = "select * from c where c._cn = @x1 and c.mx_userapi_device.access_token = @x2"
|
||||
s.selectDevicesByIDStmt = "select * from c where c._cn = @x1 and ARRAY_CONTAINS(@x2, c.mx_userapi_device.device_id)"
|
||||
s.serverName = server
|
||||
s.tableName = "device_devices"
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -144,85 +162,219 @@ func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server go
|
|||
// Returns an error if the user already has a device with the given device ID.
|
||||
// Returns the device on success.
|
||||
func (s *devicesStatements) insertDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||
ctx context.Context, id, localpart, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
) (*api.Device, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
var sessionID int64
|
||||
countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
|
||||
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
|
||||
// "SELECT COUNT(access_token) FROM device_devices"
|
||||
// HACK: Do we need a Cosmos Table for the sequence?
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []DeviceCosmosSessionCount
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
}
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectDevicesCountStmt, params)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionID = response[0].SessionCount
|
||||
sessionID++
|
||||
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &api.Device{
|
||||
// "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
||||
// " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||||
|
||||
data := DeviceCosmos{
|
||||
ID: id,
|
||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
AccessToken: accessToken,
|
||||
SessionID: sessionID,
|
||||
LastSeenTS: createdTimeMS,
|
||||
LastSeenIP: ipAddr,
|
||||
Localpart: localpart,
|
||||
UserAgent: userAgent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// access_token TEXT PRIMARY KEY,
|
||||
// UNIQUE (localpart, device_id)
|
||||
// HACK: check for duplicate PK as we are using the UNIQUE key for the DocId
|
||||
docId := fmt.Sprintf("%s_%s", localpart, id)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
var dbData = DeviceCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
|
||||
Timestamp: time.Now().Unix(),
|
||||
Device: data,
|
||||
}
|
||||
|
||||
var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||
var _, _, errCreate = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
dbData,
|
||||
optionsCreate)
|
||||
|
||||
if errCreate != nil {
|
||||
return nil, errCreate
|
||||
}
|
||||
|
||||
var result = mapFromDevice(dbData.Device)
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (s *devicesStatements) deleteDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||
ctx context.Context, id, localpart string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||
_, err := stmt.ExecContext(ctx, id, localpart)
|
||||
// "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||
docId := fmt.Sprintf("%s_%s", localpart, id)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
cosmosDocId,
|
||||
options)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) deleteDevices(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||
ctx context.Context, localpart string, devices []string,
|
||||
) error {
|
||||
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
|
||||
prep, err := s.db.Prepare(orig)
|
||||
// "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []DeviceCosmosData
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": localpart,
|
||||
"@x3": devices,
|
||||
}
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartStmt, params)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, prep)
|
||||
params := make([]interface{}, len(devices)+1)
|
||||
params[0] = localpart
|
||||
for i, v := range devices {
|
||||
params[i+1] = v
|
||||
for _, item := range response {
|
||||
s.deleteDevice(ctx, item.Device.ID, item.Device.Localpart)
|
||||
}
|
||||
_, err = stmt.ExecContext(ctx, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) deleteDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
ctx context.Context, localpart, exceptDeviceID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
|
||||
// "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []DeviceCosmosData
|
||||
exceptDevices := []string{
|
||||
exceptDeviceID,
|
||||
}
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": localpart,
|
||||
"@x3": exceptDevices,
|
||||
}
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartStmt, params)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, item := range response {
|
||||
s.deleteDevice(ctx, item.Device.ID, item.Device.Localpart)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) updateDeviceName(
|
||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
||||
return err
|
||||
// "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
|
||||
if exGet != nil {
|
||||
return exGet
|
||||
}
|
||||
|
||||
response.Device.DisplayName = *displayName
|
||||
|
||||
var _, exReplace = setDevice(s, ctx, pk, *response)
|
||||
if exReplace != nil {
|
||||
return exReplace
|
||||
}
|
||||
return exReplace
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDeviceByToken(
|
||||
ctx context.Context, accessToken string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
stmt := s.selectDeviceByTokenStmt
|
||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
|
||||
if err == nil {
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.AccessToken = accessToken
|
||||
// "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []DeviceCosmosData
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": accessToken,
|
||||
}
|
||||
return &dev, err
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectDeviceByTokenStmt, params)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(response) == 0 {
|
||||
return nil, errors.New(fmt.Sprintf("No Devices found with AccessToken %s", accessToken))
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
result := mapFromDevice(response[0].Device)
|
||||
return &result, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// selectDeviceByID retrieves a device from the database with the given user
|
||||
|
|
@ -230,54 +382,48 @@ func (s *devicesStatements) selectDeviceByToken(
|
|||
func (s *devicesStatements) selectDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
var displayName sql.NullString
|
||||
stmt := s.selectDeviceByIDStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
|
||||
if err == nil {
|
||||
dev.ID = deviceID
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
if displayName.Valid {
|
||||
dev.DisplayName = displayName.String
|
||||
}
|
||||
// "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
|
||||
if exGet != nil {
|
||||
return nil, exGet
|
||||
}
|
||||
return &dev, err
|
||||
result := mapFromDevice(response.Device)
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
ctx context.Context, localpart, exceptDeviceID string,
|
||||
) ([]api.Device, error) {
|
||||
devices := []api.Device{}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
|
||||
|
||||
if err != nil {
|
||||
return devices, err
|
||||
// "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []DeviceCosmosData
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": localpart,
|
||||
"@x3": exceptDeviceID,
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var dev api.Device
|
||||
var lastseents sql.NullInt64
|
||||
var id, displayname, ip, useragent sql.NullString
|
||||
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
|
||||
if err != nil {
|
||||
return devices, err
|
||||
}
|
||||
if id.Valid {
|
||||
dev.ID = id.String
|
||||
}
|
||||
if displayname.Valid {
|
||||
dev.DisplayName = displayname.String
|
||||
}
|
||||
if lastseents.Valid {
|
||||
dev.LastSeenTS = lastseents.Int64
|
||||
}
|
||||
if ip.Valid {
|
||||
dev.LastSeenIP = ip.String
|
||||
}
|
||||
if useragent.Valid {
|
||||
dev.UserAgent = useragent.String
|
||||
}
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartExceptIDStmt, params)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, item := range response {
|
||||
dev := mapFromDevice(item.Device)
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
|
|
@ -286,37 +432,53 @@ func (s *devicesStatements) selectDevicesByLocalpart(
|
|||
}
|
||||
|
||||
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
|
||||
iDeviceIDs := make([]interface{}, len(deviceIDs))
|
||||
for i := range deviceIDs {
|
||||
iDeviceIDs[i] = deviceIDs[i]
|
||||
// "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
|
||||
var devices []api.Device
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []DeviceCosmosData
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": deviceIDs,
|
||||
}
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, sqlQuery, iDeviceIDs...)
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(s.selectDevicesByIDStmt, params)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
|
||||
var devices []api.Device
|
||||
for rows.Next() {
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var displayName sql.NullString
|
||||
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if displayName.Valid {
|
||||
dev.DisplayName = displayName.String
|
||||
}
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
for _, item := range response {
|
||||
dev := mapFromDevice(item.Device)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
return devices, rows.Err()
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
|
||||
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
|
||||
lastSeenTs := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)
|
||||
return err
|
||||
|
||||
// "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
|
||||
if exGet != nil {
|
||||
return exGet
|
||||
}
|
||||
|
||||
response.Device.LastSeenTS = lastSeenTs
|
||||
|
||||
var _, exReplace = setDevice(s, ctx, pk, *response)
|
||||
if exReplace != nil {
|
||||
return exReplace
|
||||
}
|
||||
return exReplace
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,16 +15,18 @@
|
|||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
|
@ -35,35 +37,32 @@ var deviceIDByteLength = 6
|
|||
|
||||
// Database represents a device database.
|
||||
type Database struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
devices devicesStatements
|
||||
writer sqlutil.Writer
|
||||
devices devicesStatements
|
||||
connection cosmosdbapi.CosmosConnection
|
||||
databaseName string
|
||||
cosmosConfig cosmosdbapi.CosmosConfig
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
// NewDatabase creates a new device database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) {
|
||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writer := sqlutil.NewExclusiveWriter()
|
||||
d := devicesStatements{}
|
||||
conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
|
||||
config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
|
||||
devices := devicesStatements{}
|
||||
|
||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
||||
if err = d.execSchema(db); err != nil {
|
||||
return nil, err
|
||||
d := &Database{
|
||||
databaseName: "userapi",
|
||||
devices: devices,
|
||||
serverName: serverName,
|
||||
connection: conn,
|
||||
cosmosConfig: config,
|
||||
}
|
||||
m := sqlutil.NewMigrations()
|
||||
deltas.LoadLastSeenTSIP(m)
|
||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.prepare(db, writer, serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Database{db, writer, d}, nil
|
||||
err := d.devices.prepare(d, serverName)
|
||||
|
||||
return d, err
|
||||
}
|
||||
|
||||
// GetDeviceByAccessToken returns the device matching the given access token.
|
||||
|
|
@ -86,7 +85,7 @@ func (d *Database) GetDeviceByID(
|
|||
func (d *Database) GetDevicesByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) ([]api.Device, error) {
|
||||
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
|
||||
return d.devices.selectDevicesByLocalpart(ctx, localpart, "")
|
||||
}
|
||||
|
||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
|
|
@ -104,16 +103,14 @@ func (d *Database) CreateDevice(
|
|||
displayName *string, ipAddr, userAgent string,
|
||||
) (dev *api.Device, returnErr error) {
|
||||
if deviceID != nil {
|
||||
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
// Revoke existing tokens for this device
|
||||
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
||||
return err
|
||||
}
|
||||
var err error
|
||||
// Revoke existing tokens for this device
|
||||
if err = d.devices.deleteDevice(ctx, *deviceID, localpart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
return err
|
||||
})
|
||||
dev, err = d.devices.insertDevice(ctx, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
return dev, err
|
||||
} else {
|
||||
// We generate device IDs in a loop in case its already taken.
|
||||
// We cap this at going round 5 times to ensure we don't spin forever
|
||||
|
|
@ -124,11 +121,9 @@ func (d *Database) CreateDevice(
|
|||
return
|
||||
}
|
||||
|
||||
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
return err
|
||||
})
|
||||
var err error
|
||||
dev, err = d.devices.insertDevice(ctx, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
return dev, err
|
||||
if returnErr == nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -154,9 +149,7 @@ func generateDeviceID() (string, error) {
|
|||
func (d *Database) UpdateDevice(
|
||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||
) error {
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
||||
})
|
||||
return d.devices.updateDeviceName(ctx, localpart, deviceID, displayName)
|
||||
}
|
||||
|
||||
// RemoveDevice revokes a device by deleting the entry in the database
|
||||
|
|
@ -166,12 +159,10 @@ func (d *Database) UpdateDevice(
|
|||
func (d *Database) RemoveDevice(
|
||||
ctx context.Context, deviceID, localpart string,
|
||||
) error {
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err := d.devices.deleteDevice(ctx, deviceID, localpart); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
||||
|
|
@ -181,12 +172,10 @@ func (d *Database) RemoveDevice(
|
|||
func (d *Database) RemoveDevices(
|
||||
ctx context.Context, localpart string, devices []string,
|
||||
) error {
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err := d.devices.deleteDevices(ctx, localpart, devices); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveAllDevices revokes devices by deleting the entry in the
|
||||
|
|
@ -195,22 +184,17 @@ func (d *Database) RemoveDevices(
|
|||
func (d *Database) RemoveAllDevices(
|
||||
ctx context.Context, localpart, exceptDeviceID string,
|
||||
) (devices []api.Device, err error) {
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
devices, err = d.devices.selectDevicesByLocalpart(ctx, localpart, exceptDeviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := d.devices.deleteDevicesByLocalpart(ctx, localpart, exceptDeviceID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
|
||||
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
|
||||
})
|
||||
return d.devices.updateDeviceLastSeen(ctx, localpart, deviceID, ipAddr)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue