mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-26 08:13:09 -06:00
- Update the Config to use Cosmos for the sync API - Ensure Cosmos DocId does not contain escape chars - Create a shared Cosmos PartitionOffet table and refactor to use it - Hardcode the "nafka" Connstring to use the "file:naffka.db" - Create seq documents for each of the nextXXXID methods
471 lines
15 KiB
Go
471 lines
15 KiB
Go
// Copyright 2017 Vector Creations Ltd
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package cosmosdb
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
|
|
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
|
|
|
"github.com/matrix-org/dendrite/userapi/api"
|
|
|
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
|
"github.com/matrix-org/gomatrixserverlib"
|
|
)
|
|
|
|
// const devicesSchema = `
|
|
// -- This sequence is used for automatic allocation of session_id.
|
|
// -- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
|
|
|
|
// -- Stores data about devices.
|
|
// CREATE TABLE IF NOT EXISTS device_devices (
|
|
// access_token TEXT PRIMARY KEY,
|
|
// session_id INTEGER,
|
|
// device_id TEXT ,
|
|
// localpart TEXT ,
|
|
// created_ts BIGINT,
|
|
// display_name TEXT,
|
|
// last_seen_ts BIGINT,
|
|
// ip TEXT,
|
|
// user_agent TEXT,
|
|
|
|
// UNIQUE (localpart, device_id)
|
|
// );
|
|
// `
|
|
|
|
type DeviceCosmos struct {
|
|
ID string `json:"device_id"`
|
|
UserID string `json:"user_id"`
|
|
// The access_token granted to this device.
|
|
// This uniquely identifies the device from all other devices and clients.
|
|
AccessToken string `json:"access_token"`
|
|
// The unique ID of the session identified by the access token.
|
|
// Can be used as a secure substitution in places where data needs to be
|
|
// associated with access tokens.
|
|
SessionID int64 `json:"session_id"`
|
|
DisplayName string `json:"display_name"`
|
|
LastSeenTS int64 `json:"last_seen_ts"`
|
|
LastSeenIP string `json:"last_seen_ip"`
|
|
Localpart string `json:"local_part"`
|
|
UserAgent string `json:"user_agent"`
|
|
// If the device is for an appservice user,
|
|
// this is the appservice ID.
|
|
AppserviceID string `json:"app_service_id"`
|
|
}
|
|
|
|
type DeviceCosmosData struct {
|
|
Id string `json:"id"`
|
|
Pk string `json:"_pk"`
|
|
Cn string `json:"_cn"`
|
|
ETag string `json:"_etag"`
|
|
Timestamp int64 `json:"_ts"`
|
|
Device DeviceCosmos `json:"mx_userapi_device"`
|
|
}
|
|
|
|
type DeviceCosmosSessionCount struct {
|
|
SessionCount int64 `json:"sessioncount"`
|
|
}
|
|
|
|
type devicesStatements struct {
|
|
db *Database
|
|
selectDevicesCountStmt string
|
|
selectDeviceByTokenStmt string
|
|
// selectDeviceByIDStmt *sql.Stmt
|
|
selectDevicesByIDStmt string
|
|
selectDevicesByLocalpartStmt string
|
|
selectDevicesByLocalpartExceptIDStmt string
|
|
serverName gomatrixserverlib.ServerName
|
|
tableName string
|
|
}
|
|
|
|
func mapFromDevice(db DeviceCosmos) api.Device {
|
|
return api.Device{
|
|
AccessToken: db.AccessToken,
|
|
AppserviceID: db.AppserviceID,
|
|
ID: db.ID,
|
|
LastSeenIP: db.LastSeenIP,
|
|
LastSeenTS: db.LastSeenTS,
|
|
SessionID: db.SessionID,
|
|
UserAgent: db.UserAgent,
|
|
UserID: db.UserID,
|
|
}
|
|
}
|
|
|
|
func mapTodevice(api api.Device, s *devicesStatements) DeviceCosmos {
|
|
localPart, _ := userutil.ParseUsernameParam(api.UserID, &s.serverName)
|
|
return DeviceCosmos{
|
|
AccessToken: api.AccessToken,
|
|
AppserviceID: api.AppserviceID,
|
|
ID: api.ID,
|
|
LastSeenIP: api.LastSeenIP,
|
|
LastSeenTS: api.LastSeenTS,
|
|
Localpart: localPart,
|
|
SessionID: api.SessionID,
|
|
UserAgent: api.UserAgent,
|
|
UserID: api.UserID,
|
|
}
|
|
}
|
|
|
|
func queryDevice(s *devicesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceCosmosData, error) {
|
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
|
var response []DeviceCosmosData
|
|
|
|
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
|
var query = cosmosdbapi.GetQuery(qry, params)
|
|
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
|
ctx,
|
|
s.db.cosmosConfig.DatabaseName,
|
|
s.db.cosmosConfig.ContainerName,
|
|
query,
|
|
&response,
|
|
optionsQry)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return response, nil
|
|
}
|
|
|
|
func getDevice(s *devicesStatements, ctx context.Context, pk string, docId string) (*DeviceCosmosData, error) {
|
|
response := DeviceCosmosData{}
|
|
err := cosmosdbapi.GetDocumentOrNil(
|
|
s.db.connection,
|
|
s.db.cosmosConfig,
|
|
ctx,
|
|
pk,
|
|
docId,
|
|
&response)
|
|
|
|
if response.Id == "" {
|
|
return nil, cosmosdbutil.ErrNoRows
|
|
}
|
|
|
|
return &response, err
|
|
}
|
|
|
|
func setDevice(s *devicesStatements, ctx context.Context, device DeviceCosmosData) (*DeviceCosmosData, error) {
|
|
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(device.Pk, device.ETag)
|
|
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
|
ctx,
|
|
s.db.cosmosConfig.DatabaseName,
|
|
s.db.cosmosConfig.ContainerName,
|
|
device.Id,
|
|
&device,
|
|
optionsReplace)
|
|
return &device, ex
|
|
}
|
|
|
|
func (s *devicesStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
|
|
s.db = db
|
|
s.selectDevicesCountStmt = "select count(c._ts) as sessioncount from c where c._cn = @x1"
|
|
s.selectDevicesByLocalpartStmt = "select * from c where c._cn = @x1 and c.mx_userapi_device.local_part = @x2 and ARRAY_CONTAINS(@x3, c.mx_userapi_device.device_id)"
|
|
s.selectDevicesByLocalpartExceptIDStmt = "select * from c where c._cn = @x1 and c.mx_userapi_device.local_part = @x2 and c.mx_userapi_device.device_id != @x3"
|
|
s.selectDeviceByTokenStmt = "select * from c where c._cn = @x1 and c.mx_userapi_device.access_token = @x2"
|
|
s.selectDevicesByIDStmt = "select * from c where c._cn = @x1 and ARRAY_CONTAINS(@x2, c.mx_userapi_device.device_id)"
|
|
s.serverName = server
|
|
s.tableName = "device_devices"
|
|
return
|
|
}
|
|
|
|
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
|
|
// Returns an error if the user already has a device with the given device ID.
|
|
// Returns the device on success.
|
|
func (s *devicesStatements) insertDevice(
|
|
ctx context.Context, id, localpart, accessToken string,
|
|
displayName *string, ipAddr, userAgent string,
|
|
) (*api.Device, error) {
|
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
|
var sessionID int64
|
|
// "SELECT COUNT(access_token) FROM device_devices"
|
|
// HACK: Do we need a Cosmos Table for the sequence?
|
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
|
var response []DeviceCosmosSessionCount
|
|
params := map[string]interface{}{
|
|
"@x1": dbCollectionName,
|
|
}
|
|
|
|
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
|
var query = cosmosdbapi.GetQuery(s.selectDevicesCountStmt, params)
|
|
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
|
ctx,
|
|
s.db.cosmosConfig.DatabaseName,
|
|
s.db.cosmosConfig.ContainerName,
|
|
query,
|
|
&response,
|
|
optionsQry)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sessionID = response[0].SessionCount
|
|
sessionID++
|
|
// "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
|
// " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
|
|
|
data := DeviceCosmos{
|
|
ID: id,
|
|
UserID: userutil.MakeUserID(localpart, s.serverName),
|
|
AccessToken: accessToken,
|
|
SessionID: sessionID,
|
|
LastSeenTS: createdTimeMS,
|
|
LastSeenIP: ipAddr,
|
|
Localpart: localpart,
|
|
UserAgent: userAgent,
|
|
}
|
|
|
|
// access_token TEXT PRIMARY KEY,
|
|
// UNIQUE (localpart, device_id)
|
|
// HACK: check for duplicate PK as we are using the UNIQUE key for the DocId
|
|
docId := fmt.Sprintf("%s_%s", localpart, id)
|
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
|
|
|
var dbData = DeviceCosmosData{
|
|
Id: cosmosDocId,
|
|
Cn: dbCollectionName,
|
|
Pk: pk,
|
|
Timestamp: time.Now().Unix(),
|
|
Device: data,
|
|
}
|
|
|
|
var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
|
var _, _, errCreate = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
|
ctx,
|
|
s.db.cosmosConfig.DatabaseName,
|
|
s.db.cosmosConfig.ContainerName,
|
|
dbData,
|
|
optionsCreate)
|
|
|
|
if errCreate != nil {
|
|
return nil, errCreate
|
|
}
|
|
|
|
var result = mapFromDevice(dbData.Device)
|
|
return &result, nil
|
|
}
|
|
|
|
func (s *devicesStatements) deleteDevice(
|
|
ctx context.Context, id, localpart string,
|
|
) error {
|
|
// "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
|
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
|
docId := fmt.Sprintf("%s_%s", localpart, id)
|
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
|
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
|
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
|
|
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
|
|
ctx,
|
|
s.db.cosmosConfig.DatabaseName,
|
|
s.db.cosmosConfig.ContainerName,
|
|
cosmosDocId,
|
|
options)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *devicesStatements) deleteDevices(
|
|
ctx context.Context, localpart string, devices []string,
|
|
) error {
|
|
// "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
|
var response []DeviceCosmosData
|
|
params := map[string]interface{}{
|
|
"@x1": dbCollectionName,
|
|
"@x2": localpart,
|
|
"@x3": devices,
|
|
}
|
|
|
|
response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartStmt, params)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, item := range response {
|
|
s.deleteDevice(ctx, item.Device.ID, item.Device.Localpart)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *devicesStatements) deleteDevicesByLocalpart(
|
|
ctx context.Context, localpart, exceptDeviceID string,
|
|
) error {
|
|
// "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
|
exceptDevices := []string{
|
|
exceptDeviceID,
|
|
}
|
|
params := map[string]interface{}{
|
|
"@x1": dbCollectionName,
|
|
"@x2": localpart,
|
|
"@x3": exceptDevices,
|
|
}
|
|
|
|
response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartStmt, params)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, item := range response {
|
|
s.deleteDevice(ctx, item.Device.ID, item.Device.Localpart)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *devicesStatements) updateDeviceName(
|
|
ctx context.Context, localpart, deviceID string, displayName *string,
|
|
) error {
|
|
// "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
|
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
|
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
|
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
|
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
|
|
if exGet != nil {
|
|
return exGet
|
|
}
|
|
|
|
response.Device.DisplayName = *displayName
|
|
|
|
var _, exReplace = setDevice(s, ctx, *response)
|
|
if exReplace != nil {
|
|
return exReplace
|
|
}
|
|
return exReplace
|
|
}
|
|
|
|
func (s *devicesStatements) selectDeviceByToken(
|
|
ctx context.Context, accessToken string,
|
|
) (*api.Device, error) {
|
|
// "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
|
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
|
var response []DeviceCosmosData
|
|
params := map[string]interface{}{
|
|
"@x1": dbCollectionName,
|
|
"@x2": accessToken,
|
|
}
|
|
|
|
response, err := queryDevice(s, ctx, s.selectDeviceByTokenStmt, params)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(response) == 0 {
|
|
return nil, cosmosdbutil.ErrNoRows
|
|
}
|
|
|
|
if err == nil {
|
|
result := mapFromDevice(response[0].Device)
|
|
return &result, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// selectDeviceByID retrieves a device from the database with the given user
|
|
// localpart and deviceID
|
|
func (s *devicesStatements) selectDeviceByID(
|
|
ctx context.Context, localpart, deviceID string,
|
|
) (*api.Device, error) {
|
|
// "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
|
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
|
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
|
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
|
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
|
|
if exGet != nil {
|
|
return nil, exGet
|
|
}
|
|
result := mapFromDevice(response.Device)
|
|
return &result, nil
|
|
}
|
|
|
|
func (s *devicesStatements) selectDevicesByLocalpart(
|
|
ctx context.Context, localpart, exceptDeviceID string,
|
|
) ([]api.Device, error) {
|
|
devices := []api.Device{}
|
|
// "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
|
params := map[string]interface{}{
|
|
"@x1": dbCollectionName,
|
|
"@x2": localpart,
|
|
"@x3": exceptDeviceID,
|
|
}
|
|
|
|
response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartExceptIDStmt, params)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, item := range response {
|
|
dev := mapFromDevice(item.Device)
|
|
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
|
devices = append(devices, dev)
|
|
}
|
|
|
|
return devices, nil
|
|
}
|
|
|
|
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
|
// "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
|
|
var devices []api.Device
|
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
|
var response []DeviceCosmosData
|
|
params := map[string]interface{}{
|
|
"@x1": dbCollectionName,
|
|
"@x2": deviceIDs,
|
|
}
|
|
|
|
response, err := queryDevice(s, ctx, s.selectDevicesByIDStmt, params)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, item := range response {
|
|
dev := mapFromDevice(item.Device)
|
|
devices = append(devices, dev)
|
|
}
|
|
return devices, nil
|
|
}
|
|
|
|
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
|
|
lastSeenTs := time.Now().UnixNano() / 1000000
|
|
|
|
// "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
|
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
|
|
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
|
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
|
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
|
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
|
|
if exGet != nil {
|
|
return exGet
|
|
}
|
|
|
|
response.Device.LastSeenTS = lastSeenTs
|
|
response.Device.LastSeenIP = ipAddr
|
|
|
|
var _, exReplace = setDevice(s, ctx, *response)
|
|
if exReplace != nil {
|
|
return exReplace
|
|
}
|
|
return exReplace
|
|
}
|