mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-28 09:13:09 -06:00
- Implement Cosmos for the devices_table (#4)
- Use the ConnectionString in the YAML to include the Tenant - Revert all other non implemented tables back to use SQLLite3
This commit is contained in:
parent
dfd5d445ac
commit
b696923333
|
|
@ -16,7 +16,6 @@
|
||||||
package cosmosdb
|
package cosmosdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
|
|
@ -38,7 +37,6 @@ type Database struct {
|
||||||
|
|
||||||
// NewDatabase opens a new database
|
// NewDatabase opens a new database
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
|
||||||
var result Database
|
var result Database
|
||||||
var err error
|
var err error
|
||||||
if result.db, err = sqlutil.Open(dbProperties); err != nil {
|
if result.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||||
|
|
|
||||||
|
|
@ -354,12 +354,12 @@ user_api:
|
||||||
listen: http://localhost:7781
|
listen: http://localhost:7781
|
||||||
connect: http://localhost:7781
|
connect: http://localhost:7781
|
||||||
account_database:
|
account_database:
|
||||||
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
|
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=criticalarc.com;"
|
||||||
max_open_conns: 10
|
max_open_conns: 10
|
||||||
max_idle_conns: 2
|
max_idle_conns: 2
|
||||||
conn_max_lifetime: -1
|
conn_max_lifetime: -1
|
||||||
device_database:
|
device_database:
|
||||||
connection_string: file:userapi_devices.db
|
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=criticalarc.com;"
|
||||||
max_open_conns: 10
|
max_open_conns: 10
|
||||||
max_idle_conns: 2
|
max_idle_conns: 2
|
||||||
conn_max_lifetime: -1
|
conn_max_lifetime: -1
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@
|
||||||
package cosmosdb
|
package cosmosdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
|
@ -38,7 +37,6 @@ type Database struct {
|
||||||
|
|
||||||
// NewDatabase opens a new database
|
// NewDatabase opens a new database
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (*Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (*Database, error) {
|
||||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
|
||||||
var d Database
|
var d Database
|
||||||
var err error
|
var err error
|
||||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||||
|
|
|
||||||
6
internal/cosmosdbapi/cosmosconfig.go
Normal file
6
internal/cosmosdbapi/cosmosconfig.go
Normal file
|
|
@ -0,0 +1,6 @@
|
||||||
|
package cosmosdbapi
|
||||||
|
|
||||||
|
type CosmosConfig struct {
|
||||||
|
DatabaseName string
|
||||||
|
ContainerName string
|
||||||
|
}
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
package cosmosdbapi
|
|
||||||
|
|
||||||
type Tenant struct {
|
|
||||||
DatabaseName string
|
|
||||||
TenantName string
|
|
||||||
}
|
|
||||||
|
|
||||||
//TODO: Move into Config or the JWT
|
|
||||||
func DefaultConfig() Tenant {
|
|
||||||
return Tenant{
|
|
||||||
DatabaseName: "safezone_local",
|
|
||||||
TenantName: "criticalarc.com",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,22 +1,50 @@
|
||||||
package cosmosdbutil
|
package cosmosdbutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetConnectionString(d *config.DataSource) config.DataSource {
|
const accountEndpointName = "AccountEndpoint"
|
||||||
|
const accountKeyName = "AccountKey"
|
||||||
|
const databaseName = "DatabaseName"
|
||||||
|
const containerName = "ContainerName"
|
||||||
|
|
||||||
|
func getConnectionString(d *config.DataSource) config.DataSource {
|
||||||
var connString string
|
var connString string
|
||||||
connString = string(*d)
|
connString = string(*d)
|
||||||
return config.DataSource(strings.Replace(connString, "cosmosdb:", "", 1))
|
return config.DataSource(strings.Replace(connString, "cosmosdb:", "", 1))
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetConnectionProperties(connectionString string) map[string]string {
|
func getConnectionProperties(connectionString string) map[string]string {
|
||||||
connectionItemsRaw := strings.Split(connectionString, ";")
|
connectionItemsRaw := strings.Split(connectionString, ";")
|
||||||
connectionItems := map[string]string{}
|
connectionItems := map[string]string{}
|
||||||
for _, item := range connectionItemsRaw {
|
for _, item := range connectionItemsRaw {
|
||||||
itemSplit := strings.SplitN(item, "=", 2)
|
if len(item) > 0 {
|
||||||
connectionItems[itemSplit[0]] = itemSplit[1]
|
itemSplit := strings.SplitN(item, "=", 2)
|
||||||
|
connectionItems[itemSplit[0]] = itemSplit[1]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return connectionItems
|
return connectionItems
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetCosmosConnection(d *config.DataSource) cosmosdbapi.CosmosConnection {
|
||||||
|
connString := getConnectionString(d)
|
||||||
|
connMap := getConnectionProperties(string(connString))
|
||||||
|
accountEndpoint := connMap[accountEndpointName]
|
||||||
|
accountKey := connMap[accountKeyName]
|
||||||
|
return cosmosdbapi.GetCosmosConnection(accountEndpoint, accountKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetCosmosConfig(d *config.DataSource) cosmosdbapi.CosmosConfig {
|
||||||
|
connString := getConnectionString(d)
|
||||||
|
connMap := getConnectionProperties(string(connString))
|
||||||
|
database := connMap[databaseName]
|
||||||
|
container := connMap[containerName]
|
||||||
|
return cosmosdbapi.CosmosConfig{
|
||||||
|
DatabaseName: database,
|
||||||
|
ContainerName: container,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,14 +15,12 @@
|
||||||
package cosmosdb
|
package cosmosdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/shared"
|
"github.com/matrix-org/dendrite/keyserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
|
||||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
|
||||||
db, err := sqlutil.Open(dbProperties)
|
db, err := sqlutil.Open(dbProperties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@
|
||||||
package cosmosdb
|
package cosmosdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
|
|
@ -37,7 +36,6 @@ type Database struct {
|
||||||
|
|
||||||
// Open opens a postgres database.
|
// Open opens a postgres database.
|
||||||
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
|
||||||
d := Database{
|
d := Database{
|
||||||
writer: sqlutil.NewExclusiveWriter(),
|
writer: sqlutil.NewExclusiveWriter(),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@
|
||||||
package cosmosdb
|
package cosmosdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
|
|
@ -38,7 +37,6 @@ type Database struct {
|
||||||
|
|
||||||
// Open a sqlite database.
|
// Open a sqlite database.
|
||||||
func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
|
func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
|
||||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
|
||||||
var d Database
|
var d Database
|
||||||
var db *sql.DB
|
var db *sql.DB
|
||||||
var err error
|
var err error
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
package kafka
|
package kafka
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
|
||||||
"github.com/Shopify/sarama"
|
"github.com/Shopify/sarama"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/naffka"
|
"github.com/matrix-org/naffka"
|
||||||
|
|
@ -47,8 +46,9 @@ func setupNaffka(cfg *config.Kafka) (sarama.Consumer, sarama.SyncProducer) {
|
||||||
if naffkaInstance != nil {
|
if naffkaInstance != nil {
|
||||||
return naffkaInstance, naffkaInstance
|
return naffkaInstance, naffkaInstance
|
||||||
}
|
}
|
||||||
if(cfg.Database.ConnectionString.IsCosmosDB()) {
|
if cfg.Database.ConnectionString.IsCosmosDB() {
|
||||||
cfg.Database.ConnectionString = cosmosdbutil.GetConnectionString(&cfg.Database.ConnectionString)
|
//TODO: What do we do for Nafka
|
||||||
|
// cfg.Database.ConnectionString = cosmosdbutil.GetConnectionString(&cfg.Database.ConnectionString)
|
||||||
}
|
}
|
||||||
|
|
||||||
naffkaDB, err := naffkaStorage.NewDatabase(string(cfg.Database.ConnectionString))
|
naffkaDB, err := naffkaStorage.NewDatabase(string(cfg.Database.ConnectionString))
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@
|
||||||
package cosmosdb
|
package cosmosdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
|
|
@ -45,7 +44,6 @@ func NewDatabase(
|
||||||
serverKey ed25519.PublicKey,
|
serverKey ed25519.PublicKey,
|
||||||
serverKeyID gomatrixserverlib.KeyID,
|
serverKeyID gomatrixserverlib.KeyID,
|
||||||
) (*Database, error) {
|
) (*Database, error) {
|
||||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
|
||||||
db, err := sqlutil.Open(dbProperties)
|
db, err := sqlutil.Open(dbProperties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@
|
||||||
package cosmosdb
|
package cosmosdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
// Import the sqlite3 package
|
// Import the sqlite3 package
|
||||||
|
|
@ -41,7 +40,6 @@ type SyncServerDatasource struct {
|
||||||
// NewDatabase creates a new sync server database
|
// NewDatabase creates a new sync server database
|
||||||
// nolint: gocyclo
|
// nolint: gocyclo
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) {
|
||||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
|
||||||
var d SyncServerDatasource
|
var d SyncServerDatasource
|
||||||
var err error
|
var err error
|
||||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,6 @@ func (s *accountDataStatements) insertAccountData(
|
||||||
Content: content,
|
Content: content,
|
||||||
}
|
}
|
||||||
|
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
|
||||||
id := ""
|
id := ""
|
||||||
if roomID == "" {
|
if roomID == "" {
|
||||||
|
|
@ -94,9 +93,9 @@ func (s *accountDataStatements) insertAccountData(
|
||||||
}
|
}
|
||||||
|
|
||||||
var dbData = AccountDataCosmosData{
|
var dbData = AccountDataCosmosData{
|
||||||
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id),
|
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, id),
|
||||||
Cn: dbCollectionName,
|
Cn: dbCollectionName,
|
||||||
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
|
||||||
Timestamp: time.Now().Unix(),
|
Timestamp: time.Now().Unix(),
|
||||||
AccountData: result,
|
AccountData: result,
|
||||||
}
|
}
|
||||||
|
|
@ -104,8 +103,8 @@ func (s *accountDataStatements) insertAccountData(
|
||||||
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||||
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
dbData,
|
dbData,
|
||||||
options)
|
options)
|
||||||
|
|
||||||
|
|
@ -120,9 +119,8 @@ func (s *accountDataStatements) selectAccountData(
|
||||||
error,
|
error,
|
||||||
) {
|
) {
|
||||||
// "SELECT room_id, type, content FROM account_data WHERE localpart = $1"
|
// "SELECT room_id, type, content FROM account_data WHERE localpart = $1"
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
|
||||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
response := []AccountDataCosmosData{}
|
response := []AccountDataCosmosData{}
|
||||||
params := map[string]interface{}{
|
params := map[string]interface{}{
|
||||||
"@x1": dbCollectionName,
|
"@x1": dbCollectionName,
|
||||||
|
|
@ -132,8 +130,8 @@ func (s *accountDataStatements) selectAccountData(
|
||||||
var query = cosmosdbapi.GetQuery(s.selectAccountDataStmt, params)
|
var query = cosmosdbapi.GetQuery(s.selectAccountDataStmt, params)
|
||||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
query,
|
query,
|
||||||
&response,
|
&response,
|
||||||
options)
|
options)
|
||||||
|
|
@ -167,9 +165,8 @@ func (s *accountDataStatements) selectAccountDataByType(
|
||||||
var bytes []byte
|
var bytes []byte
|
||||||
|
|
||||||
// "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
// "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
|
||||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
response := []AccountDataCosmosData{}
|
response := []AccountDataCosmosData{}
|
||||||
params := map[string]interface{}{
|
params := map[string]interface{}{
|
||||||
"@x1": dbCollectionName,
|
"@x1": dbCollectionName,
|
||||||
|
|
@ -181,8 +178,8 @@ func (s *accountDataStatements) selectAccountDataByType(
|
||||||
var query = cosmosdbapi.GetQuery(s.selectAccountDataByTypeStmt, params)
|
var query = cosmosdbapi.GetQuery(s.selectAccountDataByTypeStmt, params)
|
||||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
query,
|
query,
|
||||||
&response,
|
&response,
|
||||||
options)
|
options)
|
||||||
|
|
|
||||||
|
|
@ -87,26 +87,26 @@ func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.Serv
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*AccountCosmosData, error) {
|
func getAccount(s *accountsStatements, ctx context.Context, pk string, docId string) (*AccountCosmosData, error) {
|
||||||
response := AccountCosmosData{}
|
response := AccountCosmosData{}
|
||||||
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
|
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
|
||||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
docId,
|
docId,
|
||||||
optionsGet,
|
optionsGet,
|
||||||
&response)
|
&response)
|
||||||
return &response, ex
|
return &response, ex
|
||||||
}
|
}
|
||||||
|
|
||||||
func setAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, account AccountCosmosData) (*AccountCosmosData, error) {
|
func setAccount(s *accountsStatements, ctx context.Context, pk string, account AccountCosmosData) (*AccountCosmosData, error) {
|
||||||
response := AccountCosmosData{}
|
response := AccountCosmosData{}
|
||||||
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, account.ETag)
|
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, account.ETag)
|
||||||
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
account.Id,
|
account.Id,
|
||||||
&account,
|
&account,
|
||||||
optionsReplace)
|
optionsReplace)
|
||||||
|
|
@ -153,13 +153,12 @@ func (s *accountsStatements) insertAccount(
|
||||||
data.PasswordHash = hash
|
data.PasswordHash = hash
|
||||||
data.IsDeactivated = false
|
data.IsDeactivated = false
|
||||||
|
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||||
|
|
||||||
var dbData = AccountCosmosData{
|
var dbData = AccountCosmosData{
|
||||||
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart),
|
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Localpart),
|
||||||
Cn: dbCollectionName,
|
Cn: dbCollectionName,
|
||||||
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
|
||||||
Timestamp: time.Now().Unix(),
|
Timestamp: time.Now().Unix(),
|
||||||
Account: data,
|
Account: data,
|
||||||
}
|
}
|
||||||
|
|
@ -167,8 +166,8 @@ func (s *accountsStatements) insertAccount(
|
||||||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||||
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
dbData,
|
dbData,
|
||||||
options)
|
options)
|
||||||
|
|
||||||
|
|
@ -184,19 +183,18 @@ func (s *accountsStatements) updatePassword(
|
||||||
) (err error) {
|
) (err error) {
|
||||||
|
|
||||||
// "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
// "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||||
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
|
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
|
||||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
|
|
||||||
var response, exGet = getAccount(s, ctx, config, pk, docId)
|
var response, exGet = getAccount(s, ctx, pk, docId)
|
||||||
if exGet != nil {
|
if exGet != nil {
|
||||||
return exGet
|
return exGet
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Account.PasswordHash = passwordHash
|
response.Account.PasswordHash = passwordHash
|
||||||
|
|
||||||
var _, exReplace = setAccount(s, ctx, config, pk, *response)
|
var _, exReplace = setAccount(s, ctx, pk, *response)
|
||||||
if exReplace != nil {
|
if exReplace != nil {
|
||||||
return exReplace
|
return exReplace
|
||||||
}
|
}
|
||||||
|
|
@ -208,19 +206,18 @@ func (s *accountsStatements) deactivateAccount(
|
||||||
) (err error) {
|
) (err error) {
|
||||||
|
|
||||||
// "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
// "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||||
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
|
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
|
||||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
|
|
||||||
var response, exGet = getAccount(s, ctx, config, pk, docId)
|
var response, exGet = getAccount(s, ctx, pk, docId)
|
||||||
if exGet != nil {
|
if exGet != nil {
|
||||||
return exGet
|
return exGet
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Account.IsDeactivated = true
|
response.Account.IsDeactivated = true
|
||||||
|
|
||||||
var _, exReplace = setAccount(s, ctx, config, pk, *response)
|
var _, exReplace = setAccount(s, ctx, pk, *response)
|
||||||
if exReplace != nil {
|
if exReplace != nil {
|
||||||
return exReplace
|
return exReplace
|
||||||
}
|
}
|
||||||
|
|
@ -232,9 +229,8 @@ func (s *accountsStatements) selectPasswordHash(
|
||||||
) (hash string, err error) {
|
) (hash string, err error) {
|
||||||
|
|
||||||
// "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
// "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
response := []AccountCosmosData{}
|
response := []AccountCosmosData{}
|
||||||
params := map[string]interface{}{
|
params := map[string]interface{}{
|
||||||
"@x1": dbCollectionName,
|
"@x1": dbCollectionName,
|
||||||
|
|
@ -244,8 +240,8 @@ func (s *accountsStatements) selectPasswordHash(
|
||||||
var query = cosmosdbapi.GetQuery(s.selectPasswordHashStmt, params)
|
var query = cosmosdbapi.GetQuery(s.selectPasswordHashStmt, params)
|
||||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
query,
|
query,
|
||||||
&response,
|
&response,
|
||||||
options)
|
options)
|
||||||
|
|
@ -271,9 +267,8 @@ func (s *accountsStatements) selectAccountByLocalpart(
|
||||||
var acc api.Account
|
var acc api.Account
|
||||||
|
|
||||||
// "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
|
// "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
response := []AccountCosmosData{}
|
response := []AccountCosmosData{}
|
||||||
params := map[string]interface{}{
|
params := map[string]interface{}{
|
||||||
"@x1": dbCollectionName,
|
"@x1": dbCollectionName,
|
||||||
|
|
@ -283,8 +278,8 @@ func (s *accountsStatements) selectAccountByLocalpart(
|
||||||
var query = cosmosdbapi.GetQuery(s.selectAccountByLocalpartStmt, params)
|
var query = cosmosdbapi.GetQuery(s.selectAccountByLocalpartStmt, params)
|
||||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
query,
|
query,
|
||||||
&response,
|
&response,
|
||||||
options)
|
options)
|
||||||
|
|
@ -309,9 +304,8 @@ func (s *accountsStatements) selectNewNumericLocalpart(
|
||||||
) (id int64, err error) {
|
) (id int64, err error) {
|
||||||
|
|
||||||
// "SELECT COUNT(localpart) FROM account_accounts"
|
// "SELECT COUNT(localpart) FROM account_accounts"
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
var response []AccountCosmosUserCount
|
var response []AccountCosmosUserCount
|
||||||
params := map[string]interface{}{
|
params := map[string]interface{}{
|
||||||
"@x1": dbCollectionName,
|
"@x1": dbCollectionName,
|
||||||
|
|
@ -320,8 +314,8 @@ func (s *accountsStatements) selectNewNumericLocalpart(
|
||||||
var query = cosmosdbapi.GetQuery(s.selectNewNumericLocalpartStmt, params)
|
var query = cosmosdbapi.GetQuery(s.selectNewNumericLocalpartStmt, params)
|
||||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
query,
|
query,
|
||||||
&response,
|
&response,
|
||||||
options)
|
options)
|
||||||
|
|
|
||||||
|
|
@ -85,13 +85,12 @@ func (s *tokenStatements) insertToken(
|
||||||
ExpiresAtMS: expiresAtMS,
|
ExpiresAtMS: expiresAtMS,
|
||||||
}
|
}
|
||||||
|
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
|
||||||
|
|
||||||
var dbData = OpenIdTokenCosmosData{
|
var dbData = OpenIdTokenCosmosData{
|
||||||
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Token),
|
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Token),
|
||||||
Cn: dbCollectionName,
|
Cn: dbCollectionName,
|
||||||
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
|
||||||
Timestamp: time.Now().Unix(),
|
Timestamp: time.Now().Unix(),
|
||||||
OpenIdToken: mapToToken(*result),
|
OpenIdToken: mapToToken(*result),
|
||||||
}
|
}
|
||||||
|
|
@ -99,8 +98,8 @@ func (s *tokenStatements) insertToken(
|
||||||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||||
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
dbData,
|
dbData,
|
||||||
options)
|
options)
|
||||||
|
|
||||||
|
|
@ -120,9 +119,8 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes(
|
||||||
var openIDTokenAttrs api.OpenIDTokenAttributes
|
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||||
|
|
||||||
// "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
// "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
|
||||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
response := []OpenIdTokenCosmosData{}
|
response := []OpenIdTokenCosmosData{}
|
||||||
params := map[string]interface{}{
|
params := map[string]interface{}{
|
||||||
"@x1": dbCollectionName,
|
"@x1": dbCollectionName,
|
||||||
|
|
@ -132,8 +130,8 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes(
|
||||||
var query = cosmosdbapi.GetQuery(s.selectTokenStmt, params)
|
var query = cosmosdbapi.GetQuery(s.selectTokenStmt, params)
|
||||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
query,
|
query,
|
||||||
&response,
|
&response,
|
||||||
options)
|
options)
|
||||||
|
|
|
||||||
|
|
@ -87,25 +87,25 @@ func (s *profilesStatements) prepare(db *Database) (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func getProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*ProfileCosmosData, error) {
|
func getProfile(s *profilesStatements, ctx context.Context, pk string, docId string) (*ProfileCosmosData, error) {
|
||||||
response := ProfileCosmosData{}
|
response := ProfileCosmosData{}
|
||||||
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
|
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
|
||||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
docId,
|
docId,
|
||||||
optionsGet,
|
optionsGet,
|
||||||
&response)
|
&response)
|
||||||
return &response, ex
|
return &response, ex
|
||||||
}
|
}
|
||||||
|
|
||||||
func setProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) {
|
func setProfile(s *profilesStatements, ctx context.Context, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) {
|
||||||
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, profile.ETag)
|
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, profile.ETag)
|
||||||
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
profile.Id,
|
profile.Id,
|
||||||
&profile,
|
&profile,
|
||||||
optionsReplace)
|
optionsReplace)
|
||||||
|
|
@ -121,13 +121,12 @@ func (s *profilesStatements) insertProfile(
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
}
|
}
|
||||||
|
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||||
|
|
||||||
var dbData = ProfileCosmosData{
|
var dbData = ProfileCosmosData{
|
||||||
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart),
|
Id: cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, result.Localpart),
|
||||||
Cn: dbCollectionName,
|
Cn: dbCollectionName,
|
||||||
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
|
||||||
Timestamp: time.Now().Unix(),
|
Timestamp: time.Now().Unix(),
|
||||||
Profile: mapToProfile(*result),
|
Profile: mapToProfile(*result),
|
||||||
}
|
}
|
||||||
|
|
@ -135,8 +134,8 @@ func (s *profilesStatements) insertProfile(
|
||||||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||||
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
dbData,
|
dbData,
|
||||||
options)
|
options)
|
||||||
|
|
||||||
|
|
@ -148,9 +147,8 @@ func (s *profilesStatements) selectProfileByLocalpart(
|
||||||
) (*authtypes.Profile, error) {
|
) (*authtypes.Profile, error) {
|
||||||
|
|
||||||
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
|
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
response := []ProfileCosmosData{}
|
response := []ProfileCosmosData{}
|
||||||
params := map[string]interface{}{
|
params := map[string]interface{}{
|
||||||
"@x1": dbCollectionName,
|
"@x1": dbCollectionName,
|
||||||
|
|
@ -160,8 +158,8 @@ func (s *profilesStatements) selectProfileByLocalpart(
|
||||||
var query = cosmosdbapi.GetQuery(s.selectProfileByLocalpartStmt, params)
|
var query = cosmosdbapi.GetQuery(s.selectProfileByLocalpartStmt, params)
|
||||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
query,
|
query,
|
||||||
&response,
|
&response,
|
||||||
options)
|
options)
|
||||||
|
|
@ -187,19 +185,18 @@ func (s *profilesStatements) setAvatarURL(
|
||||||
) (err error) {
|
) (err error) {
|
||||||
|
|
||||||
// "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
|
// "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
|
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
|
||||||
|
|
||||||
var response, exGet = getProfile(s, ctx, config, pk, docId)
|
var response, exGet = getProfile(s, ctx, pk, docId)
|
||||||
if exGet != nil {
|
if exGet != nil {
|
||||||
return exGet
|
return exGet
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Profile.AvatarURL = avatarURL
|
response.Profile.AvatarURL = avatarURL
|
||||||
|
|
||||||
var _, exReplace = setProfile(s, ctx, config, pk, *response)
|
var _, exReplace = setProfile(s, ctx, pk, *response)
|
||||||
if exReplace != nil {
|
if exReplace != nil {
|
||||||
return exReplace
|
return exReplace
|
||||||
}
|
}
|
||||||
|
|
@ -211,18 +208,17 @@ func (s *profilesStatements) setDisplayName(
|
||||||
) (err error) {
|
) (err error) {
|
||||||
|
|
||||||
// "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
|
// "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart)
|
var docId = cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, localpart)
|
||||||
var response, exGet = getProfile(s, ctx, config, pk, docId)
|
var response, exGet = getProfile(s, ctx, pk, docId)
|
||||||
if exGet != nil {
|
if exGet != nil {
|
||||||
return exGet
|
return exGet
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Profile.DisplayName = displayName
|
response.Profile.DisplayName = displayName
|
||||||
|
|
||||||
var _, exReplace = setProfile(s, ctx, config, pk, *response)
|
var _, exReplace = setProfile(s, ctx, pk, *response)
|
||||||
if exReplace != nil {
|
if exReplace != nil {
|
||||||
return exReplace
|
return exReplace
|
||||||
}
|
}
|
||||||
|
|
@ -235,9 +231,8 @@ func (s *profilesStatements) selectProfilesBySearch(
|
||||||
var profiles []authtypes.Profile
|
var profiles []authtypes.Profile
|
||||||
|
|
||||||
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
|
||||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
response := []ProfileCosmosData{}
|
response := []ProfileCosmosData{}
|
||||||
params := map[string]interface{}{
|
params := map[string]interface{}{
|
||||||
"@x1": dbCollectionName,
|
"@x1": dbCollectionName,
|
||||||
|
|
@ -248,8 +243,8 @@ func (s *profilesStatements) selectProfilesBySearch(
|
||||||
var query = cosmosdbapi.GetQuery(s.selectProfilesBySearchStmt, params)
|
var query = cosmosdbapi.GetQuery(s.selectProfilesBySearchStmt, params)
|
||||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
query,
|
query,
|
||||||
&response,
|
&response,
|
||||||
options)
|
options)
|
||||||
|
|
|
||||||
|
|
@ -48,20 +48,19 @@ type Database struct {
|
||||||
|
|
||||||
databaseName string
|
databaseName string
|
||||||
connection cosmosdbapi.CosmosConnection
|
connection cosmosdbapi.CosmosConnection
|
||||||
|
cosmosConfig cosmosdbapi.CosmosConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase creates a new accounts and profiles database
|
// NewDatabase creates a new accounts and profiles database
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
|
||||||
connString := cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
|
||||||
connMap := cosmosdbutil.GetConnectionProperties(string(connString))
|
config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
|
||||||
accountEndpoint := connMap["AccountEndpoint"]
|
|
||||||
accountKey := connMap["AccountKey"]
|
|
||||||
conn := cosmosdbapi.GetCosmosConnection(accountEndpoint, accountKey)
|
|
||||||
|
|
||||||
d := &Database{
|
d := &Database{
|
||||||
serverName: serverName,
|
serverName: serverName,
|
||||||
databaseName: "userapi",
|
databaseName: "userapi",
|
||||||
connection: conn,
|
connection: conn,
|
||||||
|
cosmosConfig: config,
|
||||||
// db: db,
|
// db: db,
|
||||||
// writer: sqlutil.NewExclusiveWriter(),
|
// writer: sqlutil.NewExclusiveWriter(),
|
||||||
// bcryptCost: bcryptCost,
|
// bcryptCost: bcryptCost,
|
||||||
|
|
|
||||||
|
|
@ -74,9 +74,8 @@ func (s *threepidStatements) selectLocalpartForThreePID(
|
||||||
) (localpart string, err error) {
|
) (localpart string, err error) {
|
||||||
|
|
||||||
// "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
// "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
|
||||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
response := []ThreePIDCosmosData{}
|
response := []ThreePIDCosmosData{}
|
||||||
params := map[string]interface{}{
|
params := map[string]interface{}{
|
||||||
"@x1": dbCollectionName,
|
"@x1": dbCollectionName,
|
||||||
|
|
@ -87,8 +86,8 @@ func (s *threepidStatements) selectLocalpartForThreePID(
|
||||||
var query = cosmosdbapi.GetQuery(s.selectLocalpartForThreePIDStmt, params)
|
var query = cosmosdbapi.GetQuery(s.selectLocalpartForThreePIDStmt, params)
|
||||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
query,
|
query,
|
||||||
&response,
|
&response,
|
||||||
options)
|
options)
|
||||||
|
|
@ -109,9 +108,8 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
|
||||||
) (threepids []authtypes.ThreePID, err error) {
|
) (threepids []authtypes.ThreePID, err error) {
|
||||||
|
|
||||||
// "SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
|
// "SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
|
||||||
var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
response := []ThreePIDCosmosData{}
|
response := []ThreePIDCosmosData{}
|
||||||
params := map[string]interface{}{
|
params := map[string]interface{}{
|
||||||
"@x1": dbCollectionName,
|
"@x1": dbCollectionName,
|
||||||
|
|
@ -121,8 +119,8 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
|
||||||
var query = cosmosdbapi.GetQuery(s.selectThreePIDsForLocalpartStmt, params)
|
var query = cosmosdbapi.GetQuery(s.selectThreePIDsForLocalpartStmt, params)
|
||||||
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
query,
|
query,
|
||||||
&response,
|
&response,
|
||||||
options)
|
options)
|
||||||
|
|
@ -156,14 +154,14 @@ func (s *threepidStatements) insertThreePID(
|
||||||
ThreePID: threepid,
|
ThreePID: threepid,
|
||||||
}
|
}
|
||||||
|
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||||
|
|
||||||
id := fmt.Sprintf("%s_%s", threepid, medium)
|
docId := fmt.Sprintf("%s_%s", threepid, medium)
|
||||||
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||||
var dbData = ThreePIDCosmosData{
|
var dbData = ThreePIDCosmosData{
|
||||||
Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id),
|
Id: cosmosDocId,
|
||||||
Cn: dbCollectionName,
|
Cn: dbCollectionName,
|
||||||
Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName),
|
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
|
||||||
Timestamp: time.Now().Unix(),
|
Timestamp: time.Now().Unix(),
|
||||||
ThreePID: result,
|
ThreePID: result,
|
||||||
}
|
}
|
||||||
|
|
@ -171,8 +169,8 @@ func (s *threepidStatements) insertThreePID(
|
||||||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||||
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
dbData,
|
dbData,
|
||||||
options)
|
options)
|
||||||
|
|
||||||
|
|
@ -186,16 +184,16 @@ func (s *threepidStatements) deleteThreePID(
|
||||||
ctx context.Context, threepid string, medium string) (err error) {
|
ctx context.Context, threepid string, medium string) (err error) {
|
||||||
|
|
||||||
// "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
// "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||||
var config = cosmosdbapi.DefaultConfig()
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
|
||||||
id := fmt.Sprintf("%s_%s", threepid, medium)
|
docId := fmt.Sprintf("%s_%s", threepid, medium)
|
||||||
pk := cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName)
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||||
|
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
|
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
|
||||||
_, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
|
_, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
|
||||||
ctx,
|
ctx,
|
||||||
config.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
config.TenantName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
id,
|
cosmosDocId,
|
||||||
options)
|
options)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -16,127 +16,145 @@ package cosmosdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"errors"
|
||||||
"strings"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
const devicesSchema = `
|
// const devicesSchema = `
|
||||||
-- This sequence is used for automatic allocation of session_id.
|
// -- This sequence is used for automatic allocation of session_id.
|
||||||
-- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
|
// -- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
|
||||||
|
|
||||||
-- Stores data about devices.
|
// -- Stores data about devices.
|
||||||
CREATE TABLE IF NOT EXISTS device_devices (
|
// CREATE TABLE IF NOT EXISTS device_devices (
|
||||||
access_token TEXT PRIMARY KEY,
|
// access_token TEXT PRIMARY KEY,
|
||||||
session_id INTEGER,
|
// session_id INTEGER,
|
||||||
device_id TEXT ,
|
// device_id TEXT ,
|
||||||
localpart TEXT ,
|
// localpart TEXT ,
|
||||||
created_ts BIGINT,
|
// created_ts BIGINT,
|
||||||
display_name TEXT,
|
// display_name TEXT,
|
||||||
last_seen_ts BIGINT,
|
// last_seen_ts BIGINT,
|
||||||
ip TEXT,
|
// ip TEXT,
|
||||||
user_agent TEXT,
|
// user_agent TEXT,
|
||||||
|
|
||||||
UNIQUE (localpart, device_id)
|
// UNIQUE (localpart, device_id)
|
||||||
);
|
// );
|
||||||
`
|
// `
|
||||||
|
|
||||||
const insertDeviceSQL = "" +
|
type DeviceCosmos struct {
|
||||||
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
ID string `json:"device_id"`
|
||||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
UserID string `json:"user_id"`
|
||||||
|
// The access_token granted to this device.
|
||||||
|
// This uniquely identifies the device from all other devices and clients.
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
// The unique ID of the session identified by the access token.
|
||||||
|
// Can be used as a secure substitution in places where data needs to be
|
||||||
|
// associated with access tokens.
|
||||||
|
SessionID int64 `json:"session_id"`
|
||||||
|
DisplayName string `json:"display_name"`
|
||||||
|
LastSeenTS int64 `json:"last_seen_ts"`
|
||||||
|
LastSeenIP string `json:"last_seen_ip"`
|
||||||
|
Localpart string `json:"local_part"`
|
||||||
|
UserAgent string `json:"user_agent"`
|
||||||
|
// If the device is for an appservice user,
|
||||||
|
// this is the appservice ID.
|
||||||
|
AppserviceID string `json:"app_service_id"`
|
||||||
|
}
|
||||||
|
|
||||||
const selectDevicesCountSQL = "" +
|
type DeviceCosmosData struct {
|
||||||
"SELECT COUNT(access_token) FROM device_devices"
|
Id string `json:"id"`
|
||||||
|
Pk string `json:"_pk"`
|
||||||
|
Cn string `json:"_cn"`
|
||||||
|
ETag string `json:"_etag"`
|
||||||
|
Timestamp int64 `json:"_ts"`
|
||||||
|
Device DeviceCosmos `json:"mx_userapi_device"`
|
||||||
|
}
|
||||||
|
|
||||||
const selectDeviceByTokenSQL = "" +
|
type DeviceCosmosSessionCount struct {
|
||||||
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
|
SessionCount int64 `json:"sessioncount"`
|
||||||
|
}
|
||||||
const selectDeviceByIDSQL = "" +
|
|
||||||
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
|
||||||
|
|
||||||
const selectDevicesByLocalpartSQL = "" +
|
|
||||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
|
||||||
|
|
||||||
const updateDeviceNameSQL = "" +
|
|
||||||
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
|
||||||
|
|
||||||
const deleteDeviceSQL = "" +
|
|
||||||
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
|
|
||||||
|
|
||||||
const deleteDevicesByLocalpartSQL = "" +
|
|
||||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
|
||||||
|
|
||||||
const deleteDevicesSQL = "" +
|
|
||||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
|
||||||
|
|
||||||
const selectDevicesByIDSQL = "" +
|
|
||||||
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
|
|
||||||
|
|
||||||
const updateDeviceLastSeen = "" +
|
|
||||||
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
|
|
||||||
|
|
||||||
type devicesStatements struct {
|
type devicesStatements struct {
|
||||||
db *sql.DB
|
db *Database
|
||||||
writer sqlutil.Writer
|
selectDevicesCountStmt string
|
||||||
insertDeviceStmt *sql.Stmt
|
selectDeviceByTokenStmt string
|
||||||
selectDevicesCountStmt *sql.Stmt
|
// selectDeviceByIDStmt *sql.Stmt
|
||||||
selectDeviceByTokenStmt *sql.Stmt
|
selectDevicesByIDStmt string
|
||||||
selectDeviceByIDStmt *sql.Stmt
|
selectDevicesByLocalpartStmt string
|
||||||
selectDevicesByIDStmt *sql.Stmt
|
selectDevicesByLocalpartExceptIDStmt string
|
||||||
selectDevicesByLocalpartStmt *sql.Stmt
|
serverName gomatrixserverlib.ServerName
|
||||||
updateDeviceNameStmt *sql.Stmt
|
tableName string
|
||||||
updateDeviceLastSeenStmt *sql.Stmt
|
|
||||||
deleteDeviceStmt *sql.Stmt
|
|
||||||
deleteDevicesByLocalpartStmt *sql.Stmt
|
|
||||||
serverName gomatrixserverlib.ServerName
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) execSchema(db *sql.DB) error {
|
func mapFromDevice(db DeviceCosmos) api.Device {
|
||||||
_, err := db.Exec(devicesSchema)
|
return api.Device{
|
||||||
return err
|
AccessToken: db.AccessToken,
|
||||||
|
AppserviceID: db.AppserviceID,
|
||||||
|
ID: db.ID,
|
||||||
|
LastSeenIP: db.LastSeenIP,
|
||||||
|
LastSeenTS: db.LastSeenTS,
|
||||||
|
SessionID: db.SessionID,
|
||||||
|
UserAgent: db.UserAgent,
|
||||||
|
UserID: db.UserID,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
|
func mapTodevice(api api.Device, s *devicesStatements) DeviceCosmos {
|
||||||
|
localPart, _ := userutil.ParseUsernameParam(api.UserID, &s.serverName)
|
||||||
|
return DeviceCosmos{
|
||||||
|
AccessToken: api.AccessToken,
|
||||||
|
AppserviceID: api.AppserviceID,
|
||||||
|
ID: api.ID,
|
||||||
|
LastSeenIP: api.LastSeenIP,
|
||||||
|
LastSeenTS: api.LastSeenTS,
|
||||||
|
Localpart: localPart,
|
||||||
|
SessionID: api.SessionID,
|
||||||
|
UserAgent: api.UserAgent,
|
||||||
|
UserID: api.UserID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDevice(s *devicesStatements, ctx context.Context, pk string, docId string) (*DeviceCosmosData, error) {
|
||||||
|
response := DeviceCosmosData{}
|
||||||
|
var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk)
|
||||||
|
var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
docId,
|
||||||
|
optionsGet,
|
||||||
|
&response)
|
||||||
|
return &response, ex
|
||||||
|
}
|
||||||
|
|
||||||
|
func setDevice(s *devicesStatements, ctx context.Context, pk string, device DeviceCosmosData) (*DeviceCosmosData, error) {
|
||||||
|
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, device.ETag)
|
||||||
|
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
device.Id,
|
||||||
|
&device,
|
||||||
|
optionsReplace)
|
||||||
|
return &device, ex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *devicesStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
s.writer = writer
|
s.selectDevicesCountStmt = "select count(c._ts) as sessioncount from c where c._cn = @x1"
|
||||||
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
|
s.selectDevicesByLocalpartStmt = "select * from c where c._cn = @x1 and c.mx_userapi_device.local_part = @x2 and ARRAY_CONTAINS(@x3, c.mx_userapi_device.device_id)"
|
||||||
return
|
s.selectDevicesByLocalpartExceptIDStmt = "select * from c where c._cn = @x1 and c.mx_userapi_device.local_part = @x2 and c.mx_userapi_device.device_id != @x3"
|
||||||
}
|
s.selectDeviceByTokenStmt = "select * from c where c._cn = @x1 and c.mx_userapi_device.access_token = @x2"
|
||||||
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil {
|
s.selectDevicesByIDStmt = "select * from c where c._cn = @x1 and ARRAY_CONTAINS(@x2, c.mx_userapi_device.device_id)"
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.serverName = server
|
s.serverName = server
|
||||||
|
s.tableName = "device_devices"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -144,85 +162,219 @@ func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server go
|
||||||
// Returns an error if the user already has a device with the given device ID.
|
// Returns an error if the user already has a device with the given device ID.
|
||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
func (s *devicesStatements) insertDevice(
|
func (s *devicesStatements) insertDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
ctx context.Context, id, localpart, accessToken string,
|
||||||
displayName *string, ipAddr, userAgent string,
|
displayName *string, ipAddr, userAgent string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
var sessionID int64
|
var sessionID int64
|
||||||
countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
|
// "SELECT COUNT(access_token) FROM device_devices"
|
||||||
insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
|
// HACK: Do we need a Cosmos Table for the sequence?
|
||||||
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
|
var response []DeviceCosmosSessionCount
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
}
|
||||||
|
|
||||||
|
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(s.selectDevicesCountStmt, params)
|
||||||
|
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
optionsQry)
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
sessionID = response[0].SessionCount
|
||||||
sessionID++
|
sessionID++
|
||||||
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
// "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
||||||
return nil, err
|
// " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||||||
}
|
|
||||||
return &api.Device{
|
data := DeviceCosmos{
|
||||||
ID: id,
|
ID: id,
|
||||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
SessionID: sessionID,
|
SessionID: sessionID,
|
||||||
LastSeenTS: createdTimeMS,
|
LastSeenTS: createdTimeMS,
|
||||||
LastSeenIP: ipAddr,
|
LastSeenIP: ipAddr,
|
||||||
|
Localpart: localpart,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
// access_token TEXT PRIMARY KEY,
|
||||||
|
// UNIQUE (localpart, device_id)
|
||||||
|
// HACK: check for duplicate PK as we are using the UNIQUE key for the DocId
|
||||||
|
docId := fmt.Sprintf("%s_%s", localpart, id)
|
||||||
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||||
|
var dbData = DeviceCosmosData{
|
||||||
|
Id: cosmosDocId,
|
||||||
|
Cn: dbCollectionName,
|
||||||
|
Pk: cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName),
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
Device: data,
|
||||||
|
}
|
||||||
|
|
||||||
|
var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||||
|
var _, _, errCreate = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
dbData,
|
||||||
|
optionsCreate)
|
||||||
|
|
||||||
|
if errCreate != nil {
|
||||||
|
return nil, errCreate
|
||||||
|
}
|
||||||
|
|
||||||
|
var result = mapFromDevice(dbData.Device)
|
||||||
|
return &result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) deleteDevice(
|
func (s *devicesStatements) deleteDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
ctx context.Context, id, localpart string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
// "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
|
||||||
_, err := stmt.ExecContext(ctx, id, localpart)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||||
|
docId := fmt.Sprintf("%s_%s", localpart, id)
|
||||||
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||||
|
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
|
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
|
||||||
|
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
cosmosDocId,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) deleteDevices(
|
func (s *devicesStatements) deleteDevices(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
ctx context.Context, localpart string, devices []string,
|
||||||
) error {
|
) error {
|
||||||
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
|
// "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||||
prep, err := s.db.Prepare(orig)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
|
var response []DeviceCosmosData
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": localpart,
|
||||||
|
"@x3": devices,
|
||||||
|
}
|
||||||
|
|
||||||
|
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartStmt, params)
|
||||||
|
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
optionsQry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stmt := sqlutil.TxStmt(txn, prep)
|
for _, item := range response {
|
||||||
params := make([]interface{}, len(devices)+1)
|
s.deleteDevice(ctx, item.Device.ID, item.Device.Localpart)
|
||||||
params[0] = localpart
|
|
||||||
for i, v := range devices {
|
|
||||||
params[i+1] = v
|
|
||||||
}
|
}
|
||||||
_, err = stmt.ExecContext(ctx, params...)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) deleteDevicesByLocalpart(
|
func (s *devicesStatements) deleteDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
ctx context.Context, localpart, exceptDeviceID string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
// "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||||
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
|
var response []DeviceCosmosData
|
||||||
|
exceptDevices := []string{
|
||||||
|
exceptDeviceID,
|
||||||
|
}
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": localpart,
|
||||||
|
"@x3": exceptDevices,
|
||||||
|
}
|
||||||
|
|
||||||
|
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartStmt, params)
|
||||||
|
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
optionsQry)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, item := range response {
|
||||||
|
s.deleteDevice(ctx, item.Device.ID, item.Device.Localpart)
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) updateDeviceName(
|
func (s *devicesStatements) updateDeviceName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
// "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||||
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||||
return err
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
|
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
|
||||||
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||||
|
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
|
||||||
|
if exGet != nil {
|
||||||
|
return exGet
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Device.DisplayName = *displayName
|
||||||
|
|
||||||
|
var _, exReplace = setDevice(s, ctx, pk, *response)
|
||||||
|
if exReplace != nil {
|
||||||
|
return exReplace
|
||||||
|
}
|
||||||
|
return exReplace
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) selectDeviceByToken(
|
func (s *devicesStatements) selectDeviceByToken(
|
||||||
ctx context.Context, accessToken string,
|
ctx context.Context, accessToken string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
var dev api.Device
|
// "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
|
||||||
var localpart string
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||||
stmt := s.selectDeviceByTokenStmt
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
|
var response []DeviceCosmosData
|
||||||
if err == nil {
|
params := map[string]interface{}{
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
"@x1": dbCollectionName,
|
||||||
dev.AccessToken = accessToken
|
"@x2": accessToken,
|
||||||
}
|
}
|
||||||
return &dev, err
|
|
||||||
|
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(s.selectDeviceByTokenStmt, params)
|
||||||
|
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
optionsQry)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(response) == 0 {
|
||||||
|
return nil, errors.New(fmt.Sprintf("No Devices found with AccessToken %s", accessToken))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
result := mapFromDevice(response[0].Device)
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectDeviceByID retrieves a device from the database with the given user
|
// selectDeviceByID retrieves a device from the database with the given user
|
||||||
|
|
@ -230,54 +382,48 @@ func (s *devicesStatements) selectDeviceByToken(
|
||||||
func (s *devicesStatements) selectDeviceByID(
|
func (s *devicesStatements) selectDeviceByID(
|
||||||
ctx context.Context, localpart, deviceID string,
|
ctx context.Context, localpart, deviceID string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
var dev api.Device
|
// "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||||
var displayName sql.NullString
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||||
stmt := s.selectDeviceByIDStmt
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
|
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
|
||||||
if err == nil {
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||||
dev.ID = deviceID
|
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
if exGet != nil {
|
||||||
if displayName.Valid {
|
return nil, exGet
|
||||||
dev.DisplayName = displayName.String
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return &dev, err
|
result := mapFromDevice(response.Device)
|
||||||
|
return &result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) selectDevicesByLocalpart(
|
func (s *devicesStatements) selectDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
ctx context.Context, localpart, exceptDeviceID string,
|
||||||
) ([]api.Device, error) {
|
) ([]api.Device, error) {
|
||||||
devices := []api.Device{}
|
devices := []api.Device{}
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
|
// "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||||
if err != nil {
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
return devices, err
|
var response []DeviceCosmosData
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": localpart,
|
||||||
|
"@x3": exceptDeviceID,
|
||||||
}
|
}
|
||||||
|
|
||||||
for rows.Next() {
|
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
var dev api.Device
|
var query = cosmosdbapi.GetQuery(s.selectDevicesByLocalpartExceptIDStmt, params)
|
||||||
var lastseents sql.NullInt64
|
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
var id, displayname, ip, useragent sql.NullString
|
ctx,
|
||||||
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
|
s.db.cosmosConfig.DatabaseName,
|
||||||
if err != nil {
|
s.db.cosmosConfig.ContainerName,
|
||||||
return devices, err
|
query,
|
||||||
}
|
&response,
|
||||||
if id.Valid {
|
optionsQry)
|
||||||
dev.ID = id.String
|
if err != nil {
|
||||||
}
|
return nil, err
|
||||||
if displayname.Valid {
|
}
|
||||||
dev.DisplayName = displayname.String
|
|
||||||
}
|
|
||||||
if lastseents.Valid {
|
|
||||||
dev.LastSeenTS = lastseents.Int64
|
|
||||||
}
|
|
||||||
if ip.Valid {
|
|
||||||
dev.LastSeenIP = ip.String
|
|
||||||
}
|
|
||||||
if useragent.Valid {
|
|
||||||
dev.UserAgent = useragent.String
|
|
||||||
}
|
|
||||||
|
|
||||||
|
for _, item := range response {
|
||||||
|
dev := mapFromDevice(item.Device)
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||||
devices = append(devices, dev)
|
devices = append(devices, dev)
|
||||||
}
|
}
|
||||||
|
|
@ -286,37 +432,53 @@ func (s *devicesStatements) selectDevicesByLocalpart(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||||
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
|
// "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
|
||||||
iDeviceIDs := make([]interface{}, len(deviceIDs))
|
var devices []api.Device
|
||||||
for i := range deviceIDs {
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||||
iDeviceIDs[i] = deviceIDs[i]
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
|
var response []DeviceCosmosData
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": deviceIDs,
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := s.db.QueryContext(ctx, sqlQuery, iDeviceIDs...)
|
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(s.selectDevicesByIDStmt, params)
|
||||||
|
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
optionsQry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
|
for _, item := range response {
|
||||||
var devices []api.Device
|
dev := mapFromDevice(item.Device)
|
||||||
for rows.Next() {
|
|
||||||
var dev api.Device
|
|
||||||
var localpart string
|
|
||||||
var displayName sql.NullString
|
|
||||||
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if displayName.Valid {
|
|
||||||
dev.DisplayName = displayName.String
|
|
||||||
}
|
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
|
||||||
devices = append(devices, dev)
|
devices = append(devices, dev)
|
||||||
}
|
}
|
||||||
return devices, rows.Err()
|
return devices, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
|
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
|
||||||
lastSeenTs := time.Now().UnixNano() / 1000000
|
lastSeenTs := time.Now().UnixNano() / 1000000
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
|
||||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)
|
// "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
|
||||||
return err
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
|
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
|
||||||
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||||
|
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
|
||||||
|
if exGet != nil {
|
||||||
|
return exGet
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Device.LastSeenTS = lastSeenTs
|
||||||
|
|
||||||
|
var _, exReplace = setDevice(s, ctx, pk, *response)
|
||||||
|
if exReplace != nil {
|
||||||
|
return exReplace
|
||||||
|
}
|
||||||
|
return exReplace
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,16 +15,18 @@
|
||||||
package cosmosdb
|
package cosmosdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"database/sql"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
|
@ -35,35 +37,32 @@ var deviceIDByteLength = 6
|
||||||
|
|
||||||
// Database represents a device database.
|
// Database represents a device database.
|
||||||
type Database struct {
|
type Database struct {
|
||||||
db *sql.DB
|
writer sqlutil.Writer
|
||||||
writer sqlutil.Writer
|
devices devicesStatements
|
||||||
devices devicesStatements
|
connection cosmosdbapi.CosmosConnection
|
||||||
|
databaseName string
|
||||||
|
cosmosConfig cosmosdbapi.CosmosConfig
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase creates a new device database
|
// NewDatabase creates a new device database
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) {
|
||||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
|
||||||
db, err := sqlutil.Open(dbProperties)
|
config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
|
||||||
if err != nil {
|
devices := devicesStatements{}
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
writer := sqlutil.NewExclusiveWriter()
|
|
||||||
d := devicesStatements{}
|
|
||||||
|
|
||||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
// and THEN prepare statements so we don't fail due to referencing new columns
|
||||||
if err = d.execSchema(db); err != nil {
|
d := &Database{
|
||||||
return nil, err
|
databaseName: "userapi",
|
||||||
|
devices: devices,
|
||||||
|
serverName: serverName,
|
||||||
|
connection: conn,
|
||||||
|
cosmosConfig: config,
|
||||||
}
|
}
|
||||||
m := sqlutil.NewMigrations()
|
err := d.devices.prepare(d, serverName)
|
||||||
deltas.LoadLastSeenTSIP(m)
|
|
||||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
return d, err
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.prepare(db, writer, serverName); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &Database{db, writer, d}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDeviceByAccessToken returns the device matching the given access token.
|
// GetDeviceByAccessToken returns the device matching the given access token.
|
||||||
|
|
@ -86,7 +85,7 @@ func (d *Database) GetDeviceByID(
|
||||||
func (d *Database) GetDevicesByLocalpart(
|
func (d *Database) GetDevicesByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) ([]api.Device, error) {
|
) ([]api.Device, error) {
|
||||||
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
|
return d.devices.selectDevicesByLocalpart(ctx, localpart, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||||
|
|
@ -104,16 +103,14 @@ func (d *Database) CreateDevice(
|
||||||
displayName *string, ipAddr, userAgent string,
|
displayName *string, ipAddr, userAgent string,
|
||||||
) (dev *api.Device, returnErr error) {
|
) (dev *api.Device, returnErr error) {
|
||||||
if deviceID != nil {
|
if deviceID != nil {
|
||||||
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
var err error
|
||||||
var err error
|
// Revoke existing tokens for this device
|
||||||
// Revoke existing tokens for this device
|
if err = d.devices.deleteDevice(ctx, *deviceID, localpart); err != nil {
|
||||||
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
return nil, err
|
||||||
return err
|
}
|
||||||
}
|
|
||||||
|
|
||||||
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
dev, err = d.devices.insertDevice(ctx, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||||
return err
|
return dev, err
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
// We generate device IDs in a loop in case its already taken.
|
// We generate device IDs in a loop in case its already taken.
|
||||||
// We cap this at going round 5 times to ensure we don't spin forever
|
// We cap this at going round 5 times to ensure we don't spin forever
|
||||||
|
|
@ -124,11 +121,9 @@ func (d *Database) CreateDevice(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
var err error
|
||||||
var err error
|
dev, err = d.devices.insertDevice(ctx, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||||
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
return dev, err
|
||||||
return err
|
|
||||||
})
|
|
||||||
if returnErr == nil {
|
if returnErr == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -154,9 +149,7 @@ func generateDeviceID() (string, error) {
|
||||||
func (d *Database) UpdateDevice(
|
func (d *Database) UpdateDevice(
|
||||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||||
) error {
|
) error {
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
return d.devices.updateDeviceName(ctx, localpart, deviceID, displayName)
|
||||||
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveDevice revokes a device by deleting the entry in the database
|
// RemoveDevice revokes a device by deleting the entry in the database
|
||||||
|
|
@ -166,12 +159,10 @@ func (d *Database) UpdateDevice(
|
||||||
func (d *Database) RemoveDevice(
|
func (d *Database) RemoveDevice(
|
||||||
ctx context.Context, deviceID, localpart string,
|
ctx context.Context, deviceID, localpart string,
|
||||||
) error {
|
) error {
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
if err := d.devices.deleteDevice(ctx, deviceID, localpart); err != nil {
|
||||||
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
return err
|
||||||
return err
|
}
|
||||||
}
|
return nil
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
||||||
|
|
@ -181,12 +172,10 @@ func (d *Database) RemoveDevice(
|
||||||
func (d *Database) RemoveDevices(
|
func (d *Database) RemoveDevices(
|
||||||
ctx context.Context, localpart string, devices []string,
|
ctx context.Context, localpart string, devices []string,
|
||||||
) error {
|
) error {
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
if err := d.devices.deleteDevices(ctx, localpart, devices); err != nil {
|
||||||
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
return err
|
||||||
return err
|
}
|
||||||
}
|
return nil
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveAllDevices revokes devices by deleting the entry in the
|
// RemoveAllDevices revokes devices by deleting the entry in the
|
||||||
|
|
@ -195,22 +184,17 @@ func (d *Database) RemoveDevices(
|
||||||
func (d *Database) RemoveAllDevices(
|
func (d *Database) RemoveAllDevices(
|
||||||
ctx context.Context, localpart, exceptDeviceID string,
|
ctx context.Context, localpart, exceptDeviceID string,
|
||||||
) (devices []api.Device, err error) {
|
) (devices []api.Device, err error) {
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
devices, err = d.devices.selectDevicesByLocalpart(ctx, localpart, exceptDeviceID)
|
||||||
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
if err != nil {
|
||||||
if err != nil {
|
return nil, err
|
||||||
return err
|
}
|
||||||
}
|
if err := d.devices.deleteDevicesByLocalpart(ctx, localpart, exceptDeviceID); err != nil {
|
||||||
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
return nil, err
|
||||||
return err
|
}
|
||||||
}
|
return devices, nil
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
|
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
|
||||||
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
|
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
return d.devices.updateDeviceLastSeen(ctx, localpart, deviceID, ipAddr)
|
||||||
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue