diff --git a/userapi/storage/accounts/cosmosdb/account_data_table.go b/userapi/storage/accounts/cosmosdb/account_data_table.go index c73a98c93..67a406435 100644 --- a/userapi/storage/accounts/cosmosdb/account_data_table.go +++ b/userapi/storage/accounts/cosmosdb/account_data_table.go @@ -39,16 +39,16 @@ import ( // ); // ` -type AccountCosmosAccountData struct { - Id string `json:"id"` - Pk string `json:"_pk"` - Cn string `json:"_cn"` - ETag string `json:"_etag"` - Timestamp int64 `json:"_ts"` - Object AccountData `json:"_object"` +type AccountDataCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + AccountData AccountDataCosmos `json:"mx_userapi_accountdata"` } -type AccountData struct { +type AccountDataCosmos struct { LocalPart string `json:"local_part"` RoomId string `json:"room_id"` Type string `json:"type"` @@ -56,12 +56,17 @@ type AccountData struct { } type accountDataStatements struct { - db *Database - tableName string + db *Database + // insertAccountDataStmt *sql.Stmt + selectAccountDataStmt string + selectAccountDataByTypeStmt string + tableName string } func (s *accountDataStatements) prepare(db *Database) (err error) { s.db = db + s.selectAccountDataStmt = "select * from c where c._cn = @x1 and c.mx_userapi_accountdata.local_part = @x2" + s.selectAccountDataByTypeStmt = "select * from c where c._cn = @x1 and c.mx_userapi_accountdata.local_part = @x2 and c.mx_userapi_accountdata.room_id = @x3 and c.mx_userapi_accountdata.type = @x4" s.tableName = "account_data" return } @@ -72,7 +77,7 @@ func (s *accountDataStatements) insertAccountData( // INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) // ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 - var result = AccountData{ + var result = AccountDataCosmos{ LocalPart: localpart, RoomId: roomID, Type: dataType, @@ -88,12 +93,12 @@ func (s *accountDataStatements) insertAccountData( id = fmt.Sprintf("%s_%s_%s", result.LocalPart, result.RoomId, result.Type) } - var dbData = AccountCosmosAccountData{ - Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id), - Cn: dbCollectionName, - Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), - Timestamp: time.Now().Unix(), - Object: result, + var dbData = AccountDataCosmosData{ + Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, id), + Cn: dbCollectionName, + Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), + Timestamp: time.Now().Unix(), + AccountData: result, } var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) @@ -118,14 +123,13 @@ func (s *accountDataStatements) selectAccountData( var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName) var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) - response := []AccountCosmosAccountData{} - var selectAccountDataCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2" + response := []AccountDataCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": localpart, } var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(selectAccountDataCosmos, params) + var query = cosmosdbapi.GetQuery(s.selectAccountDataStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, config.DatabaseName, @@ -143,14 +147,14 @@ func (s *accountDataStatements) selectAccountData( for i := 0; i < len(response); i++ { var row = response[i] - var roomID = row.Object.RoomId + var roomID = row.AccountData.RoomId if roomID != "" { - if _, ok := rooms[row.Object.RoomId]; !ok { + if _, ok := rooms[row.AccountData.RoomId]; !ok { rooms[roomID] = map[string]json.RawMessage{} } - rooms[roomID][row.Object.Type] = row.Object.Content + rooms[roomID][row.AccountData.Type] = row.AccountData.Content } else { - global[row.Object.Type] = row.Object.Content + global[row.AccountData.Type] = row.AccountData.Content } } @@ -166,8 +170,7 @@ func (s *accountDataStatements) selectAccountDataByType( var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName) var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) - response := []AccountCosmosAccountData{} - var selectAccountDataCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2 and c._object.room_id = @x3 and c._object.type = @x4" + response := []AccountDataCosmosData{} params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": localpart, @@ -175,7 +178,7 @@ func (s *accountDataStatements) selectAccountDataByType( "@x4": dataType, } var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(selectAccountDataCosmos, params) + var query = cosmosdbapi.GetQuery(s.selectAccountDataByTypeStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, config.DatabaseName, @@ -192,7 +195,7 @@ func (s *accountDataStatements) selectAccountDataByType( return data, nil } - bytes = response[0].Object.Content + bytes = response[0].AccountData.Content data = json.RawMessage(bytes) return diff --git a/userapi/storage/accounts/cosmosdb/accounts_table.go b/userapi/storage/accounts/cosmosdb/accounts_table.go index 076068d1f..1f8dc866f 100644 --- a/userapi/storage/accounts/cosmosdb/accounts_table.go +++ b/userapi/storage/accounts/cosmosdb/accounts_table.go @@ -45,20 +45,23 @@ import ( // ); // ` -type AccountExtended struct { - IsDeactivated bool `json:"is_deactivated"` - PasswordHash string `json:"password_hash"` - Created int64 `json:"created_ts"` +type AccountCosmos struct { + UserID string `json:"user_id"` + Localpart string `json:"local_part"` + ServerName gomatrixserverlib.ServerName `json:"server_name"` + AppServiceID string `json:"app_service_id"` + IsDeactivated bool `json:"is_deactivated"` + PasswordHash string `json:"password_hash"` + Created int64 `json:"created_ts"` } type AccountCosmosData struct { - Id string `json:"id"` - Pk string `json:"_pk"` - Cn string `json:"_cn"` - ETag string `json:"_etag"` - Timestamp int64 `json:"_ts"` - Object api.Account `json:"_object"` - ObjectExtended AccountExtended `json:"_object_extended"` + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Account AccountCosmos `json:"mx_userapi_account"` } type AccountCosmosUserCount struct { @@ -66,13 +69,19 @@ type AccountCosmosUserCount struct { } type accountsStatements struct { - db *Database - tableName string - serverName gomatrixserverlib.ServerName + db *Database + selectAccountByLocalpartStmt string + selectPasswordHashStmt string + selectNewNumericLocalpartStmt string + tableName string + serverName gomatrixserverlib.ServerName } func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) { s.db = db + s.selectPasswordHashStmt = "select * from c where c._cn = @x1 and c.mx_userapi_account.local_part = @x2 and c.mx_userapi_account.is_deactivated = false" + s.selectAccountByLocalpartStmt = "select * from c where c._cn = @x1 and c.mx_userapi_account.local_part = @x2" + s.selectNewNumericLocalpartStmt = "select count(c._ts) as usercount from c where c._cn = @x1" s.tableName = "account_accounts" s.serverName = server return @@ -104,6 +113,24 @@ func setAccount(s *accountsStatements, ctx context.Context, config cosmosdbapi.T return &response, ex } +func mapFromAccount(db AccountCosmos) api.Account { + return api.Account{ + AppServiceID: db.AppServiceID, + Localpart: db.Localpart, + ServerName: db.ServerName, + UserID: db.UserID, + } +} + +func mapToAccount(api api.Account) AccountCosmos { + return AccountCosmos{ + AppServiceID: api.AppServiceID, + Localpart: api.Localpart, + ServerName: api.ServerName, + UserID: api.UserID, + } +} + // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. @@ -120,22 +147,21 @@ func (s *accountsStatements) insertAccount( AppServiceID: appserviceID, } - var extended = AccountExtended{ - IsDeactivated: false, - PasswordHash: hash, - Created: createdTimeMS, - } + //Add the extra properties not on the API + var data = mapToAccount(result) + data.Created = createdTimeMS + data.PasswordHash = hash + data.IsDeactivated = false var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) var dbData = AccountCosmosData{ - Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart), - Cn: dbCollectionName, - Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), - Timestamp: time.Now().Unix(), - Object: result, - ObjectExtended: extended, + Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Localpart), + Cn: dbCollectionName, + Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), + Timestamp: time.Now().Unix(), + Account: data, } var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) @@ -168,7 +194,7 @@ func (s *accountsStatements) updatePassword( return exGet } - response.ObjectExtended.PasswordHash = passwordHash + response.Account.PasswordHash = passwordHash var _, exReplace = setAccount(s, ctx, config, pk, *response) if exReplace != nil { @@ -192,7 +218,7 @@ func (s *accountsStatements) deactivateAccount( return exGet } - response.ObjectExtended.IsDeactivated = true + response.Account.IsDeactivated = true var _, exReplace = setAccount(s, ctx, config, pk, *response) if exReplace != nil { @@ -210,13 +236,12 @@ func (s *accountsStatements) selectPasswordHash( var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) response := []AccountCosmosData{} - var selectPasswordHashCosmos = "select * from c where c._cn = @x1 and c._object.Localpart = @x2 and c._object_extended.is_deactivated = false" params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": localpart, } var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(selectPasswordHashCosmos, params) + var query = cosmosdbapi.GetQuery(s.selectPasswordHashStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, config.DatabaseName, @@ -237,7 +262,7 @@ func (s *accountsStatements) selectPasswordHash( return "", errors.New(fmt.Sprintf("Localpart %s has multiple entries", localpart)) } - return response[0].ObjectExtended.PasswordHash, nil + return response[0].Account.PasswordHash, nil } func (s *accountsStatements) selectAccountByLocalpart( @@ -250,13 +275,12 @@ func (s *accountsStatements) selectAccountByLocalpart( var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) response := []AccountCosmosData{} - var selectPasswordHashCosmos = "select * from c where c._cn = @x1 and c._object.Localpart = @x2" params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": localpart, } var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(selectPasswordHashCosmos, params) + var query = cosmosdbapi.GetQuery(s.selectAccountByLocalpartStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, config.DatabaseName, @@ -273,7 +297,7 @@ func (s *accountsStatements) selectAccountByLocalpart( return nil, nil } - acc = response[0].Object + acc = mapFromAccount(response[0].Account) acc.UserID = userutil.MakeUserID(localpart, s.serverName) acc.ServerName = s.serverName @@ -289,12 +313,11 @@ func (s *accountsStatements) selectNewNumericLocalpart( var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) var response []AccountCosmosUserCount - var selectCountCosmos = "select count(c._ts) as usercount from c where c._cn = @x1" params := map[string]interface{}{ "@x1": dbCollectionName, } var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(selectCountCosmos, params) + var query = cosmosdbapi.GetQuery(s.selectNewNumericLocalpartStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, config.DatabaseName, diff --git a/userapi/storage/accounts/cosmosdb/openid_table.go b/userapi/storage/accounts/cosmosdb/openid_table.go index 9224dd134..3b62d244b 100644 --- a/userapi/storage/accounts/cosmosdb/openid_table.go +++ b/userapi/storage/accounts/cosmosdb/openid_table.go @@ -1,42 +1,70 @@ package cosmosdb import ( - "time" - "github.com/matrix-org/dendrite/internal/cosmosdbapi" "context" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" ) -const openIDTokenSchema = ` --- Stores data about accounts. -CREATE TABLE IF NOT EXISTS open_id_tokens ( - -- The value of the token issued to a user - token TEXT NOT NULL PRIMARY KEY, - -- The Matrix user ID for this account - localpart TEXT NOT NULL, - -- When the token expires, as a unix timestamp (ms resolution). - token_expires_at_ms BIGINT NOT NULL -); -` +// const openIDTokenSchema = ` +// -- Stores data about accounts. +// CREATE TABLE IF NOT EXISTS open_id_tokens ( +// -- The value of the token issued to a user +// token TEXT NOT NULL PRIMARY KEY, +// -- The Matrix user ID for this account +// localpart TEXT NOT NULL, +// -- When the token expires, as a unix timestamp (ms resolution). +// token_expires_at_ms BIGINT NOT NULL +// ); +// ` + +// OpenIDToken represents an OpenID token +type OpenIDTokenCosmos struct { + Token string `json:"token"` + UserID string `json:"user_id"` + ExpiresAtMS int64 `json:"expires_at"` +} + type OpenIdTokenCosmosData struct { - Id string `json:"id"` - Pk string `json:"_pk"` - Cn string `json:"_cn"` - ETag string `json:"_etag"` - Timestamp int64 `json:"_ts"` - Object *api.OpenIDToken `json:"_object"` + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + OpenIdToken OpenIDTokenCosmos `json:"mx_userapi_openidtoken"` } type tokenStatements struct { - db *Database - tableName string - serverName gomatrixserverlib.ServerName + db *Database + // insertTokenStmt *sql.Stmt + selectTokenStmt string + tableName string + serverName gomatrixserverlib.ServerName +} + +func mapFromToken(db OpenIDTokenCosmos) api.OpenIDToken { + return api.OpenIDToken{ + ExpiresAtMS: db.ExpiresAtMS, + Token: db.Token, + UserID: db.UserID, + } +} + +func mapToToken(api api.OpenIDToken) OpenIDTokenCosmos { + return OpenIDTokenCosmos{ + ExpiresAtMS: api.ExpiresAtMS, + Token: api.Token, + UserID: api.UserID, + } } func (s *tokenStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) { s.db = db + s.selectTokenStmt = "select * from c where c._cn = @x1 and c.mx_userapi_openidtoken.token = @x2" s.tableName = "open_id_tokens" s.serverName = server return @@ -52,29 +80,29 @@ func (s *tokenStatements) insertToken( // "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" var result = &api.OpenIDToken{ - UserID: localpart, - Token: token, - ExpiresAtMS: expiresAtMS, + UserID: localpart, + Token: token, + ExpiresAtMS: expiresAtMS, } var config = cosmosdbapi.DefaultConfig() var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) var dbData = OpenIdTokenCosmosData{ - Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Token), - Cn: dbCollectionName, - Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), - Timestamp: time.Now().Unix(), - Object: result, + Id: cosmosdbapi.GetDocumentId(config.TenantName, dbCollectionName, result.Token), + Cn: dbCollectionName, + Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), + Timestamp: time.Now().Unix(), + OpenIdToken: mapToToken(*result), } var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).CreateDocument( - ctx, - config.DatabaseName, - config.TenantName, - dbData, - options) + ctx, + config.DatabaseName, + config.TenantName, + dbData, + options) if ex != nil { return ex @@ -96,32 +124,31 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes( var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) response := []OpenIdTokenCosmosData{} - var selectOpenIdTokenCosmos = "select * from c where c._cn = @x1 and c._object.Token = @x2" params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": token, } var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(selectOpenIdTokenCosmos, params) + var query = cosmosdbapi.GetQuery(s.selectTokenStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - config.DatabaseName, - config.TenantName, - query, - &response, - options) + ctx, + config.DatabaseName, + config.TenantName, + query, + &response, + options) if ex != nil { return nil, ex } - if(len(response) == 0) { + if len(response) == 0 { return nil, nil } - var openIdToken = response[0].Object + var openIdToken = response[0].OpenIdToken openIDTokenAttrs = api.OpenIDTokenAttributes{ - UserID: openIdToken.UserID, + UserID: openIdToken.UserID, ExpiresAtMS: openIdToken.ExpiresAtMS, } return &openIDTokenAttrs, nil diff --git a/userapi/storage/accounts/cosmosdb/profile_table.go b/userapi/storage/accounts/cosmosdb/profile_table.go index 78a815cad..9b47acc09 100644 --- a/userapi/storage/accounts/cosmosdb/profile_table.go +++ b/userapi/storage/accounts/cosmosdb/profile_table.go @@ -37,22 +37,52 @@ import ( // ); // ` +// Profile represents the profile for a Matrix account. +type ProfileCosmos struct { + Localpart string `json:"local_part"` + DisplayName string `json:"display_name"` + AvatarURL string `json:"avatar_url"` +} + type ProfileCosmosData struct { - Id string `json:"id"` - Pk string `json:"_pk"` - Cn string `json:"_cn"` - ETag string `json:"_etag"` - Timestamp int64 `json:"_ts"` - Object authtypes.Profile `json:"_object"` + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Profile ProfileCosmos `json:"mx_userapi_profile"` } type profilesStatements struct { - db *Database - tableName string + db *Database + // insertProfileStmt *sql.Stmt + selectProfileByLocalpartStmt string + // setAvatarURLStmt *sql.Stmt + // setDisplayNameStmt *sql.Stmt + selectProfilesBySearchStmt string + tableName string +} + +func mapFromProfile(db ProfileCosmos) authtypes.Profile { + return authtypes.Profile{ + AvatarURL: db.AvatarURL, + DisplayName: db.DisplayName, + Localpart: db.Localpart, + } +} + +func mapToProfile(api authtypes.Profile) ProfileCosmos { + return ProfileCosmos{ + AvatarURL: api.AvatarURL, + DisplayName: api.DisplayName, + Localpart: api.Localpart, + } } func (s *profilesStatements) prepare(db *Database) (err error) { s.db = db + s.selectProfileByLocalpartStmt = "select * from c where c._cn = @x1 and c.mx_userapi_profile.local_part = @x2" + s.selectProfilesBySearchStmt = "select top @x3 * from c where c._cn = @x1 and contains(c.mx_userapi_profile.local_part, @x2)" s.tableName = "account_profiles" return } @@ -99,7 +129,7 @@ func (s *profilesStatements) insertProfile( Cn: dbCollectionName, Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), Timestamp: time.Now().Unix(), - Object: *result, + Profile: mapToProfile(*result), } var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) @@ -122,13 +152,12 @@ func (s *profilesStatements) selectProfileByLocalpart( var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) response := []ProfileCosmosData{} - var selectProfileByLocalpartCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2" params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": localpart, } var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(selectProfileByLocalpartCosmos, params) + var query = cosmosdbapi.GetQuery(s.selectProfileByLocalpartStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, config.DatabaseName, @@ -149,7 +178,8 @@ func (s *profilesStatements) selectProfileByLocalpart( return nil, errors.New(fmt.Sprintf("Localpart %s has multiple entries", len(response))) } - return &response[0].Object, nil + var result = mapFromProfile(response[0].Profile) + return &result, nil } func (s *profilesStatements) setAvatarURL( @@ -167,7 +197,7 @@ func (s *profilesStatements) setAvatarURL( return exGet } - response.Object.AvatarURL = avatarURL + response.Profile.AvatarURL = avatarURL var _, exReplace = setProfile(s, ctx, config, pk, *response) if exReplace != nil { @@ -190,7 +220,7 @@ func (s *profilesStatements) setDisplayName( return exGet } - response.Object.DisplayName = displayName + response.Profile.DisplayName = displayName var _, exReplace = setProfile(s, ctx, config, pk, *response) if exReplace != nil { @@ -209,14 +239,13 @@ func (s *profilesStatements) selectProfilesBySearch( var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) response := []ProfileCosmosData{} - var selectProfileByLocalpartCosmos = "select top @x3 * from c where c._cn = @x1 and contains(c._object.local_part, @x2)" params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": searchString, "@x3": limit, } var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(selectProfileByLocalpartCosmos, params) + var query = cosmosdbapi.GetQuery(s.selectProfilesBySearchStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, config.DatabaseName, @@ -231,7 +260,7 @@ func (s *profilesStatements) selectProfilesBySearch( for i := 0; i < len(response); i++ { var responseData = response[i] - profiles = append(profiles, responseData.Object) + profiles = append(profiles, mapFromProfile(responseData.Profile)) } return profiles, nil diff --git a/userapi/storage/accounts/cosmosdb/threepid_table.go b/userapi/storage/accounts/cosmosdb/threepid_table.go index 2fb941fe0..94bd67ce3 100644 --- a/userapi/storage/accounts/cosmosdb/threepid_table.go +++ b/userapi/storage/accounts/cosmosdb/threepid_table.go @@ -37,7 +37,7 @@ import ( // PRIMARY KEY(threepid, medium) // ); -type ThreePIDObject struct { +type ThreePIDCosmos struct { Localpart string `json:"local_part"` ThreePID string `json:"three_pid"` Medium string `json:"medium"` @@ -49,16 +49,22 @@ type ThreePIDCosmosData struct { Cn string `json:"_cn"` ETag string `json:"_etag"` Timestamp int64 `json:"_ts"` - Object ThreePIDObject `json:"_object"` + ThreePID ThreePIDCosmos `json:"mx_userapi_threepid"` } type threepidStatements struct { - db *Database + db *Database + selectLocalpartForThreePIDStmt string + selectThreePIDsForLocalpartStmt string + // insertThreePIDStmt *sql.Stmt + // deleteThreePIDStmt *sql.Stmt tableName string } func (s *threepidStatements) prepare(db *Database) (err error) { s.db = db + s.selectLocalpartForThreePIDStmt = "select * from c where c._cn = @x1 and c.mx_userapi_threepid.three_pid = @x2 and c.mx_userapi_threepid.medium = @x3" + s.selectThreePIDsForLocalpartStmt = "select * from c where c._cn = @x1 and c.mx_userapi_threepid.local_part = @x2" s.tableName = "account_threepid" return } @@ -72,14 +78,13 @@ func (s *threepidStatements) selectLocalpartForThreePID( var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName) var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) response := []ThreePIDCosmosData{} - var selectLocalPartThreePIDCosmos = "select * from c where c._cn = @x1 and c._object.three_pid = @x2 and c._object.medium = @x3" params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": threepid, "@x3": medium, } var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(selectLocalPartThreePIDCosmos, params) + var query = cosmosdbapi.GetQuery(s.selectLocalpartForThreePIDStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, config.DatabaseName, @@ -96,7 +101,7 @@ func (s *threepidStatements) selectLocalpartForThreePID( return "", nil } - return response[0].Object.Localpart, nil + return response[0].ThreePID.Localpart, nil } func (s *threepidStatements) selectThreePIDsForLocalpart( @@ -108,13 +113,12 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName) var pk = cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName) response := []ThreePIDCosmosData{} - var selectThreePIDLocalPartCosmos = "select * from c where c._cn = @x1 and c._object.local_part = @x2" params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": localpart, } var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(selectThreePIDLocalPartCosmos, params) + var query = cosmosdbapi.GetQuery(s.selectThreePIDsForLocalpartStmt, params) var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, config.DatabaseName, @@ -134,8 +138,8 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( threepids = []authtypes.ThreePID{} for _, item := range response { threepids = append(threepids, authtypes.ThreePID{ - Address: item.Object.ThreePID, - Medium: item.Object.Medium, + Address: item.ThreePID.ThreePID, + Medium: item.ThreePID.Medium, }) } return threepids, nil @@ -146,7 +150,7 @@ func (s *threepidStatements) insertThreePID( ) (err error) { // "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)" - var result = ThreePIDObject{ + var result = ThreePIDCosmos{ Localpart: localpart, Medium: medium, ThreePID: threepid, @@ -161,7 +165,7 @@ func (s *threepidStatements) insertThreePID( Cn: dbCollectionName, Pk: cosmosdbapi.GetPartitionKey(config.TenantName, dbCollectionName), Timestamp: time.Now().Unix(), - Object: result, + ThreePID: result, } var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)