diff --git a/.vscode/launch.json b/.vscode/launch.json index 09db3f073..fcb46daa2 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -28,6 +28,19 @@ "${workspaceFolder}\\bin\\dendrite.yaml", "clientapi", ] - } + }, + { + "name": "Launch Package Monolith - CosmosDB", + "type": "go", + "request": "launch", + "mode": "debug", + "program": "${workspaceFolder}\\cmd\\dendrite-monolith-server", + "args": [ + "-config", + "${workspaceFolder}\\dendrite-config-cosmosdb.yaml", + //Uncomment below to expose internal api's + // "--api", + // "true" + ]} ] } \ No newline at end of file diff --git a/dendrite-config-cosmosdb.yaml b/dendrite-config-cosmosdb.yaml index 4f61b5362..f3be7dcad 100644 --- a/dendrite-config-cosmosdb.yaml +++ b/dendrite-config-cosmosdb.yaml @@ -90,7 +90,7 @@ global: # Naffka database options. Not required when using Kafka. naffka_database: - connection_string: cosmosdb:naffka.db + connection_string: file:naffka.db max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 @@ -122,7 +122,7 @@ app_service_api: listen: http://localhost:7777 connect: http://localhost:7777 database: - connection_string: cosmosdb:appservice.db + connection_string: file:appservice.db max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 @@ -202,7 +202,7 @@ federation_sender: listen: http://localhost:7775 connect: http://localhost:7775 database: - connection_string: cosmosdb:federationsender.db + connection_string: file:federationsender.db max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 @@ -228,7 +228,7 @@ key_server: listen: http://localhost:7779 connect: http://localhost:7779 database: - connection_string: cosmosdb:keyserver.db + connection_string: file:keyserver.db max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 @@ -241,7 +241,7 @@ media_api: external_api: listen: http://[::]:8074 database: - connection_string: cosmosdb:mediaapi.db + connection_string: file:mediaapi.db max_open_conns: 5 max_idle_conns: 2 conn_max_lifetime: -1 @@ -280,7 +280,7 @@ mscs: # - msc2946 (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946) mscs: [] database: - connection_string: cosmosdb:mscs.db + connection_string: file:mscs.db max_open_conns: 5 max_idle_conns: 2 conn_max_lifetime: -1 @@ -291,7 +291,7 @@ room_server: listen: http://localhost:7770 connect: http://localhost:7770 database: - connection_string: cosmosdb:roomserver.db + connection_string: file:roomserver.db max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 @@ -302,7 +302,7 @@ signing_key_server: listen: http://localhost:7780 connect: http://localhost:7780 database: - connection_string: cosmosdb:signingkeyserver.db + connection_string: file:signingkeyserver.db max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 @@ -331,7 +331,7 @@ sync_api: external_api: listen: http://[::]:8073 database: - connection_string: cosmosdb:syncapi.db + connection_string: file:syncapi.db max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 @@ -354,12 +354,12 @@ user_api: listen: http://localhost:7781 connect: http://localhost:7781 account_database: - connection_string: cosmosdb:userapi_accounts.db + connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==" max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 device_database: - connection_string: cosmosdb:userapi_devices.db + connection_string: file:userapi_devices.db max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 diff --git a/go.mod b/go.mod index a3d80f1b1..9b8603963 100644 --- a/go.mod +++ b/go.mod @@ -37,6 +37,7 @@ require ( github.com/tidwall/sjson v1.1.5 github.com/uber/jaeger-client-go v2.25.0+incompatible github.com/uber/jaeger-lib v2.4.0+incompatible + github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d // indirect github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20210218094457-e77ca8019daa go.uber.org/atomic v1.7.0 golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 diff --git a/go.sum b/go.sum index 90b5527c8..cefb07724 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,7 @@ github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/ github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= +github.com/alecthomas/repr v0.0.0-20181024024818-d37bc2a10ba1/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= @@ -176,6 +177,7 @@ github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/me github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= +github.com/gofrs/uuid v3.1.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= @@ -987,6 +989,8 @@ github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPU github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= +github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d h1:MZRYOouO0snrQyBAf4Wljc3qqaispjzMOhFRQgWfKMo= +github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d/go.mod h1:ldPlejlc7ZyiP0QQWGwL9CoZLvEjhD9yzpz0ct7+sXo= github.com/vishvananda/netlink v1.0.0/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk= github.com/vishvananda/netns v0.0.0-20190625233234-7109fa855b0f/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI= github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 h1:EKhdznlJHPMoKr0XTrX+IlJs1LH3lyx2nfr1dOlZ79k= diff --git a/internal/cosmosdbapi/client.go b/internal/cosmosdbapi/client.go new file mode 100644 index 000000000..eb99795b5 --- /dev/null +++ b/internal/cosmosdbapi/client.go @@ -0,0 +1,24 @@ +package cosmosdbapi + +import ( + cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi" +) + +type CosmosConnection struct { + Url string + Key string +} + +func GetCosmosConnection(accountEndpoint string, accountKey string) CosmosConnection { + return CosmosConnection{ + Url: accountEndpoint, + Key: accountKey, + } +} + +func GetClient(conn CosmosConnection) *cosmosapi.Client { + cfg := cosmosapi.Config{ + MasterKey: conn.Key, + } + return cosmosapi.New(conn.Url, cfg, nil, nil) +} diff --git a/internal/cosmosdbapi/collection.go b/internal/cosmosdbapi/collection.go new file mode 100644 index 000000000..5c8562602 --- /dev/null +++ b/internal/cosmosdbapi/collection.go @@ -0,0 +1,10 @@ +package cosmosdbapi + +import ( + "fmt" + +) + +func GetCollectionName(databaseName string, tableName string) string { + return fmt.Sprintf("matrix_%s_%s", databaseName, tableName) +} \ No newline at end of file diff --git a/internal/cosmosdbapi/document.go b/internal/cosmosdbapi/document.go new file mode 100644 index 000000000..9e419fc52 --- /dev/null +++ b/internal/cosmosdbapi/document.go @@ -0,0 +1,14 @@ +package cosmosdbapi + +import ( + "fmt" + +) + +func GetDocumentId(tenantName string, collectionName string, id string) string { + return fmt.Sprintf("%s,%s,%s", collectionName, tenantName, id) +} + +func GetPartitionKey(tenantName string, collectionName string) string { + return fmt.Sprintf("%s,%s", collectionName, tenantName) +} \ No newline at end of file diff --git a/internal/cosmosdbapi/documentoperations.go b/internal/cosmosdbapi/documentoperations.go new file mode 100644 index 000000000..37e8ea883 --- /dev/null +++ b/internal/cosmosdbapi/documentoperations.go @@ -0,0 +1,46 @@ +package cosmosdbapi + +import ( + cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi" +) + +func GetCreateDocumentOptions(pk string) cosmosapi.CreateDocumentOptions { + return cosmosapi.CreateDocumentOptions{ + IsUpsert: false, + PartitionKeyValue: pk, + } +} + +func GetUpsertDocumentOptions(pk string) cosmosapi.CreateDocumentOptions { + return cosmosapi.CreateDocumentOptions{ + IsUpsert: true, + PartitionKeyValue: pk, + } +} + +func GetQueryDocumentsOptions(pk string) cosmosapi.QueryDocumentsOptions { + return cosmosapi.QueryDocumentsOptions{ + PartitionKeyValue: pk, + IsQuery: true, + ContentType: cosmosapi.QUERY_CONTENT_TYPE, + } +} + +func GetGetDocumentOptions(pk string) cosmosapi.GetDocumentOptions { + return cosmosapi.GetDocumentOptions{ + PartitionKeyValue: pk, + } +} + +func GetReplaceDocumentOptions(pk string, etag string) cosmosapi.ReplaceDocumentOptions { + return cosmosapi.ReplaceDocumentOptions{ + PartitionKeyValue: pk, + IfMatch: etag, + } +} + +func GetDeleteDocumentOptions(pk string) cosmosapi.DeleteDocumentOptions { + return cosmosapi.DeleteDocumentOptions{ + PartitionKeyValue: pk, + } +} diff --git a/internal/cosmosdbapi/query.go b/internal/cosmosdbapi/query.go new file mode 100644 index 000000000..29e46be23 --- /dev/null +++ b/internal/cosmosdbapi/query.go @@ -0,0 +1,20 @@ +package cosmosdbapi + +import ( + cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi" +) + +func GetQuery(qry string, params map[string]interface{}) cosmosapi.Query { + qryParams := []cosmosapi.QueryParam{} + for key, value := range params { + qryParam := cosmosapi.QueryParam { + Name: key, + Value: value, + } + qryParams = append(qryParams, qryParam) + } + return cosmosapi.Query { + Query: qry, + Params: qryParams, + } +} \ No newline at end of file diff --git a/internal/cosmosdbapi/tenant.go b/internal/cosmosdbapi/tenant.go new file mode 100644 index 000000000..d9cb825f2 --- /dev/null +++ b/internal/cosmosdbapi/tenant.go @@ -0,0 +1,14 @@ +package cosmosdbapi + +type Tenant struct { + DatabaseName string + TenantName string +} + +//TODO: Move into Config or the JWT +func DefaultConfig() Tenant { + return Tenant{ + DatabaseName: "safezone_local", + TenantName: "criticalarc.com", + } +} diff --git a/internal/cosmosdbutil/connection.go b/internal/cosmosdbutil/connection.go index 6e96cb9d7..d8b6527a9 100644 --- a/internal/cosmosdbutil/connection.go +++ b/internal/cosmosdbutil/connection.go @@ -8,5 +8,15 @@ import ( func GetConnectionString(d *config.DataSource) config.DataSource { var connString string connString = string(*d) - return config.DataSource(strings.Replace(connString, "cosmosdb:", "file:", 1)) + return config.DataSource(strings.Replace(connString, "cosmosdb:", "", 1)) +} + +func GetConnectionProperties(connectionString string) map[string]string { + connectionItemsRaw := strings.Split(connectionString, ";") + connectionItems := map[string]string{} + for _, item := range connectionItemsRaw { + itemSplit := strings.SplitN(item, "=", 2) + connectionItems[itemSplit[0]] = itemSplit[1] + } + return connectionItems } \ No newline at end of file diff --git a/userapi/storage/accounts/cosmosdb/account_data_table.go b/userapi/storage/accounts/cosmosdb/account_data_table.go index 916d28735..c73a98c93 100644 --- a/userapi/storage/accounts/cosmosdb/account_data_table.go +++ b/userapi/storage/accounts/cosmosdb/account_data_table.go @@ -16,68 +16,94 @@ package cosmosdb import ( "context" - "database/sql" "encoding/json" + "fmt" + "time" - "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" ) -const accountDataSchema = ` --- Stores data about accounts data. -CREATE TABLE IF NOT EXISTS account_data ( - -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL, - -- The room ID for this data (empty string if not specific to a room) - room_id TEXT, - -- The account data type - type TEXT NOT NULL, - -- The account data content - content TEXT NOT NULL, +// const accountDataSchema = ` +// -- Stores data about accounts data. +// CREATE TABLE IF NOT EXISTS account_data ( +// -- The Matrix user ID localpart for this account +// localpart TEXT NOT NULL, +// -- The room ID for this data (empty string if not specific to a room) +// room_id TEXT, +// -- The account data type +// type TEXT NOT NULL, +// -- The account data content +// content TEXT NOT NULL, - PRIMARY KEY(localpart, room_id, type) -); -` +// PRIMARY KEY(localpart, room_id, type) +// ); +// ` -const insertAccountDataSQL = ` - INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) - ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 -` - -const selectAccountDataSQL = "" + - "SELECT room_id, type, content FROM account_data WHERE localpart = $1" - -const selectAccountDataByTypeSQL = "" + - "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" - -type accountDataStatements struct { - db *sql.DB - insertAccountDataStmt *sql.Stmt - selectAccountDataStmt *sql.Stmt - selectAccountDataByTypeStmt *sql.Stmt +type AccountCosmosAccountData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Object AccountData `json:"_object"` } -func (s *accountDataStatements) prepare(db *sql.DB) (err error) { +type AccountData struct { + LocalPart string `json:"local_part"` + RoomId string `json:"room_id"` + Type string `json:"type"` + Content []byte `json:"content"` +} + +type accountDataStatements struct { + db *Database + tableName string +} + +func (s *accountDataStatements) prepare(db *Database) (err error) { s.db = db - _, err = db.Exec(accountDataSchema) - if err != nil { - return - } - if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { - return - } - if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil { - return - } - if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil { - return - } + s.tableName = "account_data" return } func (s *accountDataStatements) insertAccountData( - ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, + ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ) error { - _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) + + // INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) + // ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 + var result = AccountData{ + LocalPart: localpart, + RoomId: roomID, + Type: dataType, + Content: content, + } + + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName) + id := "" + if roomID == "" { + id = fmt.Sprintf("%s_%s", result.LocalPart, result.Type) + } else { + id = fmt.Sprintf("%s_%s_%s", result.LocalPart, result.RoomId, result.Type) + } + + var dbData = AccountCosmosAccountData{ + Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id), + Cn: dbCollectionName, + Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), + Timestamp: time.Now().Unix(), + Object: result, + } + + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + config.DatabaseName, + config.TenantName, + dbData, + options) + return err } @@ -88,30 +114,43 @@ func (s *accountDataStatements) selectAccountData( /* rooms */ map[string]map[string]json.RawMessage, error, ) { - rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) - if err != nil { - return nil, nil, err + // "SELECT room_id, type, content FROM account_data WHERE localpart = $1" + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName) + var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + response := []AccountCosmosAccountData{} + var selectAccountDataCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2" + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": localpart, + } + var options = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(selectAccountDataCosmos, params) + var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + config.DatabaseName, + config.TenantName, + query, + &response, + options) + + if ex != nil { + return nil, nil, ex } global := map[string]json.RawMessage{} rooms := map[string]map[string]json.RawMessage{} - for rows.Next() { - var roomID string - var dataType string - var content []byte - - if err = rows.Scan(&roomID, &dataType, &content); err != nil { - return nil, nil, err - } - + for i := 0; i < len(response); i++ { + var row = response[i] + var roomID = row.Object.RoomId if roomID != "" { - if _, ok := rooms[roomID]; !ok { + if _, ok := rooms[row.Object.RoomId]; !ok { rooms[roomID] = map[string]json.RawMessage{} } - rooms[roomID][dataType] = content + rooms[roomID][row.Object.Type] = row.Object.Content } else { - global[dataType] = content + global[row.Object.Type] = row.Object.Content } } @@ -122,13 +161,39 @@ func (s *accountDataStatements) selectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte - stmt := s.selectAccountDataByTypeStmt - if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return + + // "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName) + var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + response := []AccountCosmosAccountData{} + var selectAccountDataCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2 and c._object.room_id = @x3 and c._object.type = @x4" + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": localpart, + "@x3": roomID, + "@x4": dataType, } + var options = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(selectAccountDataCosmos, params) + var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + config.DatabaseName, + config.TenantName, + query, + &response, + options) + + if ex != nil { + return nil, ex + } + + if len(response) == 0 { + return data, nil + } + + bytes = response[0].Object.Content + data = json.RawMessage(bytes) return } diff --git a/userapi/storage/accounts/cosmosdb/accounts_table.go b/userapi/storage/accounts/cosmosdb/accounts_table.go index f871c1830..076068d1f 100644 --- a/userapi/storage/accounts/cosmosdb/accounts_table.go +++ b/userapi/storage/accounts/cosmosdb/accounts_table.go @@ -16,159 +16,264 @@ package cosmosdb import ( "context" - "database/sql" + "errors" + "fmt" "time" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/clientapi/userutil" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" - - log "github.com/sirupsen/logrus" ) -const accountsSchema = ` --- Stores data about accounts. -CREATE TABLE IF NOT EXISTS account_accounts ( - -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, - -- When this account was first created, as a unix timestamp (ms resolution). - created_ts BIGINT NOT NULL, - -- The password hash for this account. Can be NULL if this is a passwordless account. - password_hash TEXT, - -- Identifies which application service this account belongs to, if any. - appservice_id TEXT, - -- If the account is currently active - is_deactivated BOOLEAN DEFAULT 0 - -- TODO: - -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? -); -` +// const accountsSchema = ` +// -- Stores data about accounts. +// CREATE TABLE IF NOT EXISTS account_accounts ( +// -- The Matrix user ID localpart for this account +// localpart TEXT NOT NULL PRIMARY KEY, +// -- When this account was first created, as a unix timestamp (ms resolution). +// created_ts BIGINT NOT NULL, +// -- The password hash for this account. Can be NULL if this is a passwordless account. +// password_hash TEXT, +// -- Identifies which application service this account belongs to, if any. +// appservice_id TEXT, +// -- If the account is currently active +// is_deactivated BOOLEAN DEFAULT 0 +// -- TODO: +// -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? +// ); +// ` -const insertAccountSQL = "" + - "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" +type AccountExtended struct { + IsDeactivated bool `json:"is_deactivated"` + PasswordHash string `json:"password_hash"` + Created int64 `json:"created_ts"` +} -const updatePasswordSQL = "" + - "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" +type AccountCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Object api.Account `json:"_object"` + ObjectExtended AccountExtended `json:"_object_extended"` +} -const deactivateAccountSQL = "" + - "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" - -const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" - -const selectPasswordHashSQL = "" + - "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" - -const selectNewNumericLocalpartSQL = "" + - "SELECT COUNT(localpart) FROM account_accounts" +type AccountCosmosUserCount struct { + UserCount int64 `json:"usercount"` +} type accountsStatements struct { - db *sql.DB - insertAccountStmt *sql.Stmt - updatePasswordStmt *sql.Stmt - deactivateAccountStmt *sql.Stmt - selectAccountByLocalpartStmt *sql.Stmt - selectPasswordHashStmt *sql.Stmt - selectNewNumericLocalpartStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + db *Database + tableName string + serverName gomatrixserverlib.ServerName } -func (s *accountsStatements) execSchema(db *sql.DB) error { - _, err := db.Exec(accountsSchema) - return err -} - -func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { +func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) { s.db = db - if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { - return - } - if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil { - return - } - if s.deactivateAccountStmt, err = db.Prepare(deactivateAccountSQL); err != nil { - return - } - if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { - return - } - if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil { - return - } - if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil { - return - } + s.tableName = "account_accounts" s.serverName = server return } +func getAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*AccountCosmosData, error) { + response := AccountCosmosData{} + var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk) + var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument( + ctx, + config.DatabaseName, + config.TenantName, + docId, + optionsGet, + &response) + return &response, ex +} + +func setAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, account AccountCosmosData) (*AccountCosmosData, error) { + response := AccountCosmosData{} + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, account.ETag) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( + ctx, + config.DatabaseName, + config.TenantName, + account.Id, + &account, + optionsReplace) + return &response, ex +} + // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) insertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, + ctx context.Context, localpart, hash, appserviceID string, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 - stmt := s.insertAccountStmt + // stmt := s.insertAccountStmt - var err error - if appserviceID == "" { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) - } else { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) - } - if err != nil { - return nil, err - } - - return &api.Account{ + var result = api.Account{ Localpart: localpart, UserID: userutil.MakeUserID(localpart, s.serverName), ServerName: s.serverName, AppServiceID: appserviceID, - }, nil + } + + var extended = AccountExtended{ + IsDeactivated: false, + PasswordHash: hash, + Created: createdTimeMS, + } + + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) + + var dbData = AccountCosmosData{ + Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart), + Cn: dbCollectionName, + Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), + Timestamp: time.Now().Unix(), + Object: result, + ObjectExtended: extended, + } + + var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + config.DatabaseName, + config.TenantName, + dbData, + options) + + if err != nil { + return nil, err + } + + return &result, nil } func (s *accountsStatements) updatePassword( ctx context.Context, localpart, passwordHash string, ) (err error) { - _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + + // "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) + var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart) + var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + + var response, exGet = getAccount(s, ctx, config, pk, docId) + if exGet != nil { + return exGet + } + + response.ObjectExtended.PasswordHash = passwordHash + + var _, exReplace = setAccount(s, ctx, config, pk, *response) + if exReplace != nil { + return exReplace + } return } func (s *accountsStatements) deactivateAccount( ctx context.Context, localpart string, ) (err error) { - _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) + + // "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) + var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart) + var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + + var response, exGet = getAccount(s, ctx, config, pk, docId) + if exGet != nil { + return exGet + } + + response.ObjectExtended.IsDeactivated = true + + var _, exReplace = setAccount(s, ctx, config, pk, *response) + if exReplace != nil { + return exReplace + } return } func (s *accountsStatements) selectPasswordHash( ctx context.Context, localpart string, ) (hash string, err error) { - err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) - return + + // "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) + var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + response := []AccountCosmosData{} + var selectPasswordHashCosmos = "select * from c where c._cn = @x1 and c._object.Localpart = @x2 and c._object_extended.is_deactivated = false" + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": localpart, + } + var options = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(selectPasswordHashCosmos, params) + var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + config.DatabaseName, + config.TenantName, + query, + &response, + options) + + if ex != nil { + return "", ex + } + + if len(response) == 0 { + return "", errors.New(fmt.Sprintf("Localpart %s not found", localpart)) + } + + if len(response) != 1 { + return "", errors.New(fmt.Sprintf("Localpart %s has multiple entries", localpart)) + } + + return response[0].ObjectExtended.PasswordHash, nil } func (s *accountsStatements) selectAccountByLocalpart( ctx context.Context, localpart string, ) (*api.Account, error) { - var appserviceIDPtr sql.NullString var acc api.Account - stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) - if err != nil { - if err != sql.ErrNoRows { - log.WithError(err).Error("Unable to retrieve user from the db") - } - return nil, err + // "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) + var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + response := []AccountCosmosData{} + var selectPasswordHashCosmos = "select * from c where c._cn = @x1 and c._object.Localpart = @x2" + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": localpart, } - if appserviceIDPtr.Valid { - acc.AppServiceID = appserviceIDPtr.String + var options = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(selectPasswordHashCosmos, params) + var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + config.DatabaseName, + config.TenantName, + query, + &response, + options) + + if ex != nil { + return nil, ex } + if len(response) == 0 { + return nil, nil + } + + acc = response[0].Object acc.UserID = userutil.MakeUserID(localpart, s.serverName) acc.ServerName = s.serverName @@ -176,12 +281,31 @@ func (s *accountsStatements) selectAccountByLocalpart( } func (s *accountsStatements) selectNewNumericLocalpart( - ctx context.Context, txn *sql.Tx, + ctx context.Context, ) (id int64, err error) { - stmt := s.selectNewNumericLocalpartStmt - if txn != nil { - stmt = sqlutil.TxStmt(txn, stmt) + + // "SELECT COUNT(localpart) FROM account_accounts" + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) + var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + var response []AccountCosmosUserCount + var selectCountCosmos = "select count(c._ts) as usercount from c where c._cn = @x1" + params := map[string]interface{}{ + "@x1": dbCollectionName, } - err = stmt.QueryRowContext(ctx).Scan(&id) - return + var options = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(selectCountCosmos, params) + var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + config.DatabaseName, + config.TenantName, + query, + &response, + options) + + if ex != nil { + return -1, ex + } + + return int64(response[0].UserCount), nil } diff --git a/userapi/storage/accounts/cosmosdb/openid_table.go b/userapi/storage/accounts/cosmosdb/openid_table.go index c5bab0308..9224dd134 100644 --- a/userapi/storage/accounts/cosmosdb/openid_table.go +++ b/userapi/storage/accounts/cosmosdb/openid_table.go @@ -1,13 +1,12 @@ package cosmosdb import ( + "time" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" "context" - "database/sql" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" - log "github.com/sirupsen/logrus" ) const openIDTokenSchema = ` @@ -21,32 +20,24 @@ CREATE TABLE IF NOT EXISTS open_id_tokens ( token_expires_at_ms BIGINT NOT NULL ); ` - -const insertTokenSQL = "" + - "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" - -const selectTokenSQL = "" + - "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" - -type tokenStatements struct { - db *sql.DB - insertTokenStmt *sql.Stmt - selectTokenStmt *sql.Stmt - serverName gomatrixserverlib.ServerName +type OpenIdTokenCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Object *api.OpenIDToken `json:"_object"` } -func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { +type tokenStatements struct { + db *Database + tableName string + serverName gomatrixserverlib.ServerName +} + +func (s *tokenStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) { s.db = db - _, err = db.Exec(openIDTokenSchema) - if err != nil { - return err - } - if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil { - return - } - if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil { - return - } + s.tableName = "open_id_tokens" s.serverName = server return } @@ -55,12 +46,40 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam // Returns new token, otherwise returns error if the token already exists. func (s *tokenStatements) insertToken( ctx context.Context, - txn *sql.Tx, token, localpart string, expiresAtMS int64, ) (err error) { - stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) - _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) + + // "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + var result = &api.OpenIDToken{ + UserID: localpart, + Token: token, + ExpiresAtMS: expiresAtMS, + } + + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) + + var dbData = OpenIdTokenCosmosData{ + Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Token), + Cn: dbCollectionName, + Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), + Timestamp: time.Now().Unix(), + Object: result, + } + + var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + config.DatabaseName, + config.TenantName, + dbData, + options) + + if ex != nil { + return ex + } + return } @@ -71,16 +90,39 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes( token string, ) (*api.OpenIDTokenAttributes, error) { var openIDTokenAttrs api.OpenIDTokenAttributes - err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( - &openIDTokenAttrs.UserID, - &openIDTokenAttrs.ExpiresAtMS, - ) - if err != nil { - if err != sql.ErrNoRows { - log.WithError(err).Error("Unable to retrieve token from the db") - } - return nil, err + + // "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) + var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + response := []OpenIdTokenCosmosData{} + var selectOpenIdTokenCosmos = "select * from c where c._cn = @x1 and c._object.Token = @x2" + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": token, + } + var options = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(selectOpenIdTokenCosmos, params) + var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + config.DatabaseName, + config.TenantName, + query, + &response, + options) + + if ex != nil { + return nil, ex } + if(len(response) == 0) { + return nil, nil + } + + var openIdToken = response[0].Object + openIDTokenAttrs = api.OpenIDTokenAttributes{ + UserID: openIdToken.UserID, + ExpiresAtMS: openIdToken.ExpiresAtMS, + } return &openIDTokenAttrs, nil } diff --git a/userapi/storage/accounts/cosmosdb/profile_table.go b/userapi/storage/accounts/cosmosdb/profile_table.go index 73ffec031..78a815cad 100644 --- a/userapi/storage/accounts/cosmosdb/profile_table.go +++ b/userapi/storage/accounts/cosmosdb/profile_table.go @@ -16,107 +16,186 @@ package cosmosdb import ( "context" - "database/sql" + "errors" "fmt" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" ) -const profilesSchema = ` --- Stores data about accounts profiles. -CREATE TABLE IF NOT EXISTS account_profiles ( - -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, - -- The display name for this account - display_name TEXT, - -- The URL of the avatar for this account - avatar_url TEXT -); -` +// const profilesSchema = ` +// -- Stores data about accounts profiles. +// CREATE TABLE IF NOT EXISTS account_profiles ( +// -- The Matrix user ID localpart for this account +// localpart TEXT NOT NULL PRIMARY KEY, +// -- The display name for this account +// display_name TEXT, +// -- The URL of the avatar for this account +// avatar_url TEXT +// ); +// ` -const insertProfileSQL = "" + - "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" - -const selectProfileByLocalpartSQL = "" + - "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" - -const setAvatarURLSQL = "" + - "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" - -const setDisplayNameSQL = "" + - "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" - -const selectProfilesBySearchSQL = "" + - "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" - -type profilesStatements struct { - db *sql.DB - insertProfileStmt *sql.Stmt - selectProfileByLocalpartStmt *sql.Stmt - setAvatarURLStmt *sql.Stmt - setDisplayNameStmt *sql.Stmt - selectProfilesBySearchStmt *sql.Stmt +type ProfileCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Object authtypes.Profile `json:"_object"` } -func (s *profilesStatements) prepare(db *sql.DB) (err error) { +type profilesStatements struct { + db *Database + tableName string +} + +func (s *profilesStatements) prepare(db *Database) (err error) { s.db = db - _, err = db.Exec(profilesSchema) - if err != nil { - return - } - if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil { - return - } - if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil { - return - } - if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil { - return - } - if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { - return - } - if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil { - return - } + s.tableName = "account_profiles" return } +func getProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, docId string) (*ProfileCosmosData, error) { + response := ProfileCosmosData{} + var optionsGet = cosmosdbapi.GetGetDocumentOptions(pk) + var _, ex = cosmosdbapi.GetClient(s.db.connection).GetDocument( + ctx, + config.DatabaseName, + config.TenantName, + docId, + optionsGet, + &response) + return &response, ex +} + +func setProfile(s *profilesStatements, ctx context.Context, config cosmosdbapi.Tenant, pk string, profile ProfileCosmosData) (*ProfileCosmosData, error) { + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, profile.ETag) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( + ctx, + config.DatabaseName, + config.TenantName, + profile.Id, + &profile, + optionsReplace) + return &profile, ex +} + func (s *profilesStatements) insertProfile( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, localpart string, ) error { - _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + + // "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + var result = &authtypes.Profile{ + Localpart: localpart, + } + + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) + + var dbData = ProfileCosmosData{ + Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart), + Cn: dbCollectionName, + Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), + Timestamp: time.Now().Unix(), + Object: *result, + } + + var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + config.DatabaseName, + config.TenantName, + dbData, + options) + return err } func (s *profilesStatements) selectProfileByLocalpart( ctx context.Context, localpart string, ) (*authtypes.Profile, error) { - var profile authtypes.Profile - err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( - &profile.Localpart, &profile.DisplayName, &profile.AvatarURL, - ) - if err != nil { - return nil, err + + // "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) + var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + response := []ProfileCosmosData{} + var selectProfileByLocalpartCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2" + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": localpart, } - return &profile, nil + var options = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(selectProfileByLocalpartCosmos, params) + var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + config.DatabaseName, + config.TenantName, + query, + &response, + options) + + if ex != nil { + return nil, ex + } + + if len(response) == 0 { + return nil, errors.New(fmt.Sprintf("Localpart %s not found", len(response))) + } + + if len(response) != 1 { + return nil, errors.New(fmt.Sprintf("Localpart %s has multiple entries", len(response))) + } + + return &response[0].Object, nil } func (s *profilesStatements) setAvatarURL( - ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, + ctx context.Context, localpart string, avatarURL string, ) (err error) { - stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) - _, err = stmt.ExecContext(ctx, avatarURL, localpart) + + // "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) + var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart) + + var response, exGet = getProfile(s, ctx, config, pk, docId) + if exGet != nil { + return exGet + } + + response.Object.AvatarURL = avatarURL + + var _, exReplace = setProfile(s, ctx, config, pk, *response) + if exReplace != nil { + return exReplace + } return } func (s *profilesStatements) setDisplayName( - ctx context.Context, txn *sql.Tx, localpart string, displayName string, + ctx context.Context, localpart string, displayName string, ) (err error) { - stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) - _, err = stmt.ExecContext(ctx, displayName, localpart) + + // "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) + var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + var docId = cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, localpart) + var response, exGet = getProfile(s, ctx, config, pk, docId) + if exGet != nil { + return exGet + } + + response.Object.DisplayName = displayName + + var _, exReplace = setProfile(s, ctx, config, pk, *response) + if exReplace != nil { + return exReplace + } return } @@ -124,20 +203,36 @@ func (s *profilesStatements) selectProfilesBySearch( ctx context.Context, searchString string, limit int, ) ([]authtypes.Profile, error) { var profiles []authtypes.Profile - // The fmt.Sprintf directive below is building a parameter for the - // "LIKE" condition in the SQL query. %% escapes the % char, so the - // statement in the end will look like "LIKE %searchString%". - rows, err := s.selectProfilesBySearchStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit) - if err != nil { - return nil, err + + // "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) + var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + response := []ProfileCosmosData{} + var selectProfileByLocalpartCosmos = "select top @x3 * from c where c._cn = @x1 and contains(c._object.local_part, @x2)" + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": searchString, + "@x3": limit, } - defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") - for rows.Next() { - var profile authtypes.Profile - if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { - return nil, err - } - profiles = append(profiles, profile) + var options = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(selectProfileByLocalpartCosmos, params) + var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + config.DatabaseName, + config.TenantName, + query, + &response, + options) + + if ex != nil { + return nil, ex } + + for i := 0; i < len(response); i++ { + var responseData = response[i] + profiles = append(profiles, responseData.Object) + } + return profiles, nil } diff --git a/userapi/storage/accounts/cosmosdb/storage.go b/userapi/storage/accounts/cosmosdb/storage.go index 2e9f2888d..3fc2e7d88 100644 --- a/userapi/storage/accounts/cosmosdb/storage.go +++ b/userapi/storage/accounts/cosmosdb/storage.go @@ -15,29 +15,27 @@ package cosmosdb import ( - "github.com/matrix-org/dendrite/internal/cosmosdbutil" "context" - "database/sql" "encoding/json" "errors" "strconv" - "sync" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" + + // "sync" "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas" "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" ) // Database represents an account database type Database struct { - db *sql.DB - writer sqlutil.Writer - sqlutil.PartitionOffsetStatements accounts accountsStatements profiles profilesStatements @@ -48,55 +46,57 @@ type Database struct { bcryptCost int openIDTokenLifetimeMS int64 - accountsMu sync.Mutex - profilesMu sync.Mutex - accountDatasMu sync.Mutex - threepidsMu sync.Mutex + databaseName string + connection cosmosdbapi.CosmosConnection } // NewDatabase creates a new accounts and profiles database func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { - dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) - db, err := sqlutil.Open(dbProperties) - if err != nil { - return nil, err - } + connString := cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString) + connMap := cosmosdbutil.GetConnectionProperties(string(connString)) + accountEndpoint := connMap["AccountEndpoint"] + accountKey := connMap["AccountKey"] + conn := cosmosdbapi.GetCosmosConnection(accountEndpoint, accountKey) + d := &Database{ - serverName: serverName, - db: db, - writer: sqlutil.NewExclusiveWriter(), - bcryptCost: bcryptCost, - openIDTokenLifetimeMS: openIDTokenLifetimeMS, + serverName: serverName, + databaseName: "userapi", + connection: conn, + // db: db, + // writer: sqlutil.NewExclusiveWriter(), + // bcryptCost: bcryptCost, + // openIDTokenLifetimeMS: openIDTokenLifetimeMS, } // Create tables before executing migrations so we don't fail if the table is missing, // and THEN prepare statements so we don't fail due to referencing new columns - if err = d.accounts.execSchema(db); err != nil { - return nil, err - } - m := sqlutil.NewMigrations() - deltas.LoadIsActive(m) - if err = m.RunDeltas(db, dbProperties); err != nil { - return nil, err - } + // if err = d.accounts.execSchema(db); err != nil { + // return nil, err + // } + // m := sqlutil.NewMigrations() + // deltas.LoadIsActive(m) + // if err = m.RunDeltas(db, dbProperties); err != nil { + // return nil, err + // } - partitions := sqlutil.PartitionOffsetStatements{} - if err = partitions.Prepare(db, d.writer, "account"); err != nil { + // partitions := sqlutil.PartitionOffsetStatements{} + // if err = partitions.Prepare(db, d.writer, "account"); err != nil { + // return nil, err + // } + var err error + if err = d.accounts.prepare(d, serverName); err != nil { return nil, err } - if err = d.accounts.prepare(db, serverName); err != nil { + if err = d.profiles.prepare(d); err != nil { return nil, err } - if err = d.profiles.prepare(db); err != nil { + if err = d.accountDatas.prepare(d); err != nil { return nil, err } - if err = d.accountDatas.prepare(db); err != nil { + if err = d.threepids.prepare(d); err != nil { return nil, err } - if err = d.threepids.prepare(db); err != nil { - return nil, err - } - if err = d.openIDTokens.prepare(db, serverName); err != nil { + if err = d.openIDTokens.prepare(d, serverName); err != nil { return nil, err } @@ -131,11 +131,11 @@ func (d *Database) GetProfileByLocalpart( func (d *Database) SetAvatarURL( ctx context.Context, localpart string, avatarURL string, ) error { - d.profilesMu.Lock() - defer d.profilesMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL) - }) + // d.profilesMu.Lock() + // defer d.profilesMu.Unlock() + // return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + // }) + return d.profiles.setAvatarURL(ctx, localpart, avatarURL) } // SetDisplayName updates the display name of the profile associated with the given @@ -143,11 +143,12 @@ func (d *Database) SetAvatarURL( func (d *Database) SetDisplayName( ctx context.Context, localpart string, displayName string, ) error { - d.profilesMu.Lock() - defer d.profilesMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.profiles.setDisplayName(ctx, txn, localpart, displayName) - }) + // d.profilesMu.Lock() + // defer d.profilesMu.Unlock() + // return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + // return d.profiles.setDisplayName(ctx, txn, localpart, displayName) + // }) + return d.profiles.setDisplayName(ctx, localpart, displayName) } // SetPassword sets the account password to the given hash. @@ -170,22 +171,23 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er // when the first txn upgrades to a write txn. We also need to lock the account creation else we can // race with CreateAccount // We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed. - d.profilesMu.Lock() - d.accountDatasMu.Lock() - d.accountsMu.Lock() - defer d.profilesMu.Unlock() - defer d.accountDatasMu.Unlock() - defer d.accountsMu.Unlock() - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var numLocalpart int64 - numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) - if err != nil { - return err - } - localpart := strconv.FormatInt(numLocalpart, 10) - acc, err = d.createAccount(ctx, txn, localpart, "", "") - return err - }) + + // d.profilesMu.Lock() + // d.accountDatasMu.Lock() + // d.accountsMu.Lock() + // defer d.profilesMu.Unlock() + // defer d.accountDatasMu.Unlock() + // defer d.accountsMu.Unlock() + // err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + // }) + + var numLocalpart int64 + numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx) + if err != nil { + return nil, err + } + localpart := strconv.FormatInt(numLocalpart, 10) + acc, err = d.createAccount(ctx, localpart, "", "") return acc, err } @@ -196,23 +198,25 @@ func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, ) (acc *api.Account, err error) { // Create one account at a time else we can get 'database is locked'. - d.profilesMu.Lock() - d.accountDatasMu.Lock() - d.accountsMu.Lock() - defer d.profilesMu.Unlock() - defer d.accountDatasMu.Unlock() - defer d.accountsMu.Unlock() - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) - return err - }) - return + // d.profilesMu.Lock() + // d.accountDatasMu.Lock() + // d.accountsMu.Lock() + // defer d.profilesMu.Unlock() + // defer d.accountDatasMu.Unlock() + // defer d.accountsMu.Unlock() + // err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + // acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) + // return err + // }) + + acc, err = d.createAccount(ctx, localpart, plaintextPassword, appserviceID) + return acc, err } // WARNING! This function assumes that the relevant mutexes have already // been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). func (d *Database) createAccount( - ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, + ctx context.Context, localpart, plaintextPassword, appserviceID string, ) (*api.Account, error) { var err error var account *api.Account @@ -224,13 +228,13 @@ func (d *Database) createAccount( return nil, err } } - if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil { + if account, err = d.accounts.insertAccount(ctx, localpart, hash, appserviceID); err != nil { return nil, sqlutil.ErrUserExists } - if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { + if err = d.profiles.insertProfile(ctx, localpart); err != nil { return nil, err } - if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ + if err = d.accountDatas.insertAccountData(ctx, localpart, "", "m.push_rules", json.RawMessage(`{ "global": { "content": [], "override": [], @@ -252,11 +256,11 @@ func (d *Database) createAccount( func (d *Database) SaveAccountData( ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ) error { - d.accountDatasMu.Lock() - defer d.accountDatasMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) - }) + // d.accountDatasMu.Lock() + // defer d.accountDatasMu.Unlock() + // return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + // }) + return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content) } // GetAccountData returns account data related to a given localpart @@ -286,7 +290,7 @@ func (d *Database) GetAccountDataByType( func (d *Database) GetNewNumericLocalpart( ctx context.Context, ) (int64, error) { - return d.accounts.selectNewNumericLocalpart(ctx, nil) + return d.accounts.selectNewNumericLocalpart(ctx) } func (d *Database) hashPassword(plaintext string) (hash string, err error) { @@ -305,22 +309,23 @@ var Err3PIDInUse = errors.New("This third-party identifier is already in use") func (d *Database) SaveThreePIDAssociation( ctx context.Context, threepid, localpart, medium string, ) (err error) { - d.threepidsMu.Lock() - defer d.threepidsMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - user, err := d.threepids.selectLocalpartForThreePID( - ctx, txn, threepid, medium, - ) - if err != nil { - return err - } + // d.threepidsMu.Lock() + // defer d.threepidsMu.Unlock() + // return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + // }) - if len(user) > 0 { - return Err3PIDInUse - } + user, err := d.threepids.selectLocalpartForThreePID( + ctx, threepid, medium, + ) + if err != nil { + return err + } - return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart) - }) + if len(user) > 0 { + return Err3PIDInUse + } + + return d.threepids.insertThreePID(ctx, threepid, medium, localpart) } // RemoveThreePIDAssociation removes the association involving a given third-party @@ -330,11 +335,11 @@ func (d *Database) SaveThreePIDAssociation( func (d *Database) RemoveThreePIDAssociation( ctx context.Context, threepid string, medium string, ) (err error) { - d.threepidsMu.Lock() - defer d.threepidsMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.threepids.deleteThreePID(ctx, txn, threepid, medium) - }) + // d.threepidsMu.Lock() + // defer d.threepidsMu.Unlock() + // return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + // }) + return d.threepids.deleteThreePID(ctx, threepid, medium) } // GetLocalpartForThreePID looks up the localpart associated with a given third-party @@ -345,7 +350,7 @@ func (d *Database) RemoveThreePIDAssociation( func (d *Database) GetLocalpartForThreePID( ctx context.Context, threepid string, medium string, ) (localpart string, err error) { - return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) + return d.threepids.selectLocalpartForThreePID(ctx, threepid, medium) } // GetThreePIDsForLocalpart looks up the third-party identifiers associated with @@ -362,11 +367,11 @@ func (d *Database) GetThreePIDsForLocalpart( // in the database. // If the DB returns sql.ErrNoRows the Localpart isn't taken. func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { - _, err := d.accounts.selectAccountByLocalpart(ctx, localpart) - if err == sql.ErrNoRows { - return true, nil - } - return false, err + response, err := d.accounts.selectAccountByLocalpart(ctx, localpart) + // if err == sql.ErrNoRows { + // return true, nil + // } + return response == nil, err } // GetAccountByLocalpart returns the account associated with the given localpart. @@ -395,9 +400,9 @@ func (d *Database) CreateOpenIDToken( token, localpart string, ) (int64, error) { expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS - err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS) - }) + // err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + // }) + var err = d.openIDTokens.insertToken(ctx, token, localpart, expiresAtMS) return expiresAtMS, err } diff --git a/userapi/storage/accounts/cosmosdb/threepid_table.go b/userapi/storage/accounts/cosmosdb/threepid_table.go index 0d37dda0e..2fb941fe0 100644 --- a/userapi/storage/accounts/cosmosdb/threepid_table.go +++ b/userapi/storage/accounts/cosmosdb/threepid_table.go @@ -16,118 +16,186 @@ package cosmosdb import ( "context" - "database/sql" + "fmt" + "time" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) -const threepidSchema = ` --- Stores data about third party identifiers -CREATE TABLE IF NOT EXISTS account_threepid ( - -- The third party identifier - threepid TEXT NOT NULL, - -- The 3PID medium - medium TEXT NOT NULL DEFAULT 'email', - -- The localpart of the Matrix user ID associated to this 3PID - localpart TEXT NOT NULL, +// const threepidSchema = ` +// -- Stores data about third party identifiers +// CREATE TABLE IF NOT EXISTS account_threepid ( +// -- The third party identifier +// threepid TEXT NOT NULL, +// -- The 3PID medium +// medium TEXT NOT NULL DEFAULT 'email', +// -- The localpart of the Matrix user ID associated to this 3PID +// localpart TEXT NOT NULL, - PRIMARY KEY(threepid, medium) -); +// PRIMARY KEY(threepid, medium) +// ); -CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart); -` - -const selectLocalpartForThreePIDSQL = "" + - "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" - -const selectThreePIDsForLocalpartSQL = "" + - "SELECT threepid, medium FROM account_threepid WHERE localpart = $1" - -const insertThreePIDSQL = "" + - "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)" - -const deleteThreePIDSQL = "" + - "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" - -type threepidStatements struct { - db *sql.DB - selectLocalpartForThreePIDStmt *sql.Stmt - selectThreePIDsForLocalpartStmt *sql.Stmt - insertThreePIDStmt *sql.Stmt - deleteThreePIDStmt *sql.Stmt +type ThreePIDObject struct { + Localpart string `json:"local_part"` + ThreePID string `json:"three_pid"` + Medium string `json:"medium"` } -func (s *threepidStatements) prepare(db *sql.DB) (err error) { - s.db = db - _, err = db.Exec(threepidSchema) - if err != nil { - return - } - if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil { - return - } - if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil { - return - } - if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil { - return - } - if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil { - return - } +type ThreePIDCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Object ThreePIDObject `json:"_object"` +} +type threepidStatements struct { + db *Database + tableName string +} + +func (s *threepidStatements) prepare(db *Database) (err error) { + s.db = db + s.tableName = "account_threepid" return } func (s *threepidStatements) selectLocalpartForThreePID( - ctx context.Context, txn *sql.Tx, threepid string, medium string, + ctx context.Context, threepid string, medium string, ) (localpart string, err error) { - stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) - err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) - if err == sql.ErrNoRows { + + // "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName) + var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + response := []ThreePIDCosmosData{} + var selectLocalPartThreePIDCosmos = "select * from c where c._cn = @x1 and c._object.three_pid = @x2 and c._object.medium = @x3" + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": threepid, + "@x3": medium, + } + var options = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(selectLocalPartThreePIDCosmos, params) + var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + config.DatabaseName, + config.TenantName, + query, + &response, + options) + + if ex != nil { + return "", ex + } + + if len(response) == 0 { return "", nil } - return + + return response[0].Object.Localpart, nil } func (s *threepidStatements) selectThreePIDsForLocalpart( ctx context.Context, localpart string, ) (threepids []authtypes.ThreePID, err error) { - rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) - if err != nil { - return + + // "SELECT threepid, medium FROM account_threepid WHERE localpart = $1" + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName) + var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + response := []ThreePIDCosmosData{} + var selectThreePIDLocalPartCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2" + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": localpart, + } + var options = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(selectThreePIDLocalPartCosmos, params) + var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + config.DatabaseName, + config.TenantName, + query, + &response, + options) + + if ex != nil { + return threepids, ex + } + + if len(response) == 0 { + return threepids, nil } - defer internal.CloseAndLogIfError(ctx, rows, "selectThreePIDsForLocalpart: rows.close() failed") threepids = []authtypes.ThreePID{} - for rows.Next() { - var threepid string - var medium string - if err = rows.Scan(&threepid, &medium); err != nil { - return - } + for _, item := range response { threepids = append(threepids, authtypes.ThreePID{ - Address: threepid, - Medium: medium, + Address: item.Object.ThreePID, + Medium: item.Object.Medium, }) } - return threepids, rows.Err() + return threepids, nil } func (s *threepidStatements) insertThreePID( - ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, + ctx context.Context, threepid, medium, localpart string, ) (err error) { - stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) - _, err = stmt.ExecContext(ctx, threepid, medium, localpart) - return err + + // "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)" + var result = ThreePIDObject{ + Localpart: localpart, + Medium: medium, + ThreePID: threepid, + } + + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) + + id := fmt.Sprintf("%s_%s", threepid, medium) + var dbData = ThreePIDCosmosData{ + Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id), + Cn: dbCollectionName, + Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), + Timestamp: time.Now().Unix(), + Object: result, + } + + var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + config.DatabaseName, + config.TenantName, + dbData, + options) + + if err != nil { + return err + } + return } func (s *threepidStatements) deleteThreePID( - ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) { - stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) - _, err = stmt.ExecContext(ctx, threepid, medium) - return err + ctx context.Context, threepid string, medium string) (err error) { + + // "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" + var config = cosmosdbapi.DefaultConfig() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) + id := fmt.Sprintf("%s_%s", threepid, medium) + pk := cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) + var options = cosmosdbapi.GetDeleteDocumentOptions(pk) + _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( + ctx, + config.DatabaseName, + config.TenantName, + id, + options) + + if err != nil { + return err + } + return }