mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-26 08:13:09 -06:00
Merge pull request #2 from criticalarc/af/CA-5532-AddCosmosDB
Add CosmosDB Package and support to Dendrite
This commit is contained in:
commit
234a89db5d
267
appservice/storage/cosmosdb/appservice_events_table.go
Normal file
267
appservice/storage/cosmosdb/appservice_events_table.go
Normal file
|
|
@ -0,0 +1,267 @@
|
|||
// Copyright 2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const appserviceEventsSchema = `
|
||||
-- Stores events to be sent to application services
|
||||
CREATE TABLE IF NOT EXISTS appservice_events (
|
||||
-- An auto-incrementing id unique to each event in the table
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The ID of the application service the event will be sent to
|
||||
as_id TEXT NOT NULL,
|
||||
-- JSON representation of the event
|
||||
headered_event_json TEXT NOT NULL,
|
||||
-- The ID of the transaction that this event is a part of
|
||||
txn_id INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS appservice_events_as_id ON appservice_events(as_id);
|
||||
`
|
||||
|
||||
const selectEventsByApplicationServiceIDSQL = "" +
|
||||
"SELECT id, headered_event_json, txn_id " +
|
||||
"FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC"
|
||||
|
||||
const countEventsByApplicationServiceIDSQL = "" +
|
||||
"SELECT COUNT(id) FROM appservice_events WHERE as_id = $1"
|
||||
|
||||
const insertEventSQL = "" +
|
||||
"INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " +
|
||||
"VALUES ($1, $2, $3)"
|
||||
|
||||
const updateTxnIDForEventsSQL = "" +
|
||||
"UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3"
|
||||
|
||||
const deleteEventsBeforeAndIncludingIDSQL = "" +
|
||||
"DELETE FROM appservice_events WHERE as_id = $1 AND id <= $2"
|
||||
|
||||
const (
|
||||
// A transaction ID number that no transaction should ever have. Used for
|
||||
// checking again the default value.
|
||||
invalidTxnID = -2
|
||||
)
|
||||
|
||||
type eventsStatements struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
selectEventsByApplicationServiceIDStmt *sql.Stmt
|
||||
countEventsByApplicationServiceIDStmt *sql.Stmt
|
||||
insertEventStmt *sql.Stmt
|
||||
updateTxnIDForEventsStmt *sql.Stmt
|
||||
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *eventsStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||
s.db = db
|
||||
s.writer = writer
|
||||
_, err = db.Exec(appserviceEventsSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.selectEventsByApplicationServiceIDStmt, err = db.Prepare(selectEventsByApplicationServiceIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.countEventsByApplicationServiceIDStmt, err = db.Prepare(countEventsByApplicationServiceIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.updateTxnIDForEventsStmt, err = db.Prepare(updateTxnIDForEventsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteEventsBeforeAndIncludingIDStmt, err = db.Prepare(deleteEventsBeforeAndIncludingIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// selectEventsByApplicationServiceID takes in an application service ID and
|
||||
// returns a slice of events that need to be sent to that application service,
|
||||
// as well as an int later used to remove these same events from the database
|
||||
// once successfully sent to an application service.
|
||||
func (s *eventsStatements) selectEventsByApplicationServiceID(
|
||||
ctx context.Context,
|
||||
applicationServiceID string,
|
||||
limit int,
|
||||
) (
|
||||
txnID, maxID int,
|
||||
events []gomatrixserverlib.HeaderedEvent,
|
||||
eventsRemaining bool,
|
||||
err error,
|
||||
) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"appservice": applicationServiceID,
|
||||
}).WithError(err).Fatalf("appservice unable to select new events to send")
|
||||
}
|
||||
}()
|
||||
// Retrieve events from the database. Unsuccessfully sent events first
|
||||
eventRows, err := s.selectEventsByApplicationServiceIDStmt.QueryContext(ctx, applicationServiceID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer checkNamedErr(eventRows.Close, &err)
|
||||
events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil
|
||||
func checkNamedErr(fn func() error, err *error) {
|
||||
if e := fn(); e != nil && *err == nil {
|
||||
*err = e
|
||||
}
|
||||
}
|
||||
|
||||
func retrieveEvents(eventRows *sql.Rows, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) {
|
||||
// Get current time for use in calculating event age
|
||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
|
||||
// Iterate through each row and store event contents
|
||||
// If txn_id changes dramatically, we've switched from collecting old events to
|
||||
// new ones. Send back those events first.
|
||||
lastTxnID := invalidTxnID
|
||||
for eventsProcessed := 0; eventRows.Next(); {
|
||||
var event gomatrixserverlib.HeaderedEvent
|
||||
var eventJSON []byte
|
||||
var id int
|
||||
err = eventRows.Scan(
|
||||
&id,
|
||||
&eventJSON,
|
||||
&txnID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, 0, false, err
|
||||
}
|
||||
|
||||
// Unmarshal eventJSON
|
||||
if err = json.Unmarshal(eventJSON, &event); err != nil {
|
||||
return nil, 0, 0, false, err
|
||||
}
|
||||
|
||||
// If txnID has changed on this event from the previous event, then we've
|
||||
// reached the end of a transaction's events. Return only those events.
|
||||
if lastTxnID > invalidTxnID && lastTxnID != txnID {
|
||||
return events, maxID, lastTxnID, true, nil
|
||||
}
|
||||
lastTxnID = txnID
|
||||
|
||||
// Limit events that aren't part of an old transaction
|
||||
if txnID == -1 {
|
||||
// Return if we've hit the limit
|
||||
if eventsProcessed++; eventsProcessed > limit {
|
||||
return events, maxID, lastTxnID, true, nil
|
||||
}
|
||||
}
|
||||
|
||||
if id > maxID {
|
||||
maxID = id
|
||||
}
|
||||
|
||||
// Portion of the event that is unsigned due to rapid change
|
||||
// TODO: Consider removing age as not many app services use it
|
||||
if err = event.SetUnsignedField("age", nowMilli-int64(event.OriginServerTS())); err != nil {
|
||||
return nil, 0, 0, false, err
|
||||
}
|
||||
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// countEventsByApplicationServiceID inserts an event mapped to its corresponding application service
|
||||
// IDs into the db.
|
||||
func (s *eventsStatements) countEventsByApplicationServiceID(
|
||||
ctx context.Context,
|
||||
appServiceID string,
|
||||
) (int, error) {
|
||||
var count int
|
||||
err := s.countEventsByApplicationServiceIDStmt.QueryRowContext(ctx, appServiceID).Scan(&count)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// insertEvent inserts an event mapped to its corresponding application service
|
||||
// IDs into the db.
|
||||
func (s *eventsStatements) insertEvent(
|
||||
ctx context.Context,
|
||||
appServiceID string,
|
||||
event *gomatrixserverlib.HeaderedEvent,
|
||||
) (err error) {
|
||||
// Convert event to JSON before inserting
|
||||
eventJSON, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
_, err := s.insertEventStmt.ExecContext(
|
||||
ctx,
|
||||
appServiceID,
|
||||
eventJSON,
|
||||
-1, // No transaction ID yet
|
||||
)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// updateTxnIDForEvents sets the transactionID for a collection of events. Done
|
||||
// before sending them to an AppService. Referenced before sending to make sure
|
||||
// we aren't constructing multiple transactions with the same events.
|
||||
func (s *eventsStatements) updateTxnIDForEvents(
|
||||
ctx context.Context,
|
||||
appserviceID string,
|
||||
maxID, txnID int,
|
||||
) (err error) {
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
_, err := s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// deleteEventsBeforeAndIncludingID removes events matching given IDs from the database.
|
||||
func (s *eventsStatements) deleteEventsBeforeAndIncludingID(
|
||||
ctx context.Context,
|
||||
appserviceID string,
|
||||
eventTableID int,
|
||||
) (err error) {
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
_, err := s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID)
|
||||
return err
|
||||
})
|
||||
}
|
||||
121
appservice/storage/cosmosdb/storage.go
Normal file
121
appservice/storage/cosmosdb/storage.go
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
// Copyright 2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
// Import SQLite database driver
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// Database stores events intended to be later sent to application services
|
||||
type Database struct {
|
||||
sqlutil.PartitionOffsetStatements
|
||||
events eventsStatements
|
||||
txnID txnStatements
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
}
|
||||
|
||||
// NewDatabase opens a new database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
var result Database
|
||||
var err error
|
||||
if result.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.writer = sqlutil.NewExclusiveWriter()
|
||||
if err = result.prepare(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = result.PartitionOffsetStatements.Prepare(result.db, result.writer, "appservice"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (d *Database) prepare() error {
|
||||
if err := d.events.prepare(d.db, d.writer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return d.txnID.prepare(d.db, d.writer)
|
||||
}
|
||||
|
||||
// StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database
|
||||
// for a transaction worker to pull and later send to an application service.
|
||||
func (d *Database) StoreEvent(
|
||||
ctx context.Context,
|
||||
appServiceID string,
|
||||
event *gomatrixserverlib.HeaderedEvent,
|
||||
) error {
|
||||
return d.events.insertEvent(ctx, appServiceID, event)
|
||||
}
|
||||
|
||||
// GetEventsWithAppServiceID returns a slice of events and their IDs intended to
|
||||
// be sent to an application service given its ID.
|
||||
func (d *Database) GetEventsWithAppServiceID(
|
||||
ctx context.Context,
|
||||
appServiceID string,
|
||||
limit int,
|
||||
) (int, int, []gomatrixserverlib.HeaderedEvent, bool, error) {
|
||||
return d.events.selectEventsByApplicationServiceID(ctx, appServiceID, limit)
|
||||
}
|
||||
|
||||
// CountEventsWithAppServiceID returns the number of events destined for an
|
||||
// application service given its ID.
|
||||
func (d *Database) CountEventsWithAppServiceID(
|
||||
ctx context.Context,
|
||||
appServiceID string,
|
||||
) (int, error) {
|
||||
return d.events.countEventsByApplicationServiceID(ctx, appServiceID)
|
||||
}
|
||||
|
||||
// UpdateTxnIDForEvents takes in an application service ID and a
|
||||
// and stores them in the DB, unless the pair already exists, in
|
||||
// which case it updates them.
|
||||
func (d *Database) UpdateTxnIDForEvents(
|
||||
ctx context.Context,
|
||||
appserviceID string,
|
||||
maxID, txnID int,
|
||||
) error {
|
||||
return d.events.updateTxnIDForEvents(ctx, appserviceID, maxID, txnID)
|
||||
}
|
||||
|
||||
// RemoveEventsBeforeAndIncludingID removes all events from the database that
|
||||
// are less than or equal to a given maximum ID. IDs here are implemented as a
|
||||
// serial, thus this should always delete events in chronological order.
|
||||
func (d *Database) RemoveEventsBeforeAndIncludingID(
|
||||
ctx context.Context,
|
||||
appserviceID string,
|
||||
eventTableID int,
|
||||
) error {
|
||||
return d.events.deleteEventsBeforeAndIncludingID(ctx, appserviceID, eventTableID)
|
||||
}
|
||||
|
||||
// GetLatestTxnID returns the latest available transaction id
|
||||
func (d *Database) GetLatestTxnID(
|
||||
ctx context.Context,
|
||||
) (int, error) {
|
||||
return d.txnID.selectTxnID(ctx)
|
||||
}
|
||||
69
appservice/storage/cosmosdb/txn_id_counter_table.go
Normal file
69
appservice/storage/cosmosdb/txn_id_counter_table.go
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
// Copyright 2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
const txnIDSchema = `
|
||||
-- Keeps a count of the current transaction ID
|
||||
CREATE TABLE IF NOT EXISTS appservice_counters (
|
||||
name TEXT PRIMARY KEY NOT NULL,
|
||||
last_id INTEGER DEFAULT 1
|
||||
);
|
||||
INSERT OR IGNORE INTO appservice_counters (name, last_id) VALUES('txn_id', 1);
|
||||
`
|
||||
|
||||
const selectTxnIDSQL = `
|
||||
SELECT last_id FROM appservice_counters WHERE name='txn_id';
|
||||
UPDATE appservice_counters SET last_id=last_id+1 WHERE name='txn_id';
|
||||
`
|
||||
|
||||
type txnStatements struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
selectTxnIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *txnStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||
s.db = db
|
||||
s.writer = writer
|
||||
_, err = db.Exec(txnIDSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.selectTxnIDStmt, err = db.Prepare(selectTxnIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// selectTxnID selects the latest ascending transaction ID
|
||||
func (s *txnStatements) selectTxnID(
|
||||
ctx context.Context,
|
||||
) (txnID int, err error) {
|
||||
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
err := s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
@ -17,6 +17,7 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/appservice/storage/cosmosdb"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/appservice/storage/postgres"
|
||||
|
|
@ -28,6 +29,8 @@ import (
|
|||
// and sets DB connection parameters
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsCosmosDB():
|
||||
return cosmosdb.NewDatabase(dbProperties)
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
|
|
|
|||
|
|
@ -23,6 +23,8 @@ import (
|
|||
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsCosmosDB():
|
||||
return nil, fmt.Errorf("can't use CosmosDB implementation")
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
|
|
|
|||
394
dendrite-config-cosmosdb.yaml
Normal file
394
dendrite-config-cosmosdb.yaml
Normal file
|
|
@ -0,0 +1,394 @@
|
|||
# This is the Dendrite configuration file.
|
||||
#
|
||||
# The configuration is split up into sections - each Dendrite component has a
|
||||
# configuration section, in addition to the "global" section which applies to
|
||||
# all components.
|
||||
#
|
||||
# At a minimum, to get started, you will need to update the settings in the
|
||||
# "global" section for your deployment, and you will need to check that the
|
||||
# database "connection_string" line in each component section is correct.
|
||||
#
|
||||
# Each component with a "database" section can accept the following formats
|
||||
# for "connection_string":
|
||||
# SQLite: file:filename.db
|
||||
# file:///path/to/filename.db
|
||||
# PostgreSQL: postgresql://user:pass@hostname/database?params=...
|
||||
# CosmosDB: cosmosdb:filename.db
|
||||
# cosmosdb:///path/to/filename.db
|
||||
#
|
||||
# SQLite is embedded into Dendrite and therefore no further prerequisites are
|
||||
# needed for the database when using SQLite mode. However, performance with
|
||||
# PostgreSQL is significantly better and recommended for multi-user deployments.
|
||||
# SQLite is typically around 20-30% slower than PostgreSQL when tested with a
|
||||
# small number of users and likely will perform worse still with a higher volume
|
||||
# of users.
|
||||
#
|
||||
# The "max_open_conns" and "max_idle_conns" settings configure the maximum
|
||||
# number of open/idle database connections. The value 0 will use the database
|
||||
# engine default, and a negative value will use unlimited connections. The
|
||||
# "conn_max_lifetime" option controls the maximum length of time a database
|
||||
# connection can be idle in seconds - a negative value is unlimited.
|
||||
|
||||
# The version of the configuration file.
|
||||
version: 1
|
||||
|
||||
# Global Matrix configuration. This configuration applies to all components.
|
||||
global:
|
||||
# The domain name of this homeserver.
|
||||
server_name: localhost
|
||||
|
||||
# The path to the signing private key file, used to sign requests and events.
|
||||
# Note that this is NOT the same private key as used for TLS! To generate a
|
||||
# signing key, use "./bin/generate-keys --private-key matrix_key.pem".
|
||||
private_key: matrix_key.pem
|
||||
|
||||
# The paths and expiry timestamps (as a UNIX timestamp in millisecond precision)
|
||||
# to old signing private keys that were formerly in use on this domain. These
|
||||
# keys will not be used for federation request or event signing, but will be
|
||||
# provided to any other homeserver that asks when trying to verify old events.
|
||||
# old_private_keys:
|
||||
# - private_key: old_matrix_key.pem
|
||||
# expired_at: 1601024554498
|
||||
|
||||
# How long a remote server can cache our server signing key before requesting it
|
||||
# again. Increasing this number will reduce the number of requests made by other
|
||||
# servers for our key but increases the period that a compromised key will be
|
||||
# considered valid by other homeservers.
|
||||
key_validity_period: 168h0m0s
|
||||
|
||||
# Lists of domains that the server will trust as identity servers to verify third
|
||||
# party identifiers such as phone numbers and email addresses.
|
||||
trusted_third_party_id_servers:
|
||||
- matrix.org
|
||||
- vector.im
|
||||
|
||||
# Disables federation. Dendrite will not be able to make any outbound HTTP requests
|
||||
# to other servers and the federation API will not be exposed.
|
||||
disable_federation: false
|
||||
|
||||
# Configuration for Kafka/Naffka.
|
||||
kafka:
|
||||
# List of Kafka broker addresses to connect to. This is not needed if using
|
||||
# Naffka in monolith mode.
|
||||
addresses:
|
||||
- localhost:2181
|
||||
|
||||
# The prefix to use for Kafka topic names for this homeserver. Change this only if
|
||||
# you are running more than one Dendrite homeserver on the same Kafka deployment.
|
||||
topic_prefix: Dendrite
|
||||
|
||||
# Whether to use Naffka instead of Kafka. This is only available in monolith
|
||||
# mode, but means that you can run a single-process server without requiring
|
||||
# Kafka.
|
||||
use_naffka: true
|
||||
|
||||
# The max size a Kafka message is allowed to use.
|
||||
# You only need to change this value, if you encounter issues with too large messages.
|
||||
# Must be less than/equal to "max.message.bytes" configured in Kafka.
|
||||
# Defaults to 8388608 bytes.
|
||||
# max_message_bytes: 8388608
|
||||
|
||||
# Naffka database options. Not required when using Kafka.
|
||||
naffka_database:
|
||||
connection_string: cosmosdb:naffka.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
||||
# Configuration for Prometheus metric collection.
|
||||
metrics:
|
||||
# Whether or not Prometheus metrics are enabled.
|
||||
enabled: false
|
||||
|
||||
# HTTP basic authentication to protect access to monitoring.
|
||||
basic_auth:
|
||||
username: metrics
|
||||
password: metrics
|
||||
|
||||
# DNS cache options. The DNS cache may reduce the load on DNS servers
|
||||
# if there is no local caching resolver available for use.
|
||||
dns_cache:
|
||||
# Whether or not the DNS cache is enabled.
|
||||
enabled: false
|
||||
|
||||
# Maximum number of entries to hold in the DNS cache, and
|
||||
# for how long those items should be considered valid in seconds.
|
||||
cache_size: 256
|
||||
cache_lifetime: 300
|
||||
|
||||
# Configuration for the Appservice API.
|
||||
app_service_api:
|
||||
internal_api:
|
||||
listen: http://localhost:7777
|
||||
connect: http://localhost:7777
|
||||
database:
|
||||
connection_string: cosmosdb:appservice.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
||||
# Disable the validation of TLS certificates of appservices. This is
|
||||
# not recommended in production since it may allow appservice traffic
|
||||
# to be sent to an unverified endpoint.
|
||||
disable_tls_validation: false
|
||||
|
||||
# Appservice configuration files to load into this homeserver.
|
||||
config_files: []
|
||||
|
||||
# Configuration for the Client API.
|
||||
client_api:
|
||||
internal_api:
|
||||
listen: http://localhost:7771
|
||||
connect: http://localhost:7771
|
||||
external_api:
|
||||
listen: http://[::]:8071
|
||||
|
||||
# Prevents new users from being able to register on this homeserver, except when
|
||||
# using the registration shared secret below.
|
||||
registration_disabled: false
|
||||
|
||||
# If set, allows registration by anyone who knows the shared secret, regardless of
|
||||
# whether registration is otherwise disabled.
|
||||
registration_shared_secret: ""
|
||||
|
||||
# Whether to require reCAPTCHA for registration.
|
||||
enable_registration_captcha: false
|
||||
|
||||
# Settings for ReCAPTCHA.
|
||||
recaptcha_public_key: ""
|
||||
recaptcha_private_key: ""
|
||||
recaptcha_bypass_secret: ""
|
||||
recaptcha_siteverify_api: ""
|
||||
|
||||
# TURN server information that this homeserver should send to clients.
|
||||
turn:
|
||||
turn_user_lifetime: ""
|
||||
turn_uris: []
|
||||
turn_shared_secret: ""
|
||||
turn_username: ""
|
||||
turn_password: ""
|
||||
|
||||
# Settings for rate-limited endpoints. Rate limiting will kick in after the
|
||||
# threshold number of "slots" have been taken by requests from a specific
|
||||
# host. Each "slot" will be released after the cooloff time in milliseconds.
|
||||
rate_limiting:
|
||||
enabled: true
|
||||
threshold: 5
|
||||
cooloff_ms: 500
|
||||
|
||||
# Configuration for the EDU server.
|
||||
edu_server:
|
||||
internal_api:
|
||||
listen: http://localhost:7778
|
||||
connect: http://localhost:7778
|
||||
|
||||
# Configuration for the Federation API.
|
||||
federation_api:
|
||||
internal_api:
|
||||
listen: http://localhost:7772
|
||||
connect: http://localhost:7772
|
||||
external_api:
|
||||
listen: http://[::]:8072
|
||||
|
||||
# List of paths to X.509 certificates to be used by the external federation listeners.
|
||||
# These certificates will be used to calculate the TLS fingerprints and other servers
|
||||
# will expect the certificate to match these fingerprints. Certificates must be in PEM
|
||||
# format.
|
||||
federation_certificates: []
|
||||
|
||||
# Configuration for the Federation Sender.
|
||||
federation_sender:
|
||||
internal_api:
|
||||
listen: http://localhost:7775
|
||||
connect: http://localhost:7775
|
||||
database:
|
||||
connection_string: cosmosdb:federationsender.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
||||
# How many times we will try to resend a failed transaction to a specific server. The
|
||||
# backoff is 2**x seconds, so 1 = 2 seconds, 2 = 4 seconds, 3 = 8 seconds etc.
|
||||
send_max_retries: 16
|
||||
|
||||
# Disable the validation of TLS certificates of remote federated homeservers. Do not
|
||||
# enable this option in production as it presents a security risk!
|
||||
disable_tls_validation: false
|
||||
|
||||
# Use the following proxy server for outbound federation traffic.
|
||||
proxy_outbound:
|
||||
enabled: false
|
||||
protocol: http
|
||||
host: localhost
|
||||
port: 8080
|
||||
|
||||
# Configuration for the Key Server (for end-to-end encryption).
|
||||
key_server:
|
||||
internal_api:
|
||||
listen: http://localhost:7779
|
||||
connect: http://localhost:7779
|
||||
database:
|
||||
connection_string: cosmosdb:keyserver.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
||||
# Configuration for the Media API.
|
||||
media_api:
|
||||
internal_api:
|
||||
listen: http://localhost:7774
|
||||
connect: http://localhost:7774
|
||||
external_api:
|
||||
listen: http://[::]:8074
|
||||
database:
|
||||
connection_string: cosmosdb:mediaapi.db
|
||||
max_open_conns: 5
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
||||
# Storage path for uploaded media. May be relative or absolute.
|
||||
base_path: ./media_store
|
||||
|
||||
# The maximum allowed file size (in bytes) for media uploads to this homeserver
|
||||
# (0 = unlimited). If using a reverse proxy, ensure it allows requests at
|
||||
# least this large (e.g. client_max_body_size in nginx.)
|
||||
max_file_size_bytes: 10485760
|
||||
|
||||
# Whether to dynamically generate thumbnails if needed.
|
||||
dynamic_thumbnails: false
|
||||
|
||||
# The maximum number of simultaneous thumbnail generators to run.
|
||||
max_thumbnail_generators: 10
|
||||
|
||||
# A list of thumbnail sizes to be generated for media content.
|
||||
thumbnail_sizes:
|
||||
- width: 32
|
||||
height: 32
|
||||
method: crop
|
||||
- width: 96
|
||||
height: 96
|
||||
method: crop
|
||||
- width: 640
|
||||
height: 480
|
||||
method: scale
|
||||
|
||||
# Configuration for experimental MSC's
|
||||
mscs:
|
||||
# A list of enabled MSC's
|
||||
# Currently valid values are:
|
||||
# - msc2836 (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836)
|
||||
# - msc2946 (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946)
|
||||
mscs: []
|
||||
database:
|
||||
connection_string: cosmosdb:mscs.db
|
||||
max_open_conns: 5
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
||||
# Configuration for the Room Server.
|
||||
room_server:
|
||||
internal_api:
|
||||
listen: http://localhost:7770
|
||||
connect: http://localhost:7770
|
||||
database:
|
||||
connection_string: cosmosdb:roomserver.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
||||
# Configuration for the Signing Key Server (for server signing keys).
|
||||
signing_key_server:
|
||||
internal_api:
|
||||
listen: http://localhost:7780
|
||||
connect: http://localhost:7780
|
||||
database:
|
||||
connection_string: cosmosdb:signingkeyserver.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
||||
# Perspective keyservers to use as a backup when direct key fetches fail. This may
|
||||
# be required to satisfy key requests for servers that are no longer online when
|
||||
# joining some rooms.
|
||||
key_perspectives:
|
||||
- server_name: matrix.org
|
||||
keys:
|
||||
- key_id: ed25519:auto
|
||||
public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw
|
||||
- key_id: ed25519:a_RXGa
|
||||
public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ
|
||||
|
||||
# This option will control whether Dendrite will prefer to look up keys directly
|
||||
# or whether it should try perspective servers first, using direct fetches as a
|
||||
# last resort.
|
||||
prefer_direct_fetch: false
|
||||
|
||||
# Configuration for the Sync API.
|
||||
sync_api:
|
||||
internal_api:
|
||||
listen: http://localhost:7773
|
||||
connect: http://localhost:7773
|
||||
external_api:
|
||||
listen: http://[::]:8073
|
||||
database:
|
||||
connection_string: cosmosdb:syncapi.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
||||
# This option controls which HTTP header to inspect to find the real remote IP
|
||||
# address of the client. This is likely required if Dendrite is running behind
|
||||
# a reverse proxy server.
|
||||
# real_ip_header: X-Real-IP
|
||||
|
||||
# Configuration for the User API.
|
||||
user_api:
|
||||
# The cost when hashing passwords on registration/login. Default: 10. Min: 4, Max: 31
|
||||
# See https://pkg.go.dev/golang.org/x/crypto/bcrypt for more information.
|
||||
# Setting this lower makes registration/login consume less CPU resources at the cost of security
|
||||
# should the database be compromised. Setting this higher makes registration/login consume more
|
||||
# CPU resources but makes it harder to brute force password hashes.
|
||||
# This value can be low if performing tests or on embedded Dendrite instances (e.g WASM builds)
|
||||
# bcrypt_cost: 10
|
||||
internal_api:
|
||||
listen: http://localhost:7781
|
||||
connect: http://localhost:7781
|
||||
account_database:
|
||||
connection_string: cosmosdb:userapi_accounts.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
device_database:
|
||||
connection_string: cosmosdb:userapi_devices.db
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
# The length of time that a token issued for a relying party from
|
||||
# /_matrix/client/r0/user/{userId}/openid/request_token endpoint
|
||||
# is considered to be valid in milliseconds.
|
||||
# The default lifetime is 3600000ms (60 minutes).
|
||||
# openid_token_lifetime_ms: 3600000
|
||||
|
||||
# Configuration for Opentracing.
|
||||
# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on
|
||||
# how this works and how to set it up.
|
||||
tracing:
|
||||
enabled: false
|
||||
jaeger:
|
||||
serviceName: ""
|
||||
disabled: false
|
||||
rpc_metrics: false
|
||||
tags: []
|
||||
sampler: null
|
||||
reporter: null
|
||||
headers: null
|
||||
baggage_restrictions: null
|
||||
throttler: null
|
||||
|
||||
# Logging configuration, in addition to the standard logging that is sent to
|
||||
# stdout by Dendrite.
|
||||
logging:
|
||||
- type: file
|
||||
level: info
|
||||
params:
|
||||
path: ./logs
|
||||
107
federationsender/storage/cosmosdb/blacklist_table.go
Normal file
107
federationsender/storage/cosmosdb/blacklist_table.go
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const blacklistSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_blacklist (
|
||||
-- The blacklisted server name
|
||||
server_name TEXT NOT NULL,
|
||||
UNIQUE (server_name)
|
||||
);
|
||||
`
|
||||
|
||||
const insertBlacklistSQL = "" +
|
||||
"INSERT INTO federationsender_blacklist (server_name) VALUES ($1)" +
|
||||
" ON CONFLICT DO NOTHING"
|
||||
|
||||
const selectBlacklistSQL = "" +
|
||||
"SELECT server_name FROM federationsender_blacklist WHERE server_name = $1"
|
||||
|
||||
const deleteBlacklistSQL = "" +
|
||||
"DELETE FROM federationsender_blacklist WHERE server_name = $1"
|
||||
|
||||
type blacklistStatements struct {
|
||||
db *sql.DB
|
||||
insertBlacklistStmt *sql.Stmt
|
||||
selectBlacklistStmt *sql.Stmt
|
||||
deleteBlacklistStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
|
||||
s = &blacklistStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(blacklistSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.insertBlacklistStmt, err = db.Prepare(insertBlacklistSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectBlacklistStmt, err = db.Prepare(selectBlacklistSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteBlacklistStmt, err = db.Prepare(deleteBlacklistSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// insertRoom inserts the room if it didn't already exist.
|
||||
// If the room didn't exist then last_event_id is set to the empty string.
|
||||
func (s *blacklistStatements) InsertBlacklist(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
|
||||
// The row must already exist in the table. Callers can ensure that the row
|
||||
// exists by calling insertRoom first.
|
||||
func (s *blacklistStatements) SelectBlacklist(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (bool, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt)
|
||||
res, err := stmt.QueryContext(ctx, serverName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer res.Close() // nolint:errcheck
|
||||
// The query will return the server name if the server is blacklisted, and
|
||||
// will return no rows if not. By calling Next, we find out if a row was
|
||||
// returned or not - we don't care about the value itself.
|
||||
return res.Next(), nil
|
||||
}
|
||||
|
||||
// updateRoom updates the last_event_id for the room. selectRoomForUpdate should
|
||||
// have already been called earlier within the transaction.
|
||||
func (s *blacklistStatements) DeleteBlacklist(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName)
|
||||
return err
|
||||
}
|
||||
176
federationsender/storage/cosmosdb/inbound_peeks_table.go
Normal file
176
federationsender/storage/cosmosdb/inbound_peeks_table.go
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationsender/types"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const inboundPeeksSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_inbound_peeks (
|
||||
room_id TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
peek_id TEXT NOT NULL,
|
||||
creation_ts INTEGER NOT NULL,
|
||||
renewed_ts INTEGER NOT NULL,
|
||||
renewal_interval INTEGER NOT NULL,
|
||||
UNIQUE (room_id, server_name, peek_id)
|
||||
);
|
||||
`
|
||||
|
||||
const insertInboundPeekSQL = "" +
|
||||
"INSERT INTO federationsender_inbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
|
||||
const selectInboundPeekSQL = "" +
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||
|
||||
const selectInboundPeeksSQL = "" +
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
|
||||
|
||||
const renewInboundPeekSQL = "" +
|
||||
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
||||
|
||||
const deleteInboundPeekSQL = "" +
|
||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
|
||||
|
||||
const deleteInboundPeeksSQL = "" +
|
||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
|
||||
|
||||
type inboundPeeksStatements struct {
|
||||
db *sql.DB
|
||||
insertInboundPeekStmt *sql.Stmt
|
||||
selectInboundPeekStmt *sql.Stmt
|
||||
selectInboundPeeksStmt *sql.Stmt
|
||||
renewInboundPeekStmt *sql.Stmt
|
||||
deleteInboundPeekStmt *sql.Stmt
|
||||
deleteInboundPeeksStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSQLiteInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err error) {
|
||||
s = &inboundPeeksStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(inboundPeeksSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) InsertInboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
|
||||
) (err error) {
|
||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
stmt := sqlutil.TxStmt(txn, s.insertInboundPeekStmt)
|
||||
_, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) RenewInboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
|
||||
) (err error) {
|
||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
_, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) SelectInboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
|
||||
) (*types.InboundPeek, error) {
|
||||
row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID)
|
||||
inboundPeek := types.InboundPeek{}
|
||||
err := row.Scan(
|
||||
&inboundPeek.RoomID,
|
||||
&inboundPeek.ServerName,
|
||||
&inboundPeek.PeekID,
|
||||
&inboundPeek.CreationTimestamp,
|
||||
&inboundPeek.RenewedTimestamp,
|
||||
&inboundPeek.RenewalInterval,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &inboundPeek, nil
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) SelectInboundPeeks(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (inboundPeeks []types.InboundPeek, err error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectInboundPeeks: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
inboundPeek := types.InboundPeek{}
|
||||
if err = rows.Scan(
|
||||
&inboundPeek.RoomID,
|
||||
&inboundPeek.ServerName,
|
||||
&inboundPeek.PeekID,
|
||||
&inboundPeek.CreationTimestamp,
|
||||
&inboundPeek.RenewedTimestamp,
|
||||
&inboundPeek.RenewalInterval,
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
inboundPeeks = append(inboundPeeks, inboundPeek)
|
||||
}
|
||||
|
||||
return inboundPeeks, rows.Err()
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) DeleteInboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteInboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) DeleteInboundPeeks(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteInboundPeeksStmt).ExecContext(ctx, roomID)
|
||||
return
|
||||
}
|
||||
219
federationsender/storage/cosmosdb/joined_hosts_table.go
Normal file
219
federationsender/storage/cosmosdb/joined_hosts_table.go
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationsender/types"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const joinedHostsSchema = `
|
||||
-- The joined_hosts table stores a list of m.room.member event ids in the
|
||||
-- current state for each room where the membership is "join".
|
||||
-- There will be an entry for every user that is joined to the room.
|
||||
CREATE TABLE IF NOT EXISTS federationsender_joined_hosts (
|
||||
-- The string ID of the room.
|
||||
room_id TEXT NOT NULL,
|
||||
-- The event ID of the m.room.member join event.
|
||||
event_id TEXT NOT NULL,
|
||||
-- The domain part of the user ID the m.room.member event is for.
|
||||
server_name TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx
|
||||
ON federationsender_joined_hosts (event_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS federatonsender_joined_hosts_room_id_idx
|
||||
ON federationsender_joined_hosts (room_id)
|
||||
`
|
||||
|
||||
const insertJoinedHostsSQL = "" +
|
||||
"INSERT OR IGNORE INTO federationsender_joined_hosts (room_id, event_id, server_name)" +
|
||||
" VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteJoinedHostsSQL = "" +
|
||||
"DELETE FROM federationsender_joined_hosts WHERE event_id = $1"
|
||||
|
||||
const deleteJoinedHostsForRoomSQL = "" +
|
||||
"DELETE FROM federationsender_joined_hosts WHERE room_id = $1"
|
||||
|
||||
const selectJoinedHostsSQL = "" +
|
||||
"SELECT event_id, server_name FROM federationsender_joined_hosts" +
|
||||
" WHERE room_id = $1"
|
||||
|
||||
const selectAllJoinedHostsSQL = "" +
|
||||
"SELECT DISTINCT server_name FROM federationsender_joined_hosts"
|
||||
|
||||
const selectJoinedHostsForRoomsSQL = "" +
|
||||
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
|
||||
|
||||
type joinedHostsStatements struct {
|
||||
db *sql.DB
|
||||
insertJoinedHostsStmt *sql.Stmt
|
||||
deleteJoinedHostsStmt *sql.Stmt
|
||||
deleteJoinedHostsForRoomStmt *sql.Stmt
|
||||
selectJoinedHostsStmt *sql.Stmt
|
||||
selectAllJoinedHostsStmt *sql.Stmt
|
||||
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
}
|
||||
|
||||
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
|
||||
s = &joinedHostsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(joinedHostsSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertJoinedHostsStmt, err = db.Prepare(insertJoinedHostsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) InsertJoinedHosts(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
roomID, eventID string,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) DeleteJoinedHosts(
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
) error {
|
||||
for _, eventID := range eventIDs {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
|
||||
if _, err := stmt.ExecContext(ctx, eventID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) DeleteJoinedHostsForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) ([]types.JoinedHost, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt)
|
||||
return joinedHostsFromStmt(ctx, stmt, roomID)
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) SelectJoinedHosts(
|
||||
ctx context.Context, roomID string,
|
||||
) ([]types.JoinedHost, error) {
|
||||
return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID)
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) SelectAllJoinedHosts(
|
||||
ctx context.Context,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed")
|
||||
|
||||
var result []gomatrixserverlib.ServerName
|
||||
for rows.Next() {
|
||||
var serverName string
|
||||
if err = rows.Scan(&serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, gomatrixserverlib.ServerName(serverName))
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
|
||||
ctx context.Context, roomIDs []string,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
iRoomIDs := make([]interface{}, len(roomIDs))
|
||||
for i := range roomIDs {
|
||||
iRoomIDs[i] = roomIDs[i]
|
||||
}
|
||||
|
||||
sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
|
||||
rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed")
|
||||
|
||||
var result []gomatrixserverlib.ServerName
|
||||
for rows.Next() {
|
||||
var serverName string
|
||||
if err = rows.Scan(&serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, gomatrixserverlib.ServerName(serverName))
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func joinedHostsFromStmt(
|
||||
ctx context.Context, stmt *sql.Stmt, roomID string,
|
||||
) ([]types.JoinedHost, error) {
|
||||
rows, err := stmt.QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "joinedHostsFromStmt: rows.close() failed")
|
||||
|
||||
var result []types.JoinedHost
|
||||
for rows.Next() {
|
||||
var eventID, serverName string
|
||||
if err = rows.Scan(&eventID, &serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, types.JoinedHost{
|
||||
MemberEventID: eventID,
|
||||
ServerName: gomatrixserverlib.ServerName(serverName),
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
176
federationsender/storage/cosmosdb/outbound_peeks_table.go
Normal file
176
federationsender/storage/cosmosdb/outbound_peeks_table.go
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationsender/types"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const outboundPeeksSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_outbound_peeks (
|
||||
room_id TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
peek_id TEXT NOT NULL,
|
||||
creation_ts INTEGER NOT NULL,
|
||||
renewed_ts INTEGER NOT NULL,
|
||||
renewal_interval INTEGER NOT NULL,
|
||||
UNIQUE (room_id, server_name, peek_id)
|
||||
);
|
||||
`
|
||||
|
||||
const insertOutboundPeekSQL = "" +
|
||||
"INSERT INTO federationsender_outbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
|
||||
const selectOutboundPeekSQL = "" +
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||
|
||||
const selectOutboundPeeksSQL = "" +
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
|
||||
|
||||
const renewOutboundPeekSQL = "" +
|
||||
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
||||
|
||||
const deleteOutboundPeekSQL = "" +
|
||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2"
|
||||
|
||||
const deleteOutboundPeeksSQL = "" +
|
||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
|
||||
|
||||
type outboundPeeksStatements struct {
|
||||
db *sql.DB
|
||||
insertOutboundPeekStmt *sql.Stmt
|
||||
selectOutboundPeekStmt *sql.Stmt
|
||||
selectOutboundPeeksStmt *sql.Stmt
|
||||
renewOutboundPeekStmt *sql.Stmt
|
||||
deleteOutboundPeekStmt *sql.Stmt
|
||||
deleteOutboundPeeksStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err error) {
|
||||
s = &outboundPeeksStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(outboundPeeksSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) InsertOutboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
|
||||
) (err error) {
|
||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
stmt := sqlutil.TxStmt(txn, s.insertOutboundPeekStmt)
|
||||
_, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) RenewOutboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
|
||||
) (err error) {
|
||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
_, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) SelectOutboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
|
||||
) (*types.OutboundPeek, error) {
|
||||
row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID)
|
||||
outboundPeek := types.OutboundPeek{}
|
||||
err := row.Scan(
|
||||
&outboundPeek.RoomID,
|
||||
&outboundPeek.ServerName,
|
||||
&outboundPeek.PeekID,
|
||||
&outboundPeek.CreationTimestamp,
|
||||
&outboundPeek.RenewedTimestamp,
|
||||
&outboundPeek.RenewalInterval,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &outboundPeek, nil
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) SelectOutboundPeeks(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (outboundPeeks []types.OutboundPeek, err error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectOutboundPeeks: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
outboundPeek := types.OutboundPeek{}
|
||||
if err = rows.Scan(
|
||||
&outboundPeek.RoomID,
|
||||
&outboundPeek.ServerName,
|
||||
&outboundPeek.PeekID,
|
||||
&outboundPeek.CreationTimestamp,
|
||||
&outboundPeek.RenewedTimestamp,
|
||||
&outboundPeek.RenewalInterval,
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
outboundPeeks = append(outboundPeeks, outboundPeek)
|
||||
}
|
||||
|
||||
return outboundPeeks, rows.Err()
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) DeleteOutboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteOutboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) DeleteOutboundPeeks(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteOutboundPeeksStmt).ExecContext(ctx, roomID)
|
||||
return
|
||||
}
|
||||
207
federationsender/storage/cosmosdb/queue_edus_table.go
Normal file
207
federationsender/storage/cosmosdb/queue_edus_table.go
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const queueEDUsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
|
||||
-- The type of the event (informational).
|
||||
edu_type TEXT NOT NULL,
|
||||
-- The domain part of the user ID the EDU event is for.
|
||||
server_name TEXT NOT NULL,
|
||||
-- The JSON NID from the federationsender_queue_edus_json table.
|
||||
json_nid BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
|
||||
ON federationsender_queue_edus (json_nid, server_name);
|
||||
`
|
||||
|
||||
const insertQueueEDUSQL = "" +
|
||||
"INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" +
|
||||
" VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteQueueEDUsSQL = "" +
|
||||
"DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)"
|
||||
|
||||
const selectQueueEDUSQL = "" +
|
||||
"SELECT json_nid FROM federationsender_queue_edus" +
|
||||
" WHERE server_name = $1" +
|
||||
" LIMIT $2"
|
||||
|
||||
const selectQueueEDUReferenceJSONCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
||||
" WHERE json_nid = $1"
|
||||
|
||||
const selectQueueEDUCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
||||
" WHERE server_name = $1"
|
||||
|
||||
const selectQueueServerNamesSQL = "" +
|
||||
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
|
||||
|
||||
type queueEDUsStatements struct {
|
||||
db *sql.DB
|
||||
insertQueueEDUStmt *sql.Stmt
|
||||
selectQueueEDUStmt *sql.Stmt
|
||||
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
||||
selectQueueEDUCountStmt *sql.Stmt
|
||||
selectQueueEDUServerNamesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
|
||||
s = &queueEDUsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(queueEDUsSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertQueueEDUStmt, err = db.Prepare(insertQueueEDUSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueEDUStmt, err = db.Prepare(selectQueueEDUSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueEDUReferenceJSONCountStmt, err = db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueEDUCountStmt, err = db.Prepare(selectQueueEDUCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueEDUServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) InsertQueueEDU(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
eduType string,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
nid int64,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
eduType, // the EDU type
|
||||
serverName, // destination server name
|
||||
nid, // JSON blob NID
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) DeleteQueueEDUs(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
jsonNIDs []int64,
|
||||
) error {
|
||||
deleteSQL := strings.Replace(deleteQueueEDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
|
||||
deleteStmt, err := txn.Prepare(deleteSQL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
|
||||
}
|
||||
|
||||
params := make([]interface{}, len(jsonNIDs)+1)
|
||||
params[0] = serverName
|
||||
for k, v := range jsonNIDs {
|
||||
params[k+1] = v
|
||||
}
|
||||
|
||||
stmt := sqlutil.TxStmt(txn, deleteStmt)
|
||||
_, err = stmt.ExecContext(ctx, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) SelectQueueEDUs(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
limit int,
|
||||
) ([]int64, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt)
|
||||
rows, err := stmt.QueryContext(ctx, serverName, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
|
||||
var result []int64
|
||||
for rows.Next() {
|
||||
var nid int64
|
||||
if err = rows.Scan(&nid); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, nid)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
|
||||
ctx context.Context, txn *sql.Tx, jsonNID int64,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt)
|
||||
err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
return -1, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) SelectQueueEDUCount(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
|
||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
// It's acceptable for there to be no rows referencing a given
|
||||
// JSON NID but it's not an error condition. Just return as if
|
||||
// there's a zero count.
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt)
|
||||
rows, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
|
||||
var result []gomatrixserverlib.ServerName
|
||||
for rows.Next() {
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
if err = rows.Scan(&serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, serverName)
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
||||
136
federationsender/storage/cosmosdb/queue_json_table.go
Normal file
136
federationsender/storage/cosmosdb/queue_json_table.go
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
const queueJSONSchema = `
|
||||
-- The queue_retry_json table contains event contents that
|
||||
-- we failed to send.
|
||||
CREATE TABLE IF NOT EXISTS federationsender_queue_json (
|
||||
-- The JSON NID. This allows the federationsender_queue_retry table to
|
||||
-- cross-reference to find the JSON blob.
|
||||
json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The JSON body. Text so that we preserve UTF-8.
|
||||
json_body TEXT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertJSONSQL = "" +
|
||||
"INSERT INTO federationsender_queue_json (json_body)" +
|
||||
" VALUES ($1)"
|
||||
|
||||
const deleteJSONSQL = "" +
|
||||
"DELETE FROM federationsender_queue_json WHERE json_nid IN ($1)"
|
||||
|
||||
const selectJSONSQL = "" +
|
||||
"SELECT json_nid, json_body FROM federationsender_queue_json" +
|
||||
" WHERE json_nid IN ($1)"
|
||||
|
||||
type queueJSONStatements struct {
|
||||
db *sql.DB
|
||||
insertJSONStmt *sql.Stmt
|
||||
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
}
|
||||
|
||||
func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
|
||||
s = &queueJSONStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(queueJSONSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertJSONStmt, err = db.Prepare(insertJSONSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *queueJSONStatements) InsertQueueJSON(
|
||||
ctx context.Context, txn *sql.Tx, json string,
|
||||
) (lastid int64, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
|
||||
res, err := stmt.ExecContext(ctx, json)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("stmt.QueryContext: %w", err)
|
||||
}
|
||||
lastid, err = res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("res.LastInsertId: %w", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *queueJSONStatements) DeleteQueueJSON(
|
||||
ctx context.Context, txn *sql.Tx, nids []int64,
|
||||
) error {
|
||||
deleteSQL := strings.Replace(deleteJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
|
||||
deleteStmt, err := txn.Prepare(deleteSQL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
|
||||
}
|
||||
|
||||
iNIDs := make([]interface{}, len(nids))
|
||||
for k, v := range nids {
|
||||
iNIDs[k] = v
|
||||
}
|
||||
|
||||
stmt := sqlutil.TxStmt(txn, deleteStmt)
|
||||
_, err = stmt.ExecContext(ctx, iNIDs...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *queueJSONStatements) SelectQueueJSON(
|
||||
ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
|
||||
) (map[int64][]byte, error) {
|
||||
selectSQL := strings.Replace(selectJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1)
|
||||
selectStmt, err := txn.Prepare(selectSQL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err)
|
||||
}
|
||||
|
||||
iNIDs := make([]interface{}, len(jsonNIDs))
|
||||
for k, v := range jsonNIDs {
|
||||
iNIDs[k] = v
|
||||
}
|
||||
|
||||
blobs := map[int64][]byte{}
|
||||
stmt := sqlutil.TxStmt(txn, selectStmt)
|
||||
rows, err := stmt.QueryContext(ctx, iNIDs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err)
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var nid int64
|
||||
var blob []byte
|
||||
if err = rows.Scan(&nid, &blob); err != nil {
|
||||
return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err)
|
||||
}
|
||||
blobs[nid] = blob
|
||||
}
|
||||
return blobs, err
|
||||
}
|
||||
235
federationsender/storage/cosmosdb/queue_pdus_table.go
Normal file
235
federationsender/storage/cosmosdb/queue_pdus_table.go
Normal file
|
|
@ -0,0 +1,235 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const queuePDUsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_queue_pdus (
|
||||
-- The transaction ID that was generated before persisting the event.
|
||||
transaction_id TEXT NOT NULL,
|
||||
-- The domain part of the user ID the m.room.member event is for.
|
||||
server_name TEXT NOT NULL,
|
||||
-- The JSON NID from the federationsender_queue_pdus_json table.
|
||||
json_nid BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx
|
||||
ON federationsender_queue_pdus (json_nid, server_name);
|
||||
`
|
||||
|
||||
const insertQueuePDUSQL = "" +
|
||||
"INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" +
|
||||
" VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteQueuePDUsSQL = "" +
|
||||
"DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid IN ($2)"
|
||||
|
||||
const selectQueueNextTransactionIDSQL = "" +
|
||||
"SELECT transaction_id FROM federationsender_queue_pdus" +
|
||||
" WHERE server_name = $1" +
|
||||
" ORDER BY transaction_id ASC" +
|
||||
" LIMIT 1"
|
||||
|
||||
const selectQueuePDUsSQL = "" +
|
||||
"SELECT json_nid FROM federationsender_queue_pdus" +
|
||||
" WHERE server_name = $1" +
|
||||
" LIMIT $2"
|
||||
|
||||
const selectQueuePDUsReferenceJSONCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
||||
" WHERE json_nid = $1"
|
||||
|
||||
const selectQueuePDUsCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
||||
" WHERE server_name = $1"
|
||||
|
||||
const selectQueuePDUsServerNamesSQL = "" +
|
||||
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
|
||||
|
||||
type queuePDUsStatements struct {
|
||||
db *sql.DB
|
||||
insertQueuePDUStmt *sql.Stmt
|
||||
selectQueueNextTransactionIDStmt *sql.Stmt
|
||||
selectQueuePDUsStmt *sql.Stmt
|
||||
selectQueueReferenceJSONCountStmt *sql.Stmt
|
||||
selectQueuePDUsCountStmt *sql.Stmt
|
||||
selectQueueServerNamesStmt *sql.Stmt
|
||||
// deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
}
|
||||
|
||||
func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
|
||||
s = &queuePDUsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(queuePDUsSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertQueuePDUStmt, err = db.Prepare(insertQueuePDUSQL); err != nil {
|
||||
return
|
||||
}
|
||||
//if s.deleteQueuePDUsStmt, err = db.Prepare(deleteQueuePDUsSQL); err != nil {
|
||||
// return
|
||||
//}
|
||||
if s.selectQueueNextTransactionIDStmt, err = db.Prepare(selectQueueNextTransactionIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueuePDUsStmt, err = db.Prepare(selectQueuePDUsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) InsertQueuePDU(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
transactionID gomatrixserverlib.TransactionID,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
nid int64,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
transactionID, // the transaction ID that we initially attempted
|
||||
serverName, // destination server name
|
||||
nid, // JSON blob NID
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) DeleteQueuePDUs(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
jsonNIDs []int64,
|
||||
) error {
|
||||
deleteSQL := strings.Replace(deleteQueuePDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
|
||||
deleteStmt, err := txn.Prepare(deleteSQL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
|
||||
}
|
||||
|
||||
params := make([]interface{}, len(jsonNIDs)+1)
|
||||
params[0] = serverName
|
||||
for k, v := range jsonNIDs {
|
||||
params[k+1] = v
|
||||
}
|
||||
|
||||
stmt := sqlutil.TxStmt(txn, deleteStmt)
|
||||
_, err = stmt.ExecContext(ctx, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (gomatrixserverlib.TransactionID, error) {
|
||||
var transactionID gomatrixserverlib.TransactionID
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt)
|
||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
return transactionID, err
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
|
||||
ctx context.Context, txn *sql.Tx, jsonNID int64,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt)
|
||||
err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
return -1, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) SelectQueuePDUCount(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
|
||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
// It's acceptable for there to be no rows referencing a given
|
||||
// JSON NID but it's not an error condition. Just return as if
|
||||
// there's a zero count.
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) SelectQueuePDUs(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
limit int,
|
||||
) ([]int64, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, serverName, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
|
||||
var result []int64
|
||||
for rows.Next() {
|
||||
var nid int64
|
||||
if err = rows.Scan(&nid); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, nid)
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) SelectQueuePDUServerNames(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt)
|
||||
rows, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
|
||||
var result []gomatrixserverlib.ServerName
|
||||
for rows.Next() {
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
if err = rows.Scan(&serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, serverName)
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
||||
97
federationsender/storage/cosmosdb/storage.go
Normal file
97
federationsender/storage/cosmosdb/storage.go
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"database/sql"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationsender/storage/shared"
|
||||
"github.com/matrix-org/dendrite/federationsender/storage/sqlite3/deltas"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
// Database stores information needed by the federation sender
|
||||
type Database struct {
|
||||
shared.Database
|
||||
sqlutil.PartitionOffsetStatements
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
}
|
||||
|
||||
// NewDatabase opens a new database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (*Database, error) {
|
||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
var d Database
|
||||
var err error
|
||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.writer = sqlutil.NewExclusiveWriter()
|
||||
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queuePDUs, err := NewSQLiteQueuePDUsTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queueEDUs, err := NewSQLiteQueueEDUsTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queueJSON, err := NewSQLiteQueueJSONTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
blacklist, err := NewSQLiteBlacklistTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inboundPeeks, err := NewSQLiteInboundPeeksTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := sqlutil.NewMigrations()
|
||||
deltas.LoadRemoveRoomsTable(m)
|
||||
if err = m.RunDeltas(d.db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.Database = shared.Database{
|
||||
DB: d.db,
|
||||
Cache: cache,
|
||||
Writer: d.writer,
|
||||
FederationSenderJoinedHosts: joinedHosts,
|
||||
FederationSenderQueuePDUs: queuePDUs,
|
||||
FederationSenderQueueEDUs: queueEDUs,
|
||||
FederationSenderQueueJSON: queueJSON,
|
||||
FederationSenderBlacklist: blacklist,
|
||||
FederationSenderOutboundPeeks: outboundPeeks,
|
||||
FederationSenderInboundPeeks: inboundPeeks,
|
||||
}
|
||||
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "federationsender"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &d, nil
|
||||
}
|
||||
|
|
@ -21,6 +21,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/federationsender/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/federationsender/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/federationsender/storage/cosmosdb"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
|
@ -28,6 +29,8 @@ import (
|
|||
// NewDatabase opens a new database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsCosmosDB():
|
||||
return cosmosdb.NewDatabase(dbProperties, cache)
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties, cache)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ import (
|
|||
// NewDatabase opens a new database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsCosmosDB():
|
||||
return nil, fmt.Errorf("can't use CosmosDB implementation")
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties, cache)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
|
|
|
|||
12
internal/cosmosdbutil/connection.go
Normal file
12
internal/cosmosdbutil/connection.go
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
package cosmosdbutil
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetConnectionString(d *config.DataSource) config.DataSource {
|
||||
var connString string
|
||||
connString = string(*d)
|
||||
return config.DataSource(strings.Replace(connString, "cosmosdb:", "file:", 1))
|
||||
}
|
||||
|
|
@ -48,10 +48,15 @@ func (m *Migrations) RunDeltas(db *sql.DB, props *config.DatabaseOptions) error
|
|||
if err != nil {
|
||||
return fmt.Errorf("RunDeltas: Failed to collect migrations: %w", err)
|
||||
}
|
||||
if props.ConnectionString.IsPostgres() {
|
||||
if props.ConnectionString.IsCosmosDB() {
|
||||
if err = goose.SetDialect("postgres"); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if props.ConnectionString.IsPostgres() {
|
||||
//HACK: Not supported
|
||||
if err = goose.SetDialect("cosmosdb"); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if props.ConnectionString.IsSQLite() {
|
||||
if err = goose.SetDialect("sqlite3"); err != nil {
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -103,6 +103,11 @@ func Open(dbProperties *config.DatabaseOptions) (*sql.DB, error) {
|
|||
var err error
|
||||
var driverName, dsn string
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsCosmosDB():
|
||||
//HACK: Not supported
|
||||
driverName = "cosmosdb"
|
||||
dsn = string(dbProperties.ConnectionString)
|
||||
return nil, fmt.Errorf("CosmosDB %q", dbProperties.ConnectionString)
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
driverName = SQLiteDriverName()
|
||||
dsn, err = ParseFileURI(dbProperties.ConnectionString)
|
||||
|
|
|
|||
199
keyserver/storage/cosmosdb/device_keys_table.go
Normal file
199
keyserver/storage/cosmosdb/device_keys_table.go
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
)
|
||||
|
||||
var deviceKeysSchema = `
|
||||
-- Stores device keys for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_device_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
stream_id BIGINT NOT NULL,
|
||||
display_name TEXT,
|
||||
-- Clobber based on tuple of user/device.
|
||||
UNIQUE (user_id, device_id)
|
||||
);
|
||||
`
|
||||
|
||||
const upsertDeviceKeysSQL = "" +
|
||||
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT (user_id, device_id)" +
|
||||
" DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
|
||||
|
||||
const selectDeviceKeysSQL = "" +
|
||||
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||
|
||||
const selectBatchDeviceKeysSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||
|
||||
const selectMaxStreamForUserSQL = "" +
|
||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const countStreamIDsForUserSQL = "" +
|
||||
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
|
||||
|
||||
const deleteAllDeviceKeysSQL = "" +
|
||||
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
type deviceKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertDeviceKeysStmt *sql.Stmt
|
||||
selectDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysStmt *sql.Stmt
|
||||
selectMaxStreamForUserStmt *sql.Stmt
|
||||
deleteAllDeviceKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||
s := &deviceKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(deviceKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||
deviceIDMap := make(map[string]bool)
|
||||
for _, d := range deviceIDs {
|
||||
deviceIDMap[d] = true
|
||||
}
|
||||
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
||||
var result []api.DeviceMessage
|
||||
for rows.Next() {
|
||||
var dk api.DeviceMessage
|
||||
dk.UserID = userID
|
||||
var keyJSON string
|
||||
var streamID int
|
||||
var displayName sql.NullString
|
||||
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dk.KeyJSON = []byte(keyJSON)
|
||||
dk.StreamID = streamID
|
||||
if displayName.Valid {
|
||||
dk.DisplayName = displayName.String
|
||||
}
|
||||
// include the key if we want all keys (no device) or it was asked
|
||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||
result = append(result, dk)
|
||||
}
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
for i, key := range keys {
|
||||
var keyJSONStr string
|
||||
var streamID int
|
||||
var displayName sql.NullString
|
||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
// this will be '' when there is no device
|
||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
||||
keys[i].StreamID = streamID
|
||||
if displayName.Valid {
|
||||
keys[i].DisplayName = displayName.String
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
|
||||
// nullable if there are no results
|
||||
var nullStream sql.NullInt32
|
||||
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
if nullStream.Valid {
|
||||
streamID = nullStream.Int32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
|
||||
iStreamIDs := make([]interface{}, len(streamIDs)+1)
|
||||
iStreamIDs[0] = userID
|
||||
for i := range streamIDs {
|
||||
iStreamIDs[i+1] = streamIDs[i]
|
||||
}
|
||||
query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
|
||||
// nullable if there are no results
|
||||
var count sql.NullInt32
|
||||
err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if count.Valid {
|
||||
return int(count.Int32), nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||
for _, key := range keys {
|
||||
now := time.Now().Unix()
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
|
||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
104
keyserver/storage/cosmosdb/key_changes_table.go
Normal file
104
keyserver/storage/cosmosdb/key_changes_table.go
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"math"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
)
|
||||
|
||||
var keyChangesSchema = `
|
||||
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
partition BIGINT NOT NULL,
|
||||
offset BIGINT NOT NULL,
|
||||
-- The key owner
|
||||
user_id TEXT NOT NULL,
|
||||
UNIQUE (partition, offset)
|
||||
);
|
||||
`
|
||||
|
||||
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
|
||||
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
|
||||
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
|
||||
const upsertKeyChangeSQL = "" +
|
||||
"INSERT INTO keyserver_key_changes (partition, offset, user_id)" +
|
||||
" VALUES ($1, $2, $3)" +
|
||||
" ON CONFLICT (partition, offset)" +
|
||||
" DO UPDATE SET user_id = $3"
|
||||
|
||||
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
|
||||
// take the max offset value as the latest offset.
|
||||
const selectKeyChangesSQL = "" +
|
||||
"SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id"
|
||||
|
||||
type keyChangesStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeyChangeStmt *sql.Stmt
|
||||
selectKeyChangesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
||||
s := &keyChangesStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(keyChangesSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
||||
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) SelectKeyChanges(
|
||||
ctx context.Context, partition int32, fromOffset, toOffset int64,
|
||||
) (userIDs []string, latestOffset int64, err error) {
|
||||
if toOffset == sarama.OffsetNewest {
|
||||
toOffset = math.MaxInt64
|
||||
}
|
||||
latestOffset = fromOffset
|
||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
var offset int64
|
||||
if err := rows.Scan(&userID, &offset); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if offset > latestOffset {
|
||||
latestOffset = offset
|
||||
}
|
||||
userIDs = append(userIDs, userID)
|
||||
}
|
||||
return
|
||||
}
|
||||
203
keyserver/storage/cosmosdb/one_time_keys_table.go
Normal file
203
keyserver/storage/cosmosdb/one_time_keys_table.go
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
)
|
||||
|
||||
var oneTimeKeysSchema = `
|
||||
-- Stores one-time public keys for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
key_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
-- Clobber based on 4-uple of user/device/key/algorithm.
|
||||
UNIQUE (user_id, device_id, key_id, algorithm)
|
||||
);
|
||||
`
|
||||
|
||||
const upsertKeysSQL = "" +
|
||||
"INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT (user_id, device_id, key_id, algorithm)" +
|
||||
" DO UPDATE SET key_json = $6"
|
||||
|
||||
const selectKeysSQL = "" +
|
||||
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
|
||||
|
||||
const selectKeysCountSQL = "" +
|
||||
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
|
||||
|
||||
const deleteOneTimeKeySQL = "" +
|
||||
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
|
||||
|
||||
const selectKeyByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
|
||||
|
||||
type oneTimeKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeysStmt *sql.Stmt
|
||||
selectKeysStmt *sql.Stmt
|
||||
selectKeysCountStmt *sql.Stmt
|
||||
selectKeyByAlgorithmStmt *sql.Stmt
|
||||
deleteOneTimeKeyStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
||||
s := &oneTimeKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(oneTimeKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
|
||||
|
||||
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
|
||||
for _, ka := range keyIDsWithAlgorithms {
|
||||
wantSet[ka] = true
|
||||
}
|
||||
|
||||
result := make(map[string]json.RawMessage)
|
||||
for rows.Next() {
|
||||
var keyID string
|
||||
var algorithm string
|
||||
var keyJSONStr string
|
||||
if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyIDWithAlgo := algorithm + ":" + keyID
|
||||
if wantSet[keyIDWithAlgo] {
|
||||
result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
|
||||
}
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
||||
counts := &api.OneTimeKeysCount{
|
||||
DeviceID: deviceID,
|
||||
UserID: userID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) InsertOneTimeKeys(
|
||||
ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys,
|
||||
) (*api.OneTimeKeysCount, error) {
|
||||
now := time.Now().Unix()
|
||||
counts := &api.OneTimeKeysCount{
|
||||
DeviceID: keys.DeviceID,
|
||||
UserID: keys.UserID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
|
||||
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
|
||||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
|
||||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if keyJSON == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return map[string]json.RawMessage{
|
||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||
}, err
|
||||
}
|
||||
121
keyserver/storage/cosmosdb/stale_device_lists.go
Normal file
121
keyserver/storage/cosmosdb/stale_device_lists.go
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var staleDeviceListsSchema = `
|
||||
-- Stores whether a user's device lists are stale or not.
|
||||
CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
|
||||
user_id TEXT PRIMARY KEY NOT NULL,
|
||||
domain TEXT NOT NULL,
|
||||
is_stale BOOLEAN NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
|
||||
`
|
||||
|
||||
const upsertStaleDeviceListSQL = "" +
|
||||
"INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
|
||||
" VALUES ($1, $2, $3, $4)" +
|
||||
" ON CONFLICT (user_id)" +
|
||||
" DO UPDATE SET is_stale = $3, ts_added_secs = $4"
|
||||
|
||||
const selectStaleDeviceListsWithDomainsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
|
||||
|
||||
const selectStaleDeviceListsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
|
||||
|
||||
type staleDeviceListsStatements struct {
|
||||
db *sql.DB
|
||||
upsertStaleDeviceListStmt *sql.Stmt
|
||||
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
|
||||
selectStaleDeviceListsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
|
||||
s := &staleDeviceListsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(staleDeviceListsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
|
||||
_, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||
// we only query for 1 domain or all domains so optimise for those use cases
|
||||
if len(domains) == 0 {
|
||||
rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rowsToUserIDs(ctx, rows)
|
||||
}
|
||||
var result []string
|
||||
for _, domain := range domains {
|
||||
rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userIDs, err := rowsToUserIDs(ctx, rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, userIDs...)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, userID)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
54
keyserver/storage/cosmosdb/storage.go
Normal file
54
keyserver/storage/cosmosdb/storage.go
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
|
||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
otk, err := NewSqliteOneTimeKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dk, err := NewSqliteDeviceKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kc, err := NewSqliteKeyChangesTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sdl, err := NewSqliteStaleDeviceListsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &shared.Database{
|
||||
DB: db,
|
||||
Writer: sqlutil.NewExclusiveWriter(),
|
||||
OneTimeKeysTable: otk,
|
||||
DeviceKeysTable: dk,
|
||||
KeyChangesTable: kc,
|
||||
StaleDeviceListsTable: sdl,
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -19,6 +19,7 @@ package storage
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/cosmosdb"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
|
|
@ -28,6 +29,8 @@ import (
|
|||
// and sets postgres connection parameters
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsCosmosDB():
|
||||
return cosmosdb.NewDatabase(dbProperties)
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
|
|
|
|||
|
|
@ -23,6 +23,8 @@ import (
|
|||
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsCosmosDB():
|
||||
return nil, fmt.Errorf("can't use CosmosDB implementation")
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
|
|
|
|||
150
mediaapi/storage/cosmosdb/media_repository_table.go
Normal file
150
mediaapi/storage/cosmosdb/media_repository_table.go
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const mediaSchema = `
|
||||
-- The media_repository table holds metadata for each media file stored and accessible to the local server,
|
||||
-- the actual file is stored separately.
|
||||
CREATE TABLE IF NOT EXISTS mediaapi_media_repository (
|
||||
-- The id used to refer to the media.
|
||||
-- For uploads to this server this is a base64-encoded sha256 hash of the file data
|
||||
-- For media from remote servers, this can be any unique identifier string
|
||||
media_id TEXT NOT NULL,
|
||||
-- The origin of the media as requested by the client. Should be a homeserver domain.
|
||||
media_origin TEXT NOT NULL,
|
||||
-- The MIME-type of the media file as specified when uploading.
|
||||
content_type TEXT NOT NULL,
|
||||
-- Size of the media file in bytes.
|
||||
file_size_bytes INTEGER NOT NULL,
|
||||
-- When the content was uploaded in UNIX epoch ms.
|
||||
creation_ts INTEGER NOT NULL,
|
||||
-- The file name with which the media was uploaded.
|
||||
upload_name TEXT NOT NULL,
|
||||
-- Alternate RFC 4648 unpadded base64 encoding string representation of a SHA-256 hash sum of the file data.
|
||||
base64hash TEXT NOT NULL,
|
||||
-- The user who uploaded the file. Should be a Matrix user ID.
|
||||
user_id TEXT NOT NULL
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS mediaapi_media_repository_index ON mediaapi_media_repository (media_id, media_origin);
|
||||
`
|
||||
|
||||
const insertMediaSQL = `
|
||||
INSERT INTO mediaapi_media_repository (media_id, media_origin, content_type, file_size_bytes, creation_ts, upload_name, base64hash, user_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
`
|
||||
|
||||
const selectMediaSQL = `
|
||||
SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user_id FROM mediaapi_media_repository WHERE media_id = $1 AND media_origin = $2
|
||||
`
|
||||
|
||||
const selectMediaByHashSQL = `
|
||||
SELECT content_type, file_size_bytes, creation_ts, upload_name, media_id, user_id FROM mediaapi_media_repository WHERE base64hash = $1 AND media_origin = $2
|
||||
`
|
||||
|
||||
type mediaStatements struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
insertMediaStmt *sql.Stmt
|
||||
selectMediaStmt *sql.Stmt
|
||||
selectMediaByHashStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *mediaStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||
s.db = db
|
||||
s.writer = writer
|
||||
|
||||
_, err = db.Exec(mediaSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return statementList{
|
||||
{&s.insertMediaStmt, insertMediaSQL},
|
||||
{&s.selectMediaStmt, selectMediaSQL},
|
||||
{&s.selectMediaByHashStmt, selectMediaByHashSQL},
|
||||
}.prepare(db)
|
||||
}
|
||||
|
||||
func (s *mediaStatements) insertMedia(
|
||||
ctx context.Context, mediaMetadata *types.MediaMetadata,
|
||||
) error {
|
||||
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertMediaStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
mediaMetadata.MediaID,
|
||||
mediaMetadata.Origin,
|
||||
mediaMetadata.ContentType,
|
||||
mediaMetadata.FileSizeBytes,
|
||||
mediaMetadata.CreationTimestamp,
|
||||
mediaMetadata.UploadName,
|
||||
mediaMetadata.Base64Hash,
|
||||
mediaMetadata.UserID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *mediaStatements) selectMedia(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata := types.MediaMetadata{
|
||||
MediaID: mediaID,
|
||||
Origin: mediaOrigin,
|
||||
}
|
||||
err := s.selectMediaStmt.QueryRowContext(
|
||||
ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
|
||||
).Scan(
|
||||
&mediaMetadata.ContentType,
|
||||
&mediaMetadata.FileSizeBytes,
|
||||
&mediaMetadata.CreationTimestamp,
|
||||
&mediaMetadata.UploadName,
|
||||
&mediaMetadata.Base64Hash,
|
||||
&mediaMetadata.UserID,
|
||||
)
|
||||
return &mediaMetadata, err
|
||||
}
|
||||
|
||||
func (s *mediaStatements) selectMediaByHash(
|
||||
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata := types.MediaMetadata{
|
||||
Base64Hash: mediaHash,
|
||||
Origin: mediaOrigin,
|
||||
}
|
||||
err := s.selectMediaStmt.QueryRowContext(
|
||||
ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
|
||||
).Scan(
|
||||
&mediaMetadata.ContentType,
|
||||
&mediaMetadata.FileSizeBytes,
|
||||
&mediaMetadata.CreationTimestamp,
|
||||
&mediaMetadata.UploadName,
|
||||
&mediaMetadata.MediaID,
|
||||
&mediaMetadata.UserID,
|
||||
)
|
||||
return &mediaMetadata, err
|
||||
}
|
||||
38
mediaapi/storage/cosmosdb/prepare.go
Normal file
38
mediaapi/storage/cosmosdb/prepare.go
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// FIXME: This should be made internal!
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
|
||||
type statementList []struct {
|
||||
statement **sql.Stmt
|
||||
sql string
|
||||
}
|
||||
|
||||
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
|
||||
func (s statementList) prepare(db *sql.DB) (err error) {
|
||||
for _, statement := range s {
|
||||
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
38
mediaapi/storage/cosmosdb/sql.go
Normal file
38
mediaapi/storage/cosmosdb/sql.go
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
type statements struct {
|
||||
media mediaStatements
|
||||
thumbnail thumbnailStatements
|
||||
}
|
||||
|
||||
func (s *statements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||
if err = s.media.prepare(db, writer); err != nil {
|
||||
return
|
||||
}
|
||||
if err = s.thumbnail.prepare(db, writer); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
126
mediaapi/storage/cosmosdb/storage.go
Normal file
126
mediaapi/storage/cosmosdb/storage.go
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
// Import the postgres database driver.
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// Database is used to store metadata about a repository of media files.
|
||||
type Database struct {
|
||||
statements statements
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
}
|
||||
|
||||
// Open opens a postgres database.
|
||||
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
d := Database{
|
||||
writer: sqlutil.NewExclusiveWriter(),
|
||||
}
|
||||
var err error
|
||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.statements.prepare(d.db, d.writer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
// StoreMediaMetadata inserts the metadata about the uploaded media into the database.
|
||||
// Returns an error if the combination of MediaID and Origin are not unique in the table.
|
||||
func (d *Database) StoreMediaMetadata(
|
||||
ctx context.Context, mediaMetadata *types.MediaMetadata,
|
||||
) error {
|
||||
return d.statements.media.insertMedia(ctx, mediaMetadata)
|
||||
}
|
||||
|
||||
// GetMediaMetadata returns metadata about media stored on this server.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there is no metadata associated with this media.
|
||||
func (d *Database) GetMediaMetadata(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return mediaMetadata, err
|
||||
}
|
||||
|
||||
// GetMediaMetadataByHash returns metadata about media stored on this server.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there is no metadata associated with this media.
|
||||
func (d *Database) GetMediaMetadataByHash(
|
||||
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return mediaMetadata, err
|
||||
}
|
||||
|
||||
// StoreThumbnail inserts the metadata about the thumbnail into the database.
|
||||
// Returns an error if the combination of MediaID and Origin are not unique in the table.
|
||||
func (d *Database) StoreThumbnail(
|
||||
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
|
||||
) error {
|
||||
return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata)
|
||||
}
|
||||
|
||||
// GetThumbnail returns metadata about a specific thumbnail.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there is no metadata associated with this thumbnail.
|
||||
func (d *Database) GetThumbnail(
|
||||
ctx context.Context,
|
||||
mediaID types.MediaID,
|
||||
mediaOrigin gomatrixserverlib.ServerName,
|
||||
width, height int,
|
||||
resizeMethod string,
|
||||
) (*types.ThumbnailMetadata, error) {
|
||||
thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail(
|
||||
ctx, mediaID, mediaOrigin, width, height, resizeMethod,
|
||||
)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return thumbnailMetadata, err
|
||||
}
|
||||
|
||||
// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there are no thumbnails associated with this media.
|
||||
func (d *Database) GetThumbnails(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) ([]*types.ThumbnailMetadata, error) {
|
||||
thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return thumbnails, err
|
||||
}
|
||||
171
mediaapi/storage/cosmosdb/thumbnail_table.go
Normal file
171
mediaapi/storage/cosmosdb/thumbnail_table.go
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const thumbnailSchema = `
|
||||
-- The mediaapi_thumbnail table holds metadata for each thumbnail file stored and accessible to the local server,
|
||||
-- the actual file is stored separately.
|
||||
CREATE TABLE IF NOT EXISTS mediaapi_thumbnail (
|
||||
media_id TEXT NOT NULL,
|
||||
media_origin TEXT NOT NULL,
|
||||
content_type TEXT NOT NULL,
|
||||
file_size_bytes INTEGER NOT NULL,
|
||||
creation_ts INTEGER NOT NULL,
|
||||
width INTEGER NOT NULL,
|
||||
height INTEGER NOT NULL,
|
||||
resize_method TEXT NOT NULL
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS mediaapi_thumbnail_index ON mediaapi_thumbnail (media_id, media_origin, width, height, resize_method);
|
||||
`
|
||||
|
||||
const insertThumbnailSQL = `
|
||||
INSERT INTO mediaapi_thumbnail (media_id, media_origin, content_type, file_size_bytes, creation_ts, width, height, resize_method)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
`
|
||||
|
||||
// Note: this selects one specific thumbnail
|
||||
const selectThumbnailSQL = `
|
||||
SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 AND width = $3 AND height = $4 AND resize_method = $5
|
||||
`
|
||||
|
||||
// Note: this selects all thumbnails for a media_origin and media_id
|
||||
const selectThumbnailsSQL = `
|
||||
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2
|
||||
`
|
||||
|
||||
type thumbnailStatements struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
insertThumbnailStmt *sql.Stmt
|
||||
selectThumbnailStmt *sql.Stmt
|
||||
selectThumbnailsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||
_, err = db.Exec(thumbnailSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.db = db
|
||||
s.writer = writer
|
||||
|
||||
return statementList{
|
||||
{&s.insertThumbnailStmt, insertThumbnailSQL},
|
||||
{&s.selectThumbnailStmt, selectThumbnailSQL},
|
||||
{&s.selectThumbnailsStmt, selectThumbnailsSQL},
|
||||
}.prepare(db)
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) insertThumbnail(
|
||||
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
|
||||
) error {
|
||||
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
thumbnailMetadata.MediaMetadata.MediaID,
|
||||
thumbnailMetadata.MediaMetadata.Origin,
|
||||
thumbnailMetadata.MediaMetadata.ContentType,
|
||||
thumbnailMetadata.MediaMetadata.FileSizeBytes,
|
||||
thumbnailMetadata.MediaMetadata.CreationTimestamp,
|
||||
thumbnailMetadata.ThumbnailSize.Width,
|
||||
thumbnailMetadata.ThumbnailSize.Height,
|
||||
thumbnailMetadata.ThumbnailSize.ResizeMethod,
|
||||
)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) selectThumbnail(
|
||||
ctx context.Context,
|
||||
mediaID types.MediaID,
|
||||
mediaOrigin gomatrixserverlib.ServerName,
|
||||
width, height int,
|
||||
resizeMethod string,
|
||||
) (*types.ThumbnailMetadata, error) {
|
||||
thumbnailMetadata := types.ThumbnailMetadata{
|
||||
MediaMetadata: &types.MediaMetadata{
|
||||
MediaID: mediaID,
|
||||
Origin: mediaOrigin,
|
||||
},
|
||||
ThumbnailSize: types.ThumbnailSize{
|
||||
Width: width,
|
||||
Height: height,
|
||||
ResizeMethod: resizeMethod,
|
||||
},
|
||||
}
|
||||
err := s.selectThumbnailStmt.QueryRowContext(
|
||||
ctx,
|
||||
thumbnailMetadata.MediaMetadata.MediaID,
|
||||
thumbnailMetadata.MediaMetadata.Origin,
|
||||
thumbnailMetadata.ThumbnailSize.Width,
|
||||
thumbnailMetadata.ThumbnailSize.Height,
|
||||
thumbnailMetadata.ThumbnailSize.ResizeMethod,
|
||||
).Scan(
|
||||
&thumbnailMetadata.MediaMetadata.ContentType,
|
||||
&thumbnailMetadata.MediaMetadata.FileSizeBytes,
|
||||
&thumbnailMetadata.MediaMetadata.CreationTimestamp,
|
||||
)
|
||||
return &thumbnailMetadata, err
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) selectThumbnails(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) ([]*types.ThumbnailMetadata, error) {
|
||||
rows, err := s.selectThumbnailsStmt.QueryContext(
|
||||
ctx, mediaID, mediaOrigin,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectThumbnails: rows.close() failed")
|
||||
|
||||
var thumbnails []*types.ThumbnailMetadata
|
||||
for rows.Next() {
|
||||
thumbnailMetadata := types.ThumbnailMetadata{
|
||||
MediaMetadata: &types.MediaMetadata{
|
||||
MediaID: mediaID,
|
||||
Origin: mediaOrigin,
|
||||
},
|
||||
}
|
||||
err = rows.Scan(
|
||||
&thumbnailMetadata.MediaMetadata.ContentType,
|
||||
&thumbnailMetadata.MediaMetadata.FileSizeBytes,
|
||||
&thumbnailMetadata.MediaMetadata.CreationTimestamp,
|
||||
&thumbnailMetadata.ThumbnailSize.Width,
|
||||
&thumbnailMetadata.ThumbnailSize.Height,
|
||||
&thumbnailMetadata.ThumbnailSize.ResizeMethod,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
thumbnails = append(thumbnails, &thumbnailMetadata)
|
||||
}
|
||||
|
||||
return thumbnails, rows.Err()
|
||||
}
|
||||
|
|
@ -19,6 +19,7 @@ package storage
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage/cosmosdb"
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
|
|
@ -27,6 +28,8 @@ import (
|
|||
// Open opens a postgres database.
|
||||
func Open(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsCosmosDB():
|
||||
return cosmosdb.Open(dbProperties)
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.Open(dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ import (
|
|||
// Open opens a postgres database.
|
||||
func Open(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsCosmosDB():
|
||||
return nil, fmt.Errorf("can't use CosmosDB implementation")
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.Open(dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
|
|
|
|||
107
roomserver/storage/cosmosdb/event_json_table.go
Normal file
107
roomserver/storage/cosmosdb/event_json_table.go
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
const eventJSONSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_event_json (
|
||||
event_nid INTEGER NOT NULL PRIMARY KEY,
|
||||
event_json TEXT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertEventJSONSQL = `
|
||||
INSERT OR REPLACE INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2)
|
||||
`
|
||||
|
||||
// Bulk event JSON lookup by numeric event ID.
|
||||
// Sort by the numeric event ID.
|
||||
// This means that we can use binary search to lookup by numeric event ID.
|
||||
const bulkSelectEventJSONSQL = `
|
||||
SELECT event_nid, event_json FROM roomserver_event_json
|
||||
WHERE event_nid IN ($1)
|
||||
ORDER BY event_nid ASC
|
||||
`
|
||||
|
||||
type eventJSONStatements struct {
|
||||
db *sql.DB
|
||||
insertEventJSONStmt *sql.Stmt
|
||||
bulkSelectEventJSONStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
|
||||
s := &eventJSONStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(eventJSONSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, shared.StatementList{
|
||||
{&s.insertEventJSONStmt, insertEventJSONSQL},
|
||||
{&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *eventJSONStatements) InsertEventJSON(
|
||||
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *eventJSONStatements) BulkSelectEventJSON(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
) ([]tables.EventJSONPair, error) {
|
||||
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||
for k, v := range eventNIDs {
|
||||
iEventNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventJSON: rows.close() failed")
|
||||
|
||||
// We know that we will only get as many results as event NIDs
|
||||
// because of the unique constraint on event NIDs.
|
||||
// So we can allocate an array of the correct size now.
|
||||
// We might get fewer results than NIDs so we adjust the length of the slice before returning it.
|
||||
results := make([]tables.EventJSONPair, len(eventNIDs))
|
||||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
result := &results[i]
|
||||
var eventNID int64
|
||||
if err := rows.Scan(&eventNID, &result.EventJSON); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.EventNID = types.EventNID(eventNID)
|
||||
}
|
||||
return results[:i], nil
|
||||
}
|
||||
163
roomserver/storage/cosmosdb/event_state_keys_table.go
Normal file
163
roomserver/storage/cosmosdb/event_state_keys_table.go
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
const eventStateKeysSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_event_state_keys (
|
||||
event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
event_state_key TEXT NOT NULL UNIQUE
|
||||
);
|
||||
INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key)
|
||||
VALUES (1, '')
|
||||
ON CONFLICT DO NOTHING;
|
||||
`
|
||||
|
||||
// Same as insertEventTypeNIDSQL
|
||||
const insertEventStateKeyNIDSQL = `
|
||||
INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1)
|
||||
ON CONFLICT DO NOTHING;
|
||||
`
|
||||
|
||||
const selectEventStateKeyNIDSQL = `
|
||||
SELECT event_state_key_nid FROM roomserver_event_state_keys
|
||||
WHERE event_state_key = $1
|
||||
`
|
||||
|
||||
// Bulk lookup from string state key to numeric ID for that state key.
|
||||
// Takes an array of strings as the query parameter.
|
||||
const bulkSelectEventStateKeySQL = `
|
||||
SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
|
||||
WHERE event_state_key IN ($1)
|
||||
`
|
||||
|
||||
// Bulk lookup from numeric ID to string state key for that state key.
|
||||
// Takes an array of strings as the query parameter.
|
||||
const bulkSelectEventStateKeyNIDSQL = `
|
||||
SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
|
||||
WHERE event_state_key_nid IN ($1)
|
||||
`
|
||||
|
||||
type eventStateKeyStatements struct {
|
||||
db *sql.DB
|
||||
insertEventStateKeyNIDStmt *sql.Stmt
|
||||
selectEventStateKeyNIDStmt *sql.Stmt
|
||||
bulkSelectEventStateKeyNIDStmt *sql.Stmt
|
||||
bulkSelectEventStateKeyStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) {
|
||||
s := &eventStateKeyStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(eventStateKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, shared.StatementList{
|
||||
{&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL},
|
||||
{&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL},
|
||||
{&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL},
|
||||
{&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *eventStateKeyStatements) InsertEventStateKeyNID(
|
||||
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
||||
) (types.EventStateKeyNID, error) {
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt)
|
||||
res, err := insertStmt.ExecContext(ctx, eventStateKey)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
eventStateKeyNID, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return types.EventStateKeyNID(eventStateKeyNID), err
|
||||
}
|
||||
|
||||
func (s *eventStateKeyStatements) SelectEventStateKeyNID(
|
||||
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
||||
) (types.EventStateKeyNID, error) {
|
||||
var eventStateKeyNID int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectEventStateKeyNIDStmt)
|
||||
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
|
||||
return types.EventStateKeyNID(eventStateKeyNID), err
|
||||
}
|
||||
|
||||
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
||||
ctx context.Context, eventStateKeys []string,
|
||||
) (map[string]types.EventStateKeyNID, error) {
|
||||
iEventStateKeys := make([]interface{}, len(eventStateKeys))
|
||||
for k, v := range eventStateKeys {
|
||||
iEventStateKeys[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKeyNID: rows.close() failed")
|
||||
result := make(map[string]types.EventStateKeyNID, len(eventStateKeys))
|
||||
for rows.Next() {
|
||||
var stateKey string
|
||||
var stateKeyNID int64
|
||||
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[stateKey] = types.EventStateKeyNID(stateKeyNID)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
|
||||
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||
) (map[types.EventStateKeyNID]string, error) {
|
||||
iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
|
||||
for k, v := range eventStateKeyNIDs {
|
||||
iEventStateKeyNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKey: rows.close() failed")
|
||||
result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs))
|
||||
for rows.Next() {
|
||||
var stateKey string
|
||||
var stateKeyNID int64
|
||||
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[types.EventStateKeyNID(stateKeyNID)] = stateKey
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
161
roomserver/storage/cosmosdb/event_types_table.go
Normal file
161
roomserver/storage/cosmosdb/event_types_table.go
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
const eventTypesSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_event_types (
|
||||
event_type_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
event_type TEXT NOT NULL UNIQUE
|
||||
);
|
||||
INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES
|
||||
(1, 'm.room.create'),
|
||||
(2, 'm.room.power_levels'),
|
||||
(3, 'm.room.join_rules'),
|
||||
(4, 'm.room.third_party_invite'),
|
||||
(5, 'm.room.member'),
|
||||
(6, 'm.room.redaction'),
|
||||
(7, 'm.room.history_visibility') ON CONFLICT DO NOTHING;
|
||||
`
|
||||
|
||||
// Assign a new numeric event type ID.
|
||||
// The usual case is that the event type is not in the database.
|
||||
// In that case the ID will be assigned using the next value from the sequence.
|
||||
// We use `RETURNING` to tell postgres to return the assigned ID.
|
||||
// But it's possible that the type was added in a query that raced with us.
|
||||
// This will result in a conflict on the event_type_unique constraint, in this
|
||||
// case we do nothing. Postgresql won't return a row in that case so we rely on
|
||||
// the caller catching the sql.ErrNoRows error and running a select to get the row.
|
||||
// We could get postgresql to return the row on a conflict by updating the row
|
||||
// but it doesn't seem like a good idea to modify the rows just to make postgresql
|
||||
// return it. Modifying the rows will cause postgres to assign a new tuple for the
|
||||
// row even though the data doesn't change resulting in unncesssary modifications
|
||||
// to the indexes.
|
||||
const insertEventTypeNIDSQL = `
|
||||
INSERT INTO roomserver_event_types (event_type) VALUES ($1)
|
||||
ON CONFLICT DO NOTHING;
|
||||
`
|
||||
|
||||
const insertEventTypeNIDResultSQL = `
|
||||
SELECT event_type_nid FROM roomserver_event_types
|
||||
WHERE rowid = last_insert_rowid();
|
||||
`
|
||||
|
||||
const selectEventTypeNIDSQL = `
|
||||
SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1
|
||||
`
|
||||
|
||||
// Bulk lookup from string event type to numeric ID for that event type.
|
||||
// Takes an array of strings as the query parameter.
|
||||
const bulkSelectEventTypeNIDSQL = `
|
||||
SELECT event_type, event_type_nid FROM roomserver_event_types
|
||||
WHERE event_type IN ($1)
|
||||
`
|
||||
|
||||
type eventTypeStatements struct {
|
||||
db *sql.DB
|
||||
insertEventTypeNIDStmt *sql.Stmt
|
||||
insertEventTypeNIDResultStmt *sql.Stmt
|
||||
selectEventTypeNIDStmt *sql.Stmt
|
||||
bulkSelectEventTypeNIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
|
||||
s := &eventTypeStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(eventTypesSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL},
|
||||
{&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL},
|
||||
{&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL},
|
||||
{&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *eventTypeStatements) InsertEventTypeNID(
|
||||
ctx context.Context, txn *sql.Tx, eventType string,
|
||||
) (types.EventTypeNID, error) {
|
||||
var eventTypeNID int64
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDStmt)
|
||||
resultStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDResultStmt)
|
||||
_, err := insertStmt.ExecContext(ctx, eventType)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insertStmt.ExecContext: %w", err)
|
||||
}
|
||||
if err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID); err != nil {
|
||||
return 0, fmt.Errorf("resultStmt.QueryRowContext.Scan: %w", err)
|
||||
}
|
||||
return types.EventTypeNID(eventTypeNID), err
|
||||
}
|
||||
|
||||
func (s *eventTypeStatements) SelectEventTypeNID(
|
||||
ctx context.Context, tx *sql.Tx, eventType string,
|
||||
) (types.EventTypeNID, error) {
|
||||
var eventTypeNID int64
|
||||
selectStmt := sqlutil.TxStmt(tx, s.selectEventTypeNIDStmt)
|
||||
err := selectStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID)
|
||||
return types.EventTypeNID(eventTypeNID), err
|
||||
}
|
||||
|
||||
func (s *eventTypeStatements) BulkSelectEventTypeNID(
|
||||
ctx context.Context, eventTypes []string,
|
||||
) (map[string]types.EventTypeNID, error) {
|
||||
///////////////
|
||||
iEventTypes := make([]interface{}, len(eventTypes))
|
||||
for k, v := range eventTypes {
|
||||
iEventTypes[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventTypes)), 1)
|
||||
selectPrep, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
///////////////
|
||||
|
||||
rows, err := selectPrep.QueryContext(ctx, iEventTypes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventTypeNID: rows.close() failed")
|
||||
|
||||
result := make(map[string]types.EventTypeNID, len(eventTypes))
|
||||
for rows.Next() {
|
||||
var eventType string
|
||||
var eventTypeNID int64
|
||||
if err := rows.Scan(&eventType, &eventTypeNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[eventType] = types.EventTypeNID(eventTypeNID)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
515
roomserver/storage/cosmosdb/events_table.go
Normal file
515
roomserver/storage/cosmosdb/events_table.go
Normal file
|
|
@ -0,0 +1,515 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const eventsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_events (
|
||||
event_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
room_nid INTEGER NOT NULL,
|
||||
event_type_nid INTEGER NOT NULL,
|
||||
event_state_key_nid INTEGER NOT NULL,
|
||||
sent_to_output BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
state_snapshot_nid INTEGER NOT NULL DEFAULT 0,
|
||||
depth INTEGER NOT NULL,
|
||||
event_id TEXT NOT NULL UNIQUE,
|
||||
reference_sha256 BLOB NOT NULL,
|
||||
auth_event_nids TEXT NOT NULL DEFAULT '[]',
|
||||
is_rejected BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
`
|
||||
|
||||
const insertEventSQL = `
|
||||
INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
ON CONFLICT DO NOTHING;
|
||||
`
|
||||
|
||||
const selectEventSQL = "" +
|
||||
"SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1"
|
||||
|
||||
// Bulk lookup of events by string ID.
|
||||
// Sort by the numeric IDs for event type and state key.
|
||||
// This means we can use binary search to lookup entries by type and state key.
|
||||
const bulkSelectStateEventByIDSQL = "" +
|
||||
"SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
|
||||
" WHERE event_id IN ($1)" +
|
||||
" ORDER BY event_type_nid, event_state_key_nid ASC"
|
||||
|
||||
const bulkSelectStateAtEventByIDSQL = "" +
|
||||
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" +
|
||||
" WHERE event_id IN ($1)"
|
||||
|
||||
const updateEventStateSQL = "" +
|
||||
"UPDATE roomserver_events SET state_snapshot_nid = $1 WHERE event_nid = $2"
|
||||
|
||||
const selectEventSentToOutputSQL = "" +
|
||||
"SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1"
|
||||
|
||||
const updateEventSentToOutputSQL = "" +
|
||||
"UPDATE roomserver_events SET sent_to_output = TRUE WHERE event_nid = $1"
|
||||
|
||||
const selectEventIDSQL = "" +
|
||||
"SELECT event_id FROM roomserver_events WHERE event_nid = $1"
|
||||
|
||||
const bulkSelectStateAtEventAndReferenceSQL = "" +
|
||||
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" +
|
||||
" FROM roomserver_events WHERE event_nid IN ($1)"
|
||||
|
||||
const bulkSelectEventReferenceSQL = "" +
|
||||
"SELECT event_id, reference_sha256 FROM roomserver_events WHERE event_nid IN ($1)"
|
||||
|
||||
const bulkSelectEventIDSQL = "" +
|
||||
"SELECT event_nid, event_id FROM roomserver_events WHERE event_nid IN ($1)"
|
||||
|
||||
const bulkSelectEventNIDSQL = "" +
|
||||
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)"
|
||||
|
||||
const selectMaxEventDepthSQL = "" +
|
||||
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
|
||||
|
||||
const selectRoomNIDsForEventNIDsSQL = "" +
|
||||
"SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid IN ($1)"
|
||||
|
||||
type eventStatements struct {
|
||||
db *sql.DB
|
||||
insertEventStmt *sql.Stmt
|
||||
selectEventStmt *sql.Stmt
|
||||
bulkSelectStateEventByIDStmt *sql.Stmt
|
||||
bulkSelectStateAtEventByIDStmt *sql.Stmt
|
||||
updateEventStateStmt *sql.Stmt
|
||||
selectEventSentToOutputStmt *sql.Stmt
|
||||
updateEventSentToOutputStmt *sql.Stmt
|
||||
selectEventIDStmt *sql.Stmt
|
||||
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
|
||||
bulkSelectEventReferenceStmt *sql.Stmt
|
||||
bulkSelectEventIDStmt *sql.Stmt
|
||||
bulkSelectEventNIDStmt *sql.Stmt
|
||||
//selectRoomNIDsForEventNIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
|
||||
s := &eventStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(eventsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertEventStmt, insertEventSQL},
|
||||
{&s.selectEventStmt, selectEventSQL},
|
||||
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
|
||||
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
|
||||
{&s.updateEventStateStmt, updateEventStateSQL},
|
||||
{&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
|
||||
{&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL},
|
||||
{&s.selectEventIDStmt, selectEventIDSQL},
|
||||
{&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL},
|
||||
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
|
||||
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
|
||||
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
|
||||
//{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *eventStatements) InsertEvent(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
roomNID types.RoomNID,
|
||||
eventTypeNID types.EventTypeNID,
|
||||
eventStateKeyNID types.EventStateKeyNID,
|
||||
eventID string,
|
||||
referenceSHA256 []byte,
|
||||
authEventNIDs []types.EventNID,
|
||||
depth int64,
|
||||
isRejected bool,
|
||||
) (types.EventNID, types.StateSnapshotNID, error) {
|
||||
// attempt to insert: the last_row_id is the event NID
|
||||
var eventNID int64
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
|
||||
result, err := insertStmt.ExecContext(
|
||||
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
|
||||
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, isRejected,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
modified, err := result.RowsAffected()
|
||||
if modified == 0 && err == nil {
|
||||
return 0, 0, sql.ErrNoRows
|
||||
}
|
||||
eventNID, err = result.LastInsertId()
|
||||
return types.EventNID(eventNID), 0, err
|
||||
}
|
||||
|
||||
func (s *eventStatements) SelectEvent(
|
||||
ctx context.Context, txn *sql.Tx, eventID string,
|
||||
) (types.EventNID, types.StateSnapshotNID, error) {
|
||||
var eventNID int64
|
||||
var stateNID int64
|
||||
selectStmt := sqlutil.TxStmt(txn, s.selectEventStmt)
|
||||
err := selectStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID)
|
||||
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
||||
}
|
||||
|
||||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||
func (s *eventStatements) BulkSelectStateEventByID(
|
||||
ctx context.Context, eventIDs []string,
|
||||
) ([]types.StateEntry, error) {
|
||||
///////////////
|
||||
iEventIDs := make([]interface{}, len(eventIDs))
|
||||
for k, v := range eventIDs {
|
||||
iEventIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
///////////////
|
||||
|
||||
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed")
|
||||
// We know that we will only get as many results as event IDs
|
||||
// because of the unique constraint on event IDs.
|
||||
// So we can allocate an array of the correct size now.
|
||||
// We might get fewer results than IDs so we adjust the length of the slice before returning it.
|
||||
results := make([]types.StateEntry, len(eventIDs))
|
||||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
result := &results[i]
|
||||
if err = rows.Scan(
|
||||
&result.EventTypeNID,
|
||||
&result.EventStateKeyNID,
|
||||
&result.EventNID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if i != len(eventIDs) {
|
||||
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
|
||||
// We don't know which ones were missing because we don't return the string IDs in the query.
|
||||
// However it should be possible debug this by replaying queries or entries from the input kafka logs.
|
||||
// If this turns out to be impossible and we do need the debug information here, it would be better
|
||||
// to do it as a separate query rather than slowing down/complicating the internal case.
|
||||
return nil, types.MissingEventError(
|
||||
fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)),
|
||||
)
|
||||
}
|
||||
return results, err
|
||||
}
|
||||
|
||||
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
||||
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
||||
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
||||
func (s *eventStatements) BulkSelectStateAtEventByID(
|
||||
ctx context.Context, eventIDs []string,
|
||||
) ([]types.StateAtEvent, error) {
|
||||
///////////////
|
||||
iEventIDs := make([]interface{}, len(eventIDs))
|
||||
for k, v := range eventIDs {
|
||||
iEventIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
///////////////
|
||||
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventByID: rows.close() failed")
|
||||
results := make([]types.StateAtEvent, len(eventIDs))
|
||||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
result := &results[i]
|
||||
if err = rows.Scan(
|
||||
&result.EventTypeNID,
|
||||
&result.EventStateKeyNID,
|
||||
&result.EventNID,
|
||||
&result.BeforeStateSnapshotNID,
|
||||
&result.IsRejected,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.BeforeStateSnapshotNID == 0 {
|
||||
return nil, types.MissingEventError(
|
||||
fmt.Sprintf("storage: missing state for event NID %d", result.EventNID),
|
||||
)
|
||||
}
|
||||
}
|
||||
if i != len(eventIDs) {
|
||||
return nil, types.MissingEventError(
|
||||
fmt.Sprintf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)),
|
||||
)
|
||||
}
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (s *eventStatements) UpdateEventState(
|
||||
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateEventStateStmt)
|
||||
_, err := stmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *eventStatements) SelectEventSentToOutput(
|
||||
ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
|
||||
) (sentToOutput bool, err error) {
|
||||
selectStmt := sqlutil.TxStmt(txn, s.selectEventSentToOutputStmt)
|
||||
err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error {
|
||||
updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt)
|
||||
_, err := updateStmt.ExecContext(ctx, int64(eventNID))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *eventStatements) SelectEventID(
|
||||
ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
|
||||
) (eventID string, err error) {
|
||||
selectStmt := sqlutil.TxStmt(txn, s.selectEventIDStmt)
|
||||
err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *eventStatements) BulkSelectStateAtEventAndReference(
|
||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||
) ([]types.StateAtEventAndReference, error) {
|
||||
///////////////
|
||||
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||
for k, v := range eventNIDs {
|
||||
iEventNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
|
||||
selectPrep, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//////////////
|
||||
|
||||
rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlutil.TxStmt.QueryContext: %w", err)
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed")
|
||||
results := make([]types.StateAtEventAndReference, len(eventNIDs))
|
||||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
var (
|
||||
eventTypeNID int64
|
||||
eventStateKeyNID int64
|
||||
eventNID int64
|
||||
stateSnapshotNID int64
|
||||
eventID string
|
||||
eventSHA256 []byte
|
||||
)
|
||||
if err = rows.Scan(
|
||||
&eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := &results[i]
|
||||
result.EventTypeNID = types.EventTypeNID(eventTypeNID)
|
||||
result.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
|
||||
result.EventNID = types.EventNID(eventNID)
|
||||
result.BeforeStateSnapshotNID = types.StateSnapshotNID(stateSnapshotNID)
|
||||
result.EventID = eventID
|
||||
result.EventSHA256 = eventSHA256
|
||||
}
|
||||
if i != len(eventNIDs) {
|
||||
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *eventStatements) BulkSelectEventReference(
|
||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||
) ([]gomatrixserverlib.EventReference, error) {
|
||||
///////////////
|
||||
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||
for k, v := range eventNIDs {
|
||||
iEventNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
|
||||
selectPrep, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
///////////////
|
||||
|
||||
selectStmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventReference: rows.close() failed")
|
||||
results := make([]gomatrixserverlib.EventReference, len(eventNIDs))
|
||||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
result := &results[i]
|
||||
if err = rows.Scan(&result.EventID, &result.EventSHA256); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if i != len(eventNIDs) {
|
||||
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// bulkSelectEventID returns a map from numeric event ID to string event ID.
|
||||
func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
||||
///////////////
|
||||
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||
for k, v := range eventNIDs {
|
||||
iEventNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
///////////////
|
||||
|
||||
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventID: rows.close() failed")
|
||||
results := make(map[types.EventNID]string, len(eventNIDs))
|
||||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
var eventNID int64
|
||||
var eventID string
|
||||
if err = rows.Scan(&eventNID, &eventID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results[types.EventNID(eventNID)] = eventID
|
||||
}
|
||||
if i != len(eventNIDs) {
|
||||
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
|
||||
// If an event ID is not in the database then it is omitted from the map.
|
||||
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) {
|
||||
///////////////
|
||||
iEventIDs := make([]interface{}, len(eventIDs))
|
||||
for k, v := range eventIDs {
|
||||
iEventIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
///////////////
|
||||
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed")
|
||||
results := make(map[string]types.EventNID, len(eventIDs))
|
||||
for rows.Next() {
|
||||
var eventID string
|
||||
var eventNID int64
|
||||
if err = rows.Scan(&eventID, &eventNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results[eventID] = types.EventNID(eventNID)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) {
|
||||
var result int64
|
||||
iEventIDs := make([]interface{}, len(eventNIDs))
|
||||
for i, v := range eventNIDs {
|
||||
iEventIDs[i] = v
|
||||
}
|
||||
sqlStr := strings.Replace(selectMaxEventDepthSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
|
||||
sqlPrep, err := s.db.Prepare(sqlStr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = sqlutil.TxStmt(txn, sqlPrep).QueryRowContext(ctx, iEventIDs...).Scan(&result)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("sqlutil.TxStmt.QueryRowContext: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *eventStatements) SelectRoomNIDsForEventNIDs(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
) (map[types.EventNID]types.RoomNID, error) {
|
||||
sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
|
||||
sqlPrep, err := s.db.Prepare(sqlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||
for i, v := range eventNIDs {
|
||||
iEventNIDs[i] = v
|
||||
}
|
||||
rows, err := sqlPrep.QueryContext(ctx, iEventNIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed")
|
||||
result := make(map[types.EventNID]types.RoomNID)
|
||||
for rows.Next() {
|
||||
var eventNID types.EventNID
|
||||
var roomNID types.RoomNID
|
||||
if err = rows.Scan(&eventNID, &roomNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[eventNID] = roomNID
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func eventNIDsAsArray(eventNIDs []types.EventNID) string {
|
||||
b, _ := json.Marshal(eventNIDs)
|
||||
return string(b)
|
||||
}
|
||||
159
roomserver/storage/cosmosdb/invite_table.go
Normal file
159
roomserver/storage/cosmosdb/invite_table.go
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
const inviteSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_invites (
|
||||
invite_event_id TEXT PRIMARY KEY,
|
||||
room_nid INTEGER NOT NULL,
|
||||
target_nid INTEGER NOT NULL,
|
||||
sender_nid INTEGER NOT NULL DEFAULT 0,
|
||||
retired BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
invite_event_json TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS roomserver_invites_active_idx ON roomserver_invites (target_nid, room_nid)
|
||||
WHERE NOT retired;
|
||||
`
|
||||
const insertInviteEventSQL = "" +
|
||||
"INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," +
|
||||
" sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" +
|
||||
" ON CONFLICT DO NOTHING"
|
||||
|
||||
const selectInviteActiveForUserInRoomSQL = "" +
|
||||
"SELECT invite_event_id, sender_nid FROM roomserver_invites" +
|
||||
" WHERE target_nid = $1 AND room_nid = $2" +
|
||||
" AND NOT retired"
|
||||
|
||||
// Retire every active invite for a user in a room.
|
||||
// Ideally we'd know which invite events were retired by a given update so we
|
||||
// wouldn't need to remove every active invite.
|
||||
// However the matrix protocol doesn't give us a way to reliably identify the
|
||||
// invites that were retired, so we are forced to retire all of them.
|
||||
const updateInviteRetiredSQL = `
|
||||
UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
|
||||
`
|
||||
|
||||
const selectInvitesAboutToRetireSQL = `
|
||||
SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
|
||||
`
|
||||
|
||||
type inviteStatements struct {
|
||||
db *sql.DB
|
||||
insertInviteEventStmt *sql.Stmt
|
||||
selectInviteActiveForUserInRoomStmt *sql.Stmt
|
||||
updateInviteRetiredStmt *sql.Stmt
|
||||
selectInvitesAboutToRetireStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) {
|
||||
s := &inviteStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(inviteSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertInviteEventStmt, insertInviteEventSQL},
|
||||
{&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL},
|
||||
{&s.updateInviteRetiredStmt, updateInviteRetiredSQL},
|
||||
{&s.selectInvitesAboutToRetireStmt, selectInvitesAboutToRetireSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *inviteStatements) InsertInviteEvent(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
|
||||
targetUserNID, senderUserNID types.EventStateKeyNID,
|
||||
inviteEventJSON []byte,
|
||||
) (bool, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
|
||||
result, err := stmt.ExecContext(
|
||||
ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
count, err = result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count != 0, err
|
||||
}
|
||||
|
||||
func (s *inviteStatements) UpdateInviteRetired(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) (eventIDs []string, err error) {
|
||||
// gather all the event IDs we will retire
|
||||
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var inviteEventID string
|
||||
if err = rows.Scan(&inviteEventID); err != nil {
|
||||
return
|
||||
}
|
||||
eventIDs = append(eventIDs, inviteEventID)
|
||||
}
|
||||
// now retire the invites
|
||||
stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
|
||||
_, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
|
||||
return
|
||||
}
|
||||
|
||||
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs
|
||||
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
||||
ctx context.Context,
|
||||
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
||||
) ([]types.EventStateKeyNID, []string, error) {
|
||||
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
|
||||
ctx, targetUserNID, roomNID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed")
|
||||
var result []types.EventStateKeyNID
|
||||
var eventIDs []string
|
||||
for rows.Next() {
|
||||
var eventID string
|
||||
var senderUserNID int64
|
||||
if err := rows.Scan(&eventID, &senderUserNID); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
result = append(result, types.EventStateKeyNID(senderUserNID))
|
||||
eventIDs = append(eventIDs, eventID)
|
||||
}
|
||||
return result, eventIDs, nil
|
||||
}
|
||||
306
roomserver/storage/cosmosdb/membership_table.go
Normal file
306
roomserver/storage/cosmosdb/membership_table.go
Normal file
|
|
@ -0,0 +1,306 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
const membershipSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_membership (
|
||||
room_nid INTEGER NOT NULL,
|
||||
target_nid INTEGER NOT NULL,
|
||||
sender_nid INTEGER NOT NULL DEFAULT 0,
|
||||
membership_nid INTEGER NOT NULL DEFAULT 1,
|
||||
event_nid INTEGER NOT NULL DEFAULT 0,
|
||||
target_local BOOLEAN NOT NULL DEFAULT false,
|
||||
forgotten BOOLEAN NOT NULL DEFAULT false,
|
||||
UNIQUE (room_nid, target_nid)
|
||||
);
|
||||
`
|
||||
|
||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" +
|
||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
" GROUP BY target_nid"
|
||||
|
||||
// Insert a row in to membership table so that it can be locked by the
|
||||
// SELECT FOR UPDATE
|
||||
const insertMembershipSQL = "" +
|
||||
"INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" +
|
||||
" VALUES ($1, $2, $3)" +
|
||||
" ON CONFLICT DO NOTHING"
|
||||
|
||||
const selectMembershipFromRoomAndTargetSQL = "" +
|
||||
"SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" +
|
||||
" WHERE room_nid = $1 AND target_nid = $2"
|
||||
|
||||
const selectMembershipsFromRoomAndMembershipSQL = "" +
|
||||
"SELECT event_nid FROM roomserver_membership" +
|
||||
" WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false"
|
||||
|
||||
const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
|
||||
"SELECT event_nid FROM roomserver_membership" +
|
||||
" WHERE room_nid = $1 AND membership_nid = $2" +
|
||||
" AND target_local = true and forgotten = false"
|
||||
|
||||
const selectMembershipsFromRoomSQL = "" +
|
||||
"SELECT event_nid FROM roomserver_membership" +
|
||||
" WHERE room_nid = $1 and forgotten = false"
|
||||
|
||||
const selectLocalMembershipsFromRoomSQL = "" +
|
||||
"SELECT event_nid FROM roomserver_membership" +
|
||||
" WHERE room_nid = $1" +
|
||||
" AND target_local = true and forgotten = false"
|
||||
|
||||
const selectMembershipForUpdateSQL = "" +
|
||||
"SELECT membership_nid FROM roomserver_membership" +
|
||||
" WHERE room_nid = $1 AND target_nid = $2"
|
||||
|
||||
const updateMembershipSQL = "" +
|
||||
"UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" +
|
||||
" WHERE room_nid = $5 AND target_nid = $6"
|
||||
|
||||
const updateMembershipForgetRoom = "" +
|
||||
"UPDATE roomserver_membership SET forgotten = $1" +
|
||||
" WHERE room_nid = $2 AND target_nid = $3"
|
||||
|
||||
const selectRoomsWithMembershipSQL = "" +
|
||||
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
|
||||
|
||||
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
|
||||
// joined to. Since this information is used to populate the user directory, we will
|
||||
// only return users that the user would ordinarily be able to see anyway.
|
||||
var selectKnownUsersSQL = "" +
|
||||
"SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " +
|
||||
"roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
|
||||
" WHERE room_nid IN (" +
|
||||
" SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
||||
") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3"
|
||||
|
||||
type membershipStatements struct {
|
||||
db *sql.DB
|
||||
insertMembershipStmt *sql.Stmt
|
||||
selectMembershipForUpdateStmt *sql.Stmt
|
||||
selectMembershipFromRoomAndTargetStmt *sql.Stmt
|
||||
selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
|
||||
selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt
|
||||
selectMembershipsFromRoomStmt *sql.Stmt
|
||||
selectLocalMembershipsFromRoomStmt *sql.Stmt
|
||||
selectRoomsWithMembershipStmt *sql.Stmt
|
||||
updateMembershipStmt *sql.Stmt
|
||||
selectKnownUsersStmt *sql.Stmt
|
||||
updateMembershipForgetRoomStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||
s := &membershipStatements{
|
||||
db: db,
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertMembershipStmt, insertMembershipSQL},
|
||||
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
|
||||
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
|
||||
{&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL},
|
||||
{&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL},
|
||||
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
|
||||
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
|
||||
{&s.updateMembershipStmt, updateMembershipSQL},
|
||||
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
|
||||
{&s.selectKnownUsersStmt, selectKnownUsersSQL},
|
||||
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *membershipStatements) execSchema(db *sql.DB) error {
|
||||
_, err := db.Exec(membershipSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *membershipStatements) InsertMembership(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
localTarget bool,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipForUpdate(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) (membership tables.MembershipState, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt)
|
||||
err = stmt.QueryRowContext(
|
||||
ctx, roomNID, targetUserNID,
|
||||
).Scan(&membership)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
|
||||
ctx context.Context,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
|
||||
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
|
||||
ctx, roomNID, targetUserNID,
|
||||
).Scan(&membership, &eventNID, &forgotten)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipsFromRoom(
|
||||
ctx context.Context,
|
||||
roomNID types.RoomNID, localOnly bool,
|
||||
) (eventNIDs []types.EventNID, err error) {
|
||||
var selectStmt *sql.Stmt
|
||||
if localOnly {
|
||||
selectStmt = s.selectLocalMembershipsFromRoomStmt
|
||||
} else {
|
||||
selectStmt = s.selectMembershipsFromRoomStmt
|
||||
}
|
||||
rows, err := selectStmt.QueryContext(ctx, roomNID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoom: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
var eNID types.EventNID
|
||||
if err = rows.Scan(&eNID); err != nil {
|
||||
return
|
||||
}
|
||||
eventNIDs = append(eventNIDs, eNID)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
||||
ctx context.Context,
|
||||
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
|
||||
) (eventNIDs []types.EventNID, err error) {
|
||||
var stmt *sql.Stmt
|
||||
if localOnly {
|
||||
stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt
|
||||
} else {
|
||||
stmt = s.selectMembershipsFromRoomAndMembershipStmt
|
||||
}
|
||||
rows, err := stmt.QueryContext(ctx, roomNID, membership)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsFromRoomAndMembership: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
var eNID types.EventNID
|
||||
if err = rows.Scan(&eNID); err != nil {
|
||||
return
|
||||
}
|
||||
eventNIDs = append(eventNIDs, eNID)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipStatements) UpdateMembership(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
|
||||
eventNID types.EventNID, forgotten bool,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx, senderUserNID, membership, eventNID, forgotten, roomNID, targetUserNID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||
) ([]types.RoomNID, error) {
|
||||
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed")
|
||||
var roomNIDs []types.RoomNID
|
||||
for rows.Next() {
|
||||
var roomNID types.RoomNID
|
||||
if err := rows.Scan(&roomNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomNIDs = append(roomNIDs, roomNID)
|
||||
}
|
||||
return roomNIDs, nil
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
|
||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||
for i, v := range roomNIDs {
|
||||
iRoomNIDs[i] = v
|
||||
}
|
||||
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
|
||||
rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
|
||||
result := make(map[types.EventStateKeyNID]int)
|
||||
for rows.Next() {
|
||||
var userID types.EventStateKeyNID
|
||||
var count int
|
||||
if err := rows.Scan(&userID, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[userID] = count
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
|
||||
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := []string{}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, userID)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *membershipStatements) UpdateForgetMembership(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
forget bool,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
|
||||
ctx, forget, roomNID, targetUserNID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
131
roomserver/storage/cosmosdb/previous_events_table.go
Normal file
131
roomserver/storage/cosmosdb/previous_events_table.go
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
// TODO: previous_reference_sha256 was NOT NULL before but it broke sytest because
|
||||
// sytest sends no SHA256 sums in the prev_events references in the soft-fail tests.
|
||||
// In Postgres an empty BYTEA field is not NULL so it's fine there. In SQLite it
|
||||
// seems to care that it's empty and therefore hits a NOT NULL constraint on insert.
|
||||
// We should really work out what the right thing to do here is.
|
||||
const previousEventSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_previous_events (
|
||||
previous_event_id TEXT NOT NULL,
|
||||
previous_reference_sha256 BLOB,
|
||||
event_nids TEXT NOT NULL,
|
||||
UNIQUE (previous_event_id, previous_reference_sha256)
|
||||
);
|
||||
`
|
||||
|
||||
// Insert an entry into the previous_events table.
|
||||
// If there is already an entry indicating that an event references that previous event then
|
||||
// add the event NID to the list to indicate that this event references that previous event as well.
|
||||
// This should only be modified while holding a "FOR UPDATE" lock on the row in the rooms table for this room.
|
||||
// The lock is necessary to avoid data races when checking whether an event is already referenced by another event.
|
||||
const insertPreviousEventSQL = `
|
||||
INSERT OR REPLACE INTO roomserver_previous_events
|
||||
(previous_event_id, previous_reference_sha256, event_nids)
|
||||
VALUES ($1, $2, $3)
|
||||
`
|
||||
|
||||
const selectPreviousEventNIDsSQL = `
|
||||
SELECT event_nids FROM roomserver_previous_events
|
||||
WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
|
||||
`
|
||||
|
||||
// Check if the event is referenced by another event in the table.
|
||||
// This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room.
|
||||
const selectPreviousEventExistsSQL = `
|
||||
SELECT 1 FROM roomserver_previous_events
|
||||
WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
|
||||
`
|
||||
|
||||
type previousEventStatements struct {
|
||||
db *sql.DB
|
||||
insertPreviousEventStmt *sql.Stmt
|
||||
selectPreviousEventNIDsStmt *sql.Stmt
|
||||
selectPreviousEventExistsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
|
||||
s := &previousEventStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(previousEventSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertPreviousEventStmt, insertPreviousEventSQL},
|
||||
{&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL},
|
||||
{&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *previousEventStatements) InsertPreviousEvent(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
previousEventID string,
|
||||
previousEventReferenceSHA256 []byte,
|
||||
eventNID types.EventNID,
|
||||
) error {
|
||||
var eventNIDs string
|
||||
eventNIDAsString := fmt.Sprintf("%d", eventNID)
|
||||
selectStmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt)
|
||||
err := selectStmt.QueryRowContext(ctx, previousEventID, previousEventReferenceSHA256).Scan(&eventNIDs)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err)
|
||||
}
|
||||
var nids []string
|
||||
if eventNIDs != "" {
|
||||
nids = strings.Split(eventNIDs, ",")
|
||||
for _, nid := range nids {
|
||||
if nid == eventNIDAsString {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
eventNIDs = strings.Join(append(nids, eventNIDAsString), ",")
|
||||
} else {
|
||||
eventNIDs = eventNIDAsString
|
||||
}
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
|
||||
_, err = insertStmt.ExecContext(
|
||||
ctx, previousEventID, previousEventReferenceSHA256, eventNIDs,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if the event reference exists
|
||||
// Returns sql.ErrNoRows if the event reference doesn't exist.
|
||||
func (s *previousEventStatements) SelectPreviousEventExists(
|
||||
ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte,
|
||||
) error {
|
||||
var ok int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt)
|
||||
return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok)
|
||||
}
|
||||
105
roomserver/storage/cosmosdb/published_table.go
Normal file
105
roomserver/storage/cosmosdb/published_table.go
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
)
|
||||
|
||||
const publishedSchema = `
|
||||
-- Stores which rooms are published in the room directory
|
||||
CREATE TABLE IF NOT EXISTS roomserver_published (
|
||||
-- The room ID of the room
|
||||
room_id TEXT NOT NULL PRIMARY KEY,
|
||||
-- Whether it is published or not
|
||||
published BOOLEAN NOT NULL DEFAULT false
|
||||
);
|
||||
`
|
||||
|
||||
const upsertPublishedSQL = "" +
|
||||
"INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)"
|
||||
|
||||
const selectAllPublishedSQL = "" +
|
||||
"SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC"
|
||||
|
||||
const selectPublishedSQL = "" +
|
||||
"SELECT published FROM roomserver_published WHERE room_id = $1"
|
||||
|
||||
type publishedStatements struct {
|
||||
db *sql.DB
|
||||
upsertPublishedStmt *sql.Stmt
|
||||
selectAllPublishedStmt *sql.Stmt
|
||||
selectPublishedStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) {
|
||||
s := &publishedStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(publishedSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, shared.StatementList{
|
||||
{&s.upsertPublishedStmt, upsertPublishedSQL},
|
||||
{&s.selectAllPublishedStmt, selectAllPublishedSQL},
|
||||
{&s.selectPublishedStmt, selectPublishedSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *publishedStatements) UpsertRoomPublished(
|
||||
ctx context.Context, txn *sql.Tx, roomID string, published bool,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomID, published)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *publishedStatements) SelectPublishedFromRoomID(
|
||||
ctx context.Context, roomID string,
|
||||
) (published bool, err error) {
|
||||
err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *publishedStatements) SelectAllPublishedRooms(
|
||||
ctx context.Context, published bool,
|
||||
) ([]string, error) {
|
||||
rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAllPublishedStmt: rows.close() failed")
|
||||
|
||||
var roomIDs []string
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
if err = rows.Scan(&roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
roomIDs = append(roomIDs, roomID)
|
||||
}
|
||||
return roomIDs, rows.Err()
|
||||
}
|
||||
123
roomserver/storage/cosmosdb/redactions_table.go
Normal file
123
roomserver/storage/cosmosdb/redactions_table.go
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
)
|
||||
|
||||
const redactionsSchema = `
|
||||
-- Stores information about the redacted state of events.
|
||||
-- We need to track redactions rather than blindly updating the event JSON table on receipt of a redaction
|
||||
-- because we might receive the redaction BEFORE we receive the event which it redacts (think backfill).
|
||||
CREATE TABLE IF NOT EXISTS roomserver_redactions (
|
||||
redaction_event_id TEXT PRIMARY KEY,
|
||||
redacts_event_id TEXT NOT NULL,
|
||||
-- Initially FALSE, set to TRUE when the redaction has been validated according to rooms v3+ spec
|
||||
-- https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events
|
||||
validated BOOLEAN NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertRedactionSQL = "" +
|
||||
"INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" +
|
||||
" VALUES ($1, $2, $3)"
|
||||
|
||||
const selectRedactionInfoByRedactionEventIDSQL = "" +
|
||||
"SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
|
||||
" WHERE redaction_event_id = $1"
|
||||
|
||||
const selectRedactionInfoByEventBeingRedactedSQL = "" +
|
||||
"SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
|
||||
" WHERE redacts_event_id = $1"
|
||||
|
||||
const markRedactionValidatedSQL = "" +
|
||||
" UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1"
|
||||
|
||||
type redactionStatements struct {
|
||||
db *sql.DB
|
||||
insertRedactionStmt *sql.Stmt
|
||||
selectRedactionInfoByRedactionEventIDStmt *sql.Stmt
|
||||
selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt
|
||||
markRedactionValidatedStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) {
|
||||
s := &redactionStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(redactionsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertRedactionStmt, insertRedactionSQL},
|
||||
{&s.selectRedactionInfoByRedactionEventIDStmt, selectRedactionInfoByRedactionEventIDSQL},
|
||||
{&s.selectRedactionInfoByEventBeingRedactedStmt, selectRedactionInfoByEventBeingRedactedSQL},
|
||||
{&s.markRedactionValidatedStmt, markRedactionValidatedSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *redactionStatements) InsertRedaction(
|
||||
ctx context.Context, txn *sql.Tx, info tables.RedactionInfo,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt)
|
||||
_, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *redactionStatements) SelectRedactionInfoByRedactionEventID(
|
||||
ctx context.Context, txn *sql.Tx, redactionEventID string,
|
||||
) (info *tables.RedactionInfo, err error) {
|
||||
info = &tables.RedactionInfo{}
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRedactionInfoByRedactionEventIDStmt)
|
||||
err = stmt.QueryRowContext(ctx, redactionEventID).Scan(
|
||||
&info.RedactionEventID, &info.RedactsEventID, &info.Validated,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
info = nil
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted(
|
||||
ctx context.Context, txn *sql.Tx, eventID string,
|
||||
) (info *tables.RedactionInfo, err error) {
|
||||
info = &tables.RedactionInfo{}
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRedactionInfoByEventBeingRedactedStmt)
|
||||
err = stmt.QueryRowContext(ctx, eventID).Scan(
|
||||
&info.RedactionEventID, &info.RedactsEventID, &info.Validated,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
info = nil
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *redactionStatements) MarkRedactionValidated(
|
||||
ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt)
|
||||
_, err := stmt.ExecContext(ctx, redactionEventID, validated)
|
||||
return err
|
||||
}
|
||||
141
roomserver/storage/cosmosdb/room_aliases_table.go
Normal file
141
roomserver/storage/cosmosdb/room_aliases_table.go
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
)
|
||||
|
||||
const roomAliasesSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_room_aliases (
|
||||
alias TEXT NOT NULL PRIMARY KEY,
|
||||
room_id TEXT NOT NULL,
|
||||
creator_id TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id);
|
||||
`
|
||||
|
||||
const insertRoomAliasSQL = `
|
||||
INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3)
|
||||
`
|
||||
|
||||
const selectRoomIDFromAliasSQL = `
|
||||
SELECT room_id FROM roomserver_room_aliases WHERE alias = $1
|
||||
`
|
||||
|
||||
const selectAliasesFromRoomIDSQL = `
|
||||
SELECT alias FROM roomserver_room_aliases WHERE room_id = $1
|
||||
`
|
||||
|
||||
const selectCreatorIDFromAliasSQL = `
|
||||
SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1
|
||||
`
|
||||
|
||||
const deleteRoomAliasSQL = `
|
||||
DELETE FROM roomserver_room_aliases WHERE alias = $1
|
||||
`
|
||||
|
||||
type roomAliasesStatements struct {
|
||||
db *sql.DB
|
||||
insertRoomAliasStmt *sql.Stmt
|
||||
selectRoomIDFromAliasStmt *sql.Stmt
|
||||
selectAliasesFromRoomIDStmt *sql.Stmt
|
||||
selectCreatorIDFromAliasStmt *sql.Stmt
|
||||
deleteRoomAliasStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
|
||||
s := &roomAliasesStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(roomAliasesSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, shared.StatementList{
|
||||
{&s.insertRoomAliasStmt, insertRoomAliasSQL},
|
||||
{&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL},
|
||||
{&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL},
|
||||
{&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL},
|
||||
{&s.deleteRoomAliasStmt, deleteRoomAliasSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *roomAliasesStatements) InsertRoomAlias(
|
||||
ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertRoomAliasStmt)
|
||||
_, err := stmt.ExecContext(ctx, alias, roomID, creatorUserID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
||||
ctx context.Context, alias string,
|
||||
) (roomID string, err error) {
|
||||
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
||||
ctx context.Context, roomID string,
|
||||
) (aliases []string, err error) {
|
||||
aliases = []string{}
|
||||
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
var alias string
|
||||
if err = rows.Scan(&alias); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
aliases = append(aliases, alias)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
|
||||
ctx context.Context, alias string,
|
||||
) (creatorID string, err error) {
|
||||
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *roomAliasesStatements) DeleteRoomAlias(
|
||||
ctx context.Context, txn *sql.Tx, alias string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteRoomAliasStmt)
|
||||
_, err := stmt.ExecContext(ctx, alias)
|
||||
return err
|
||||
}
|
||||
296
roomserver/storage/cosmosdb/rooms_table.go
Normal file
296
roomserver/storage/cosmosdb/rooms_table.go
Normal file
|
|
@ -0,0 +1,296 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const roomsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_rooms (
|
||||
room_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
room_id TEXT NOT NULL UNIQUE,
|
||||
latest_event_nids TEXT NOT NULL DEFAULT '[]',
|
||||
last_event_sent_nid INTEGER NOT NULL DEFAULT 0,
|
||||
state_snapshot_nid INTEGER NOT NULL DEFAULT 0,
|
||||
room_version TEXT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
// Same as insertEventTypeNIDSQL
|
||||
const insertRoomNIDSQL = `
|
||||
INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2)
|
||||
ON CONFLICT DO NOTHING;
|
||||
`
|
||||
|
||||
const selectRoomNIDSQL = "" +
|
||||
"SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
|
||||
|
||||
const selectLatestEventNIDsSQL = "" +
|
||||
"SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
|
||||
|
||||
const selectLatestEventNIDsForUpdateSQL = "" +
|
||||
"SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
|
||||
|
||||
const updateLatestEventNIDsSQL = "" +
|
||||
"UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4"
|
||||
|
||||
const selectRoomVersionsForRoomNIDsSQL = "" +
|
||||
"SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)"
|
||||
|
||||
const selectRoomInfoSQL = "" +
|
||||
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
|
||||
|
||||
const selectRoomIDsSQL = "" +
|
||||
"SELECT room_id FROM roomserver_rooms"
|
||||
|
||||
const bulkSelectRoomIDsSQL = "" +
|
||||
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
|
||||
|
||||
const bulkSelectRoomNIDsSQL = "" +
|
||||
"SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)"
|
||||
|
||||
type roomStatements struct {
|
||||
db *sql.DB
|
||||
insertRoomNIDStmt *sql.Stmt
|
||||
selectRoomNIDStmt *sql.Stmt
|
||||
selectLatestEventNIDsStmt *sql.Stmt
|
||||
selectLatestEventNIDsForUpdateStmt *sql.Stmt
|
||||
updateLatestEventNIDsStmt *sql.Stmt
|
||||
//selectRoomVersionForRoomNIDStmt *sql.Stmt
|
||||
selectRoomInfoStmt *sql.Stmt
|
||||
selectRoomIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||
s := &roomStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(roomsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, shared.StatementList{
|
||||
{&s.insertRoomNIDStmt, insertRoomNIDSQL},
|
||||
{&s.selectRoomNIDStmt, selectRoomNIDSQL},
|
||||
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
|
||||
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
|
||||
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
||||
//{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL},
|
||||
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
||||
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
|
||||
rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
|
||||
var roomIDs []string
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
if err = rows.Scan(&roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomIDs = append(roomIDs, roomID)
|
||||
}
|
||||
return roomIDs, nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
||||
var info types.RoomInfo
|
||||
var latestNIDsJSON string
|
||||
err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan(
|
||||
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
var latestNIDs []int64
|
||||
if err = json.Unmarshal([]byte(latestNIDsJSON), &latestNIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info.IsStub = len(latestNIDs) == 0
|
||||
return &info, err
|
||||
}
|
||||
|
||||
func (s *roomStatements) InsertRoomNID(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomID string, roomVersion gomatrixserverlib.RoomVersion,
|
||||
) (roomNID types.RoomNID, err error) {
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
|
||||
_, err = insertStmt.ExecContext(ctx, roomID, roomVersion)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insertStmt.ExecContext: %w", err)
|
||||
}
|
||||
roomNID, err = s.SelectRoomNID(ctx, txn, roomID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("s.SelectRoomNID: %w", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomNID(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (types.RoomNID, error) {
|
||||
var roomNID int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomNIDStmt)
|
||||
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
|
||||
return types.RoomNID(roomNID), err
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectLatestEventNIDs(
|
||||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
|
||||
) ([]types.EventNID, types.StateSnapshotNID, error) {
|
||||
var eventNIDs []types.EventNID
|
||||
var nidsJSON string
|
||||
var stateSnapshotNID int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt)
|
||||
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nidsJSON, &stateSnapshotNID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if err := json.Unmarshal([]byte(nidsJSON), &eventNIDs); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectLatestEventsNIDsForUpdate(
|
||||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
|
||||
) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) {
|
||||
var eventNIDs []types.EventNID
|
||||
var nidsJSON string
|
||||
var lastEventSentNID int64
|
||||
var stateSnapshotNID int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt)
|
||||
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nidsJSON, &lastEventSentNID, &stateSnapshotNID)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
if err := json.Unmarshal([]byte(nidsJSON), &eventNIDs); err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) UpdateLatestEventNIDs(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
roomNID types.RoomNID,
|
||||
eventNIDs []types.EventNID,
|
||||
lastEventSentNID types.EventNID,
|
||||
stateSnapshotNID types.StateSnapshotNID,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
eventNIDsAsArray(eventNIDs),
|
||||
int64(lastEventSentNID),
|
||||
int64(stateSnapshotNID),
|
||||
roomNID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
||||
ctx context.Context, roomNIDs []types.RoomNID,
|
||||
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
|
||||
sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||
sqlPrep, err := s.db.Prepare(sqlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||
for i, v := range roomNIDs {
|
||||
iRoomNIDs[i] = v
|
||||
}
|
||||
rows, err := sqlPrep.QueryContext(ctx, iRoomNIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
|
||||
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
|
||||
for rows.Next() {
|
||||
var roomNID types.RoomNID
|
||||
var roomVersion gomatrixserverlib.RoomVersion
|
||||
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[roomNID] = roomVersion
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
|
||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||
for i, v := range roomNIDs {
|
||||
iRoomNIDs[i] = v
|
||||
}
|
||||
sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
|
||||
var roomIDs []string
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
if err = rows.Scan(&roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomIDs = append(roomIDs, roomID)
|
||||
}
|
||||
return roomIDs, nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
|
||||
iRoomIDs := make([]interface{}, len(roomIDs))
|
||||
for i, v := range roomIDs {
|
||||
iRoomIDs[i] = v
|
||||
}
|
||||
sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
|
||||
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
|
||||
var roomNIDs []types.RoomNID
|
||||
for rows.Next() {
|
||||
var roomNID types.RoomNID
|
||||
if err = rows.Scan(&roomNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomNIDs = append(roomNIDs, roomNID)
|
||||
}
|
||||
return roomNIDs, nil
|
||||
}
|
||||
289
roomserver/storage/cosmosdb/state_block_table.go
Normal file
289
roomserver/storage/cosmosdb/state_block_table.go
Normal file
|
|
@ -0,0 +1,289 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
const stateDataSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_state_block (
|
||||
state_block_nid INTEGER NOT NULL,
|
||||
event_type_nid INTEGER NOT NULL,
|
||||
event_state_key_nid INTEGER NOT NULL,
|
||||
event_nid INTEGER NOT NULL,
|
||||
UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
|
||||
);
|
||||
`
|
||||
|
||||
const insertStateDataSQL = "" +
|
||||
"INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
|
||||
" VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const selectNextStateBlockNIDSQL = `
|
||||
SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block
|
||||
`
|
||||
|
||||
// Bulk state lookup by numeric state block ID.
|
||||
// Sort by the state_block_nid, event_type_nid, event_state_key_nid
|
||||
// This means that all the entries for a given state_block_nid will appear
|
||||
// together in the list and those entries will sorted by event_type_nid
|
||||
// and event_state_key_nid. This property makes it easier to merge two
|
||||
// state data blocks together.
|
||||
const bulkSelectStateBlockEntriesSQL = "" +
|
||||
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
||||
" FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
|
||||
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
||||
|
||||
// Bulk state lookup by numeric state block ID.
|
||||
// Filters the rows in each block to the requested types and state keys.
|
||||
// We would like to restrict to particular type state key pairs but we are
|
||||
// restricted by the query language to pull the cross product of a list
|
||||
// of types and a list state_keys. So we have to filter the result in the
|
||||
// application to restrict it to the list of event types and state keys we
|
||||
// actually wanted.
|
||||
const bulkSelectFilteredStateBlockEntriesSQL = "" +
|
||||
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
||||
" FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
|
||||
" AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" +
|
||||
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
||||
|
||||
type stateBlockStatements struct {
|
||||
db *sql.DB
|
||||
insertStateDataStmt *sql.Stmt
|
||||
selectNextStateBlockNIDStmt *sql.Stmt
|
||||
bulkSelectStateBlockEntriesStmt *sql.Stmt
|
||||
bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
||||
s := &stateBlockStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(stateDataSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertStateDataStmt, insertStateDataSQL},
|
||||
{&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
|
||||
{&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL},
|
||||
{&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *stateBlockStatements) BulkInsertStateData(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
entries []types.StateEntry,
|
||||
) (types.StateBlockNID, error) {
|
||||
if len(entries) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
var stateBlockNID types.StateBlockNID
|
||||
err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for _, entry := range entries {
|
||||
_, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext(
|
||||
ctx,
|
||||
int64(stateBlockNID),
|
||||
int64(entry.EventTypeNID),
|
||||
int64(entry.EventStateKeyNID),
|
||||
int64(entry.EventNID),
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return stateBlockNID, err
|
||||
}
|
||||
|
||||
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
||||
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
||||
) ([]types.StateEntryList, error) {
|
||||
nids := make([]interface{}, len(stateBlockNIDs))
|
||||
for k, v := range stateBlockNIDs {
|
||||
nids[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := selectStmt.QueryContext(ctx, nids...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed")
|
||||
|
||||
results := make([]types.StateEntryList, len(stateBlockNIDs))
|
||||
// current is a pointer to the StateEntryList to append the state entries to.
|
||||
var current *types.StateEntryList
|
||||
i := 0
|
||||
for rows.Next() {
|
||||
var (
|
||||
stateBlockNID int64
|
||||
eventTypeNID int64
|
||||
eventStateKeyNID int64
|
||||
eventNID int64
|
||||
entry types.StateEntry
|
||||
)
|
||||
if err := rows.Scan(
|
||||
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
|
||||
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
|
||||
entry.EventNID = types.EventNID(eventNID)
|
||||
if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
|
||||
// The state entry row is for a different state data block to the current one.
|
||||
// So we start appending to the next entry in the list.
|
||||
current = &results[i]
|
||||
current.StateBlockNID = types.StateBlockNID(stateBlockNID)
|
||||
i++
|
||||
}
|
||||
current.StateEntries = append(current.StateEntries, entry)
|
||||
}
|
||||
if i != len(nids) {
|
||||
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids))
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
|
||||
ctx context.Context,
|
||||
stateBlockNIDs []types.StateBlockNID,
|
||||
stateKeyTuples []types.StateKeyTuple,
|
||||
) ([]types.StateEntryList, error) {
|
||||
tuples := stateKeyTupleSorter(stateKeyTuples)
|
||||
// Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
|
||||
sort.Sort(tuples)
|
||||
|
||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
||||
sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1)
|
||||
sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
|
||||
sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
|
||||
|
||||
var params []interface{}
|
||||
for _, val := range stateBlockNIDs {
|
||||
params = append(params, int64(val))
|
||||
}
|
||||
for _, val := range eventTypeNIDArray {
|
||||
params = append(params, val)
|
||||
}
|
||||
for _, val := range eventStateKeyNIDArray {
|
||||
params = append(params, val)
|
||||
}
|
||||
|
||||
rows, err := s.db.QueryContext(
|
||||
ctx,
|
||||
sqlStatement,
|
||||
params...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed")
|
||||
|
||||
var results []types.StateEntryList
|
||||
var current types.StateEntryList
|
||||
for rows.Next() {
|
||||
var (
|
||||
stateBlockNID int64
|
||||
eventTypeNID int64
|
||||
eventStateKeyNID int64
|
||||
eventNID int64
|
||||
entry types.StateEntry
|
||||
)
|
||||
if err := rows.Scan(
|
||||
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
|
||||
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
|
||||
entry.EventNID = types.EventNID(eventNID)
|
||||
|
||||
// We can use binary search here because we sorted the tuples earlier
|
||||
if !tuples.contains(entry.StateKeyTuple) {
|
||||
// The select will return the cross product of types and state keys.
|
||||
// So we need to check if type of the entry is in the list.
|
||||
continue
|
||||
}
|
||||
|
||||
if types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
|
||||
// The state entry row is for a different state data block to the current one.
|
||||
// So we append the current entry to the results and start adding to a new one.
|
||||
// The first time through the loop current will be empty.
|
||||
if current.StateEntries != nil {
|
||||
results = append(results, current)
|
||||
}
|
||||
current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)}
|
||||
}
|
||||
current.StateEntries = append(current.StateEntries, entry)
|
||||
}
|
||||
// Add the last entry to the list if it is not empty.
|
||||
if current.StateEntries != nil {
|
||||
results = append(results, current)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
type stateKeyTupleSorter []types.StateKeyTuple
|
||||
|
||||
func (s stateKeyTupleSorter) Len() int { return len(s) }
|
||||
func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
|
||||
func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
||||
// Check whether a tuple is in the list. Assumes that the list is sorted.
|
||||
func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
|
||||
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
|
||||
return i < len(s) && s[i] == value
|
||||
}
|
||||
|
||||
// List the unique eventTypeNIDs and eventStateKeyNIDs.
|
||||
// Assumes that the list is sorted.
|
||||
func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) {
|
||||
eventTypeNIDs = make([]int64, len(s))
|
||||
eventStateKeyNIDs = make([]int64, len(s))
|
||||
for i := range s {
|
||||
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
|
||||
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
|
||||
}
|
||||
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
|
||||
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
|
||||
return
|
||||
}
|
||||
|
||||
type int64Sorter []int64
|
||||
|
||||
func (s int64Sorter) Len() int { return len(s) }
|
||||
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
|
||||
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
86
roomserver/storage/cosmosdb/state_block_table_test.go
Normal file
86
roomserver/storage/cosmosdb/state_block_table_test.go
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
func TestStateKeyTupleSorter(t *testing.T) {
|
||||
input := stateKeyTupleSorter{
|
||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
||||
}
|
||||
want := []types.StateKeyTuple{
|
||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
||||
}
|
||||
doNotWant := []types.StateKeyTuple{
|
||||
{EventTypeNID: 0, EventStateKeyNID: 0},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 3},
|
||||
{EventTypeNID: 2, EventStateKeyNID: 1},
|
||||
{EventTypeNID: 3, EventStateKeyNID: 1},
|
||||
}
|
||||
wantTypeNIDs := []int64{1, 2}
|
||||
wantStateKeyNIDs := []int64{1, 2, 4}
|
||||
|
||||
// Sort the input and check it's in the right order.
|
||||
sort.Sort(input)
|
||||
gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
|
||||
|
||||
for i := range want {
|
||||
if input[i] != want[i] {
|
||||
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
|
||||
}
|
||||
|
||||
if !input.contains(want[i]) {
|
||||
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
|
||||
}
|
||||
}
|
||||
|
||||
for i := range doNotWant {
|
||||
if input.contains(doNotWant[i]) {
|
||||
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
|
||||
}
|
||||
}
|
||||
|
||||
if len(wantTypeNIDs) != len(gotTypeNIDs) {
|
||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||
}
|
||||
|
||||
for i := range wantTypeNIDs {
|
||||
if wantTypeNIDs[i] != gotTypeNIDs[i] {
|
||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||
}
|
||||
}
|
||||
|
||||
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
|
||||
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
|
||||
}
|
||||
|
||||
for i := range wantStateKeyNIDs {
|
||||
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
|
||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
126
roomserver/storage/cosmosdb/state_snapshot_table.go
Normal file
126
roomserver/storage/cosmosdb/state_snapshot_table.go
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
const stateSnapshotSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
|
||||
state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
room_nid INTEGER NOT NULL,
|
||||
state_block_nids TEXT NOT NULL DEFAULT '[]'
|
||||
);
|
||||
`
|
||||
|
||||
const insertStateSQL = `
|
||||
INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
|
||||
VALUES ($1, $2);`
|
||||
|
||||
// Bulk state data NID lookup.
|
||||
// Sorting by state_snapshot_nid means we can use binary search over the result
|
||||
// to lookup the state data NIDs for a state snapshot NID.
|
||||
const bulkSelectStateBlockNIDsSQL = "" +
|
||||
"SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
|
||||
" WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
|
||||
|
||||
type stateSnapshotStatements struct {
|
||||
db *sql.DB
|
||||
insertStateStmt *sql.Stmt
|
||||
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||
s := &stateSnapshotStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(stateSnapshotSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertStateStmt, insertStateSQL},
|
||||
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *stateSnapshotStatements) InsertState(
|
||||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
|
||||
) (stateNID types.StateSnapshotNID, err error) {
|
||||
stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
|
||||
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
lastRowID, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
stateNID = types.StateSnapshotNID(lastRowID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||
) ([]types.StateBlockNIDList, error) {
|
||||
nids := make([]interface{}, len(stateNIDs))
|
||||
for k, v := range stateNIDs {
|
||||
nids[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := selectStmt.QueryContext(ctx, nids...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed")
|
||||
results := make([]types.StateBlockNIDList, len(stateNIDs))
|
||||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
result := &results[i]
|
||||
var stateBlockNIDsJSON string
|
||||
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := json.Unmarshal([]byte(stateBlockNIDsJSON), &result.StateBlockNIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if i != len(stateNIDs) {
|
||||
return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
189
roomserver/storage/cosmosdb/storage.go
Normal file
189
roomserver/storage/cosmosdb/storage.go
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// A Database is used to store room events and stream offsets.
|
||||
type Database struct {
|
||||
shared.Database
|
||||
}
|
||||
|
||||
// Open a sqlite database.
|
||||
func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
|
||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
var d Database
|
||||
var db *sql.DB
|
||||
var err error
|
||||
if db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//db.Exec("PRAGMA journal_mode=WAL;")
|
||||
//db.Exec("PRAGMA read_uncommitted = true;")
|
||||
|
||||
// FIXME: We are leaking connections somewhere. Setting this to 2 will eventually
|
||||
// cause the roomserver to be unresponsive to new events because something will
|
||||
// acquire the global mutex and never unlock it because it is waiting for a connection
|
||||
// which it will never obtain.
|
||||
db.SetMaxOpenConns(20)
|
||||
|
||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
||||
ms := membershipStatements{}
|
||||
if err := ms.execSchema(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := sqlutil.NewMigrations()
|
||||
deltas.LoadAddForgottenColumn(m)
|
||||
if err := m.RunDeltas(db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := d.prepare(db, cache); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
// nolint: gocyclo
|
||||
func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
|
||||
var err error
|
||||
eventStateKeys, err := NewSqliteEventStateKeysTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
eventTypes, err := NewSqliteEventTypesTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
eventJSON, err := NewSqliteEventJSONTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
events, err := NewSqliteEventsTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rooms, err := NewSqliteRoomsTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
transactions, err := NewSqliteTransactionsTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stateBlock, err := NewSqliteStateBlockTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stateSnapshot, err := NewSqliteStateSnapshotTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
prevEvents, err := NewSqlitePrevEventsTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
roomAliases, err := NewSqliteRoomAliasesTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
invites, err := NewSqliteInvitesTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
membership, err := NewSqliteMembershipTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
published, err := NewSqlitePublishedTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
redactions, err := NewSqliteRedactionsTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.Database = shared.Database{
|
||||
DB: db,
|
||||
Cache: cache,
|
||||
Writer: sqlutil.NewExclusiveWriter(),
|
||||
EventsTable: events,
|
||||
EventTypesTable: eventTypes,
|
||||
EventStateKeysTable: eventStateKeys,
|
||||
EventJSONTable: eventJSON,
|
||||
RoomsTable: rooms,
|
||||
TransactionsTable: transactions,
|
||||
StateBlockTable: stateBlock,
|
||||
StateSnapshotTable: stateSnapshot,
|
||||
PrevEventsTable: prevEvents,
|
||||
RoomAliasesTable: roomAliases,
|
||||
InvitesTable: invites,
|
||||
MembershipTable: membership,
|
||||
PublishedTable: published,
|
||||
RedactionsTable: redactions,
|
||||
GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) SupportsConcurrentRoomInputs() bool {
|
||||
// This isn't supported in SQLite mode yet because of issues with
|
||||
// database locks.
|
||||
// TODO: Look at this again - the problem is probably to do with
|
||||
// the membership updaters and latest events updaters.
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *Database) GetLatestEventsForUpdate(
|
||||
ctx context.Context, roomInfo types.RoomInfo,
|
||||
) (*shared.LatestEventsUpdater, error) {
|
||||
// TODO: Do not use transactions. We should be holding open this transaction but we cannot have
|
||||
// multiple write transactions on sqlite. The code will perform additional
|
||||
// write transactions independent of this one which will consistently cause
|
||||
// 'database is locked' errors. As sqlite doesn't support multi-process on the
|
||||
// same DB anyway, and we only execute updates sequentially, the only worries
|
||||
// are for rolling back when things go wrong. (atomicity)
|
||||
return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo)
|
||||
}
|
||||
|
||||
func (d *Database) MembershipUpdater(
|
||||
ctx context.Context, roomID, targetUserID string,
|
||||
targetLocal bool, roomVersion gomatrixserverlib.RoomVersion,
|
||||
) (*shared.MembershipUpdater, error) {
|
||||
// TODO: Do not use transactions. We should be holding open this transaction but we cannot have
|
||||
// multiple write transactions on sqlite. The code will perform additional
|
||||
// write transactions independent of this one which will consistently cause
|
||||
// 'database is locked' errors. As sqlite doesn't support multi-process on the
|
||||
// same DB anyway, and we only execute updates sequentially, the only worries
|
||||
// are for rolling back when things go wrong. (atomicity)
|
||||
return shared.NewMembershipUpdater(ctx, &d.Database, nil, roomID, targetUserID, targetLocal, roomVersion)
|
||||
}
|
||||
91
roomserver/storage/cosmosdb/transactions_table.go
Normal file
91
roomserver/storage/cosmosdb/transactions_table.go
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
)
|
||||
|
||||
const transactionsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_transactions (
|
||||
transaction_id TEXT NOT NULL,
|
||||
session_id INTEGER NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
PRIMARY KEY (transaction_id, session_id, user_id)
|
||||
);
|
||||
`
|
||||
const insertTransactionSQL = `
|
||||
INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
`
|
||||
|
||||
const selectTransactionEventIDSQL = `
|
||||
SELECT event_id FROM roomserver_transactions
|
||||
WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3
|
||||
`
|
||||
|
||||
type transactionStatements struct {
|
||||
db *sql.DB
|
||||
insertTransactionStmt *sql.Stmt
|
||||
selectTransactionEventIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) {
|
||||
s := &transactionStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(transactionsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertTransactionStmt, insertTransactionSQL},
|
||||
{&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *transactionStatements) InsertTransaction(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
transactionID string,
|
||||
sessionID int64,
|
||||
userID string,
|
||||
eventID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx, transactionID, sessionID, userID, eventID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *transactionStatements) SelectTransactionEventID(
|
||||
ctx context.Context,
|
||||
transactionID string,
|
||||
sessionID int64,
|
||||
userID string,
|
||||
) (eventID string, err error) {
|
||||
err = s.selectTransactionEventIDStmt.QueryRowContext(
|
||||
ctx, transactionID, sessionID, userID,
|
||||
).Scan(&eventID)
|
||||
return
|
||||
}
|
||||
|
|
@ -20,6 +20,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/cosmosdb"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
|
|
@ -28,6 +29,8 @@ import (
|
|||
// Open opens a database connection.
|
||||
func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsCosmosDB():
|
||||
return cosmosdb.Open(dbProperties, cache)
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.Open(dbProperties, cache)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ import (
|
|||
// NewPublicRoomsServerDatabase opens a database connection.
|
||||
func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsCosmosDB():
|
||||
return nil, fmt.Errorf("can't use CosmosDB implementation")
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.Open(dbProperties, cache)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
|
|
|
|||
|
|
@ -139,6 +139,11 @@ func (d DataSource) IsPostgres() bool {
|
|||
return !d.IsSQLite()
|
||||
}
|
||||
|
||||
func (d DataSource) IsCosmosDB() bool {
|
||||
// commented line may not always be true?
|
||||
return strings.HasPrefix(string(d), "cosmosdb:")
|
||||
}
|
||||
|
||||
// A Topic in kafka.
|
||||
type Topic string
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package kafka
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"github.com/Shopify/sarama"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/naffka"
|
||||
|
|
@ -46,6 +47,10 @@ func setupNaffka(cfg *config.Kafka) (sarama.Consumer, sarama.SyncProducer) {
|
|||
if naffkaInstance != nil {
|
||||
return naffkaInstance, naffkaInstance
|
||||
}
|
||||
if(cfg.Database.ConnectionString.IsCosmosDB()) {
|
||||
cfg.Database.ConnectionString = cosmosdbutil.GetConnectionString(&cfg.Database.ConnectionString)
|
||||
}
|
||||
|
||||
naffkaDB, err := naffkaStorage.NewDatabase(string(cfg.Database.ConnectionString))
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panic("Failed to setup naffka database")
|
||||
|
|
|
|||
|
|
@ -59,6 +59,9 @@ type DB struct {
|
|||
|
||||
// NewDatabase loads the database for msc2836
|
||||
func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
|
||||
if dbOpts.ConnectionString.IsCosmosDB() {
|
||||
return newCosmosDBDatabase(dbOpts)
|
||||
}
|
||||
if dbOpts.ConnectionString.IsPostgres() {
|
||||
return newPostgresDatabase(dbOpts)
|
||||
}
|
||||
|
|
@ -225,6 +228,86 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
|
|||
return &d, nil
|
||||
}
|
||||
|
||||
func newCosmosDBDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
|
||||
d := DB{
|
||||
writer: sqlutil.NewExclusiveWriter(),
|
||||
}
|
||||
var err error
|
||||
if d.db, err = sqlutil.Open(dbOpts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = d.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS msc2836_edges (
|
||||
parent_event_id TEXT NOT NULL,
|
||||
child_event_id TEXT NOT NULL,
|
||||
rel_type TEXT NOT NULL,
|
||||
parent_room_id TEXT NOT NULL,
|
||||
parent_servers TEXT NOT NULL,
|
||||
UNIQUE (parent_event_id, child_event_id, rel_type)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS msc2836_nodes (
|
||||
event_id TEXT PRIMARY KEY NOT NULL,
|
||||
origin_server_ts BIGINT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
unsigned_children_count BIGINT NOT NULL,
|
||||
unsigned_children_hash TEXT NOT NULL,
|
||||
explored SMALLINT NOT NULL
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d.insertEdgeStmt, err = d.db.Prepare(`
|
||||
INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers)
|
||||
VALUES($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (parent_event_id, child_event_id, rel_type) DO NOTHING
|
||||
`); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d.insertNodeStmt, err = d.db.Prepare(`
|
||||
INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id, unsigned_children_count, unsigned_children_hash, explored)
|
||||
VALUES($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT DO NOTHING
|
||||
`); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectChildrenQuery := `
|
||||
SELECT child_event_id, origin_server_ts, room_id FROM msc2836_edges
|
||||
LEFT JOIN msc2836_nodes ON msc2836_edges.child_event_id = msc2836_nodes.event_id
|
||||
WHERE parent_event_id = $1 AND rel_type = $2
|
||||
ORDER BY origin_server_ts
|
||||
`
|
||||
if d.selectChildrenForParentOldestFirstStmt, err = d.db.Prepare(selectChildrenQuery + "ASC"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d.selectParentForChildStmt, err = d.db.Prepare(`
|
||||
SELECT parent_event_id, parent_room_id FROM msc2836_edges
|
||||
WHERE child_event_id = $1 AND rel_type = $2
|
||||
`); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d.updateChildMetadataStmt, err = d.db.Prepare(`
|
||||
UPDATE msc2836_nodes SET unsigned_children_count=$1, unsigned_children_hash=$2, explored=$3 WHERE event_id=$4
|
||||
`); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d.selectChildMetadataStmt, err = d.db.Prepare(`
|
||||
SELECT unsigned_children_count, unsigned_children_hash, explored FROM msc2836_nodes WHERE event_id=$1
|
||||
`); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d.updateChildMetadataExploredStmt, err = d.db.Prepare(`
|
||||
UPDATE msc2836_nodes SET explored=$1 WHERE event_id=$2
|
||||
`); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
func (p *DB) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error {
|
||||
parent, child, relType := parentChildEventIDs(ev)
|
||||
if parent == "" || child == "" {
|
||||
|
|
|
|||
|
|
@ -47,6 +47,9 @@ type DB struct {
|
|||
|
||||
// NewDatabase loads the database for msc2836
|
||||
func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
|
||||
if dbOpts.ConnectionString.IsCosmosDB() {
|
||||
return newCosmosDBDatabase(dbOpts)
|
||||
}
|
||||
if dbOpts.ConnectionString.IsPostgres() {
|
||||
return newPostgresDatabase(dbOpts)
|
||||
}
|
||||
|
|
@ -133,6 +136,46 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
|
|||
return &d, err
|
||||
}
|
||||
|
||||
func newCosmosDBDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
|
||||
d := DB{
|
||||
writer: sqlutil.NewExclusiveWriter(),
|
||||
}
|
||||
var err error
|
||||
if d.db, err = sqlutil.Open(dbOpts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = d.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS msc2946_edges (
|
||||
room_version TEXT NOT NULL,
|
||||
-- the room ID of the event, the source of the arrow
|
||||
source_room_id TEXT NOT NULL,
|
||||
-- the target room ID, the arrow destination
|
||||
dest_room_id TEXT NOT NULL,
|
||||
-- the kind of relation, either child or parent (1,2)
|
||||
rel_type SMALLINT NOT NULL,
|
||||
event_json TEXT NOT NULL,
|
||||
UNIQUE (source_room_id, dest_room_id, rel_type)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d.insertEdgeStmt, err = d.db.Prepare(`
|
||||
INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json)
|
||||
VALUES($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (source_room_id, dest_room_id, rel_type) DO UPDATE SET event_json = $5
|
||||
`); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d.selectEdgesStmt, err = d.db.Prepare(`
|
||||
SELECT room_version, event_json FROM msc2946_edges
|
||||
WHERE source_room_id = $1 OR dest_room_id = $2
|
||||
`); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &d, err
|
||||
}
|
||||
|
||||
func (d *DB) StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error {
|
||||
target := SpaceTarget(he)
|
||||
if target == "" {
|
||||
|
|
|
|||
101
signingkeyserver/storage/cosmosdb/keydb.go
Normal file
101
signingkeyserver/storage/cosmosdb/keydb.go
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"context"
|
||||
|
||||
"golang.org/x/crypto/ed25519"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// A Database implements gomatrixserverlib.KeyDatabase and is used to store
|
||||
// the public keys for other matrix servers.
|
||||
type Database struct {
|
||||
writer sqlutil.Writer
|
||||
statements serverKeyStatements
|
||||
}
|
||||
|
||||
// NewDatabase prepares a new key database.
|
||||
// It creates the necessary tables if they don't already exist.
|
||||
// It prepares all the SQL statements that it will use.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
func NewDatabase(
|
||||
dbProperties *config.DatabaseOptions,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
serverKey ed25519.PublicKey,
|
||||
serverKeyID gomatrixserverlib.KeyID,
|
||||
) (*Database, error) {
|
||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := &Database{
|
||||
writer: sqlutil.NewExclusiveWriter(),
|
||||
}
|
||||
err = d.statements.prepare(db, d.writer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// FetcherName implements KeyFetcher
|
||||
func (d Database) FetcherName() string {
|
||||
return "SqliteKeyDatabase"
|
||||
}
|
||||
|
||||
// FetchKeys implements gomatrixserverlib.KeyDatabase
|
||||
func (d *Database) FetchKeys(
|
||||
ctx context.Context,
|
||||
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
|
||||
return d.statements.bulkSelectServerKeys(ctx, requests)
|
||||
}
|
||||
|
||||
// StoreKeys implements gomatrixserverlib.KeyDatabase
|
||||
func (d *Database) StoreKeys(
|
||||
ctx context.Context,
|
||||
keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
|
||||
) error {
|
||||
// TODO: Inserting all the keys within a single transaction may
|
||||
// be more efficient since the transaction overhead can be quite
|
||||
// high for a single insert statement.
|
||||
var lastErr error
|
||||
for request, keys := range keyMap {
|
||||
if err := d.statements.upsertServerKeys(ctx, request, keys); err != nil {
|
||||
// Rather than returning immediately on error we try to insert the
|
||||
// remaining keys.
|
||||
// Since we are inserting the keys outside of a transaction it is
|
||||
// possible for some of the inserts to succeed even though some
|
||||
// of the inserts have failed.
|
||||
// Ensuring that we always insert all the keys we can means that
|
||||
// this behaviour won't depend on the iteration order of the map.
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
159
signingkeyserver/storage/cosmosdb/server_key_table.go
Normal file
159
signingkeyserver/storage/cosmosdb/server_key_table.go
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const serverKeysSchema = `
|
||||
-- A cache of signing keys downloaded from remote servers.
|
||||
CREATE TABLE IF NOT EXISTS keydb_server_keys (
|
||||
-- The name of the matrix server the key is for.
|
||||
server_name TEXT NOT NULL,
|
||||
-- The ID of the server key.
|
||||
server_key_id TEXT NOT NULL,
|
||||
-- Combined server name and key ID separated by the ASCII unit separator
|
||||
-- to make it easier to run bulk queries.
|
||||
server_name_and_key_id TEXT NOT NULL,
|
||||
-- When the key is valid until as a millisecond timestamp.
|
||||
-- 0 if this is an expired key (in which case expired_ts will be non-zero)
|
||||
valid_until_ts BIGINT NOT NULL,
|
||||
-- When the key expired as a millisecond timestamp.
|
||||
-- 0 if this is an active key (in which case valid_until_ts will be non-zero)
|
||||
expired_ts BIGINT NOT NULL,
|
||||
-- The base64-encoded public key.
|
||||
server_key TEXT NOT NULL,
|
||||
UNIQUE (server_name, server_key_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id);
|
||||
`
|
||||
|
||||
const bulkSelectServerKeysSQL = "" +
|
||||
"SELECT server_name, server_key_id, valid_until_ts, expired_ts, " +
|
||||
" server_key FROM keydb_server_keys" +
|
||||
" WHERE server_name_and_key_id IN ($1)"
|
||||
|
||||
const upsertServerKeysSQL = "" +
|
||||
"INSERT INTO keydb_server_keys (server_name, server_key_id," +
|
||||
" server_name_and_key_id, valid_until_ts, expired_ts, server_key)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT (server_name, server_key_id)" +
|
||||
" DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
|
||||
|
||||
type serverKeyStatements struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
bulkSelectServerKeysStmt *sql.Stmt
|
||||
upsertServerKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *serverKeyStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||
s.db = db
|
||||
s.writer = writer
|
||||
_, err = db.Exec(serverKeysSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.bulkSelectServerKeysStmt, err = db.Prepare(bulkSelectServerKeysSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.upsertServerKeysStmt, err = db.Prepare(upsertServerKeysSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *serverKeyStatements) bulkSelectServerKeys(
|
||||
ctx context.Context,
|
||||
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
|
||||
nameAndKeyIDs := make([]string, 0, len(requests))
|
||||
for request := range requests {
|
||||
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
|
||||
}
|
||||
results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests))
|
||||
iKeyIDs := make([]interface{}, len(nameAndKeyIDs))
|
||||
for i, v := range nameAndKeyIDs {
|
||||
iKeyIDs[i] = v
|
||||
}
|
||||
|
||||
err := sqlutil.RunLimitedVariablesQuery(
|
||||
ctx, bulkSelectServerKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables,
|
||||
func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
var serverName string
|
||||
var keyID string
|
||||
var key string
|
||||
var validUntilTS int64
|
||||
var expiredTS int64
|
||||
if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
|
||||
return fmt.Errorf("bulkSelectServerKeys: %v", err)
|
||||
}
|
||||
r := gomatrixserverlib.PublicKeyLookupRequest{
|
||||
ServerName: gomatrixserverlib.ServerName(serverName),
|
||||
KeyID: gomatrixserverlib.KeyID(keyID),
|
||||
}
|
||||
vk := gomatrixserverlib.VerifyKey{}
|
||||
err := vk.Key.Decode(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bulkSelectServerKeys: %v", err)
|
||||
}
|
||||
results[r] = gomatrixserverlib.PublicKeyLookupResult{
|
||||
VerifyKey: vk,
|
||||
ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
|
||||
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *serverKeyStatements) upsertServerKeys(
|
||||
ctx context.Context,
|
||||
request gomatrixserverlib.PublicKeyLookupRequest,
|
||||
key gomatrixserverlib.PublicKeyLookupResult,
|
||||
) error {
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
string(request.ServerName),
|
||||
string(request.KeyID),
|
||||
nameAndKeyID(request),
|
||||
key.ValidUntilTS,
|
||||
key.ExpiredTS,
|
||||
key.Key.Encode(),
|
||||
)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {
|
||||
return string(request.ServerName) + "\x1F" + string(request.KeyID)
|
||||
}
|
||||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"golang.org/x/crypto/ed25519"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/signingkeyserver/storage/cosmosdb"
|
||||
"github.com/matrix-org/dendrite/signingkeyserver/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/signingkeyserver/storage/sqlite3"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -35,6 +36,8 @@ func NewDatabase(
|
|||
serverKeyID gomatrixserverlib.KeyID,
|
||||
) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsCosmosDB():
|
||||
return cosmosdb.NewDatabase(dbProperties, serverName, serverKey, serverKeyID)
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties, serverName, serverKey, serverKeyID)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
|
|
|
|||
156
syncapi/storage/cosmosdb/account_data_table.go
Normal file
156
syncapi/storage/cosmosdb/account_data_table.go
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const accountDataSchema = `
|
||||
CREATE TABLE IF NOT EXISTS syncapi_account_data_type (
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
UNIQUE (user_id, room_id, type)
|
||||
);
|
||||
`
|
||||
|
||||
const insertAccountDataSQL = "" +
|
||||
"INSERT INTO syncapi_account_data_type (id, user_id, room_id, type) VALUES ($1, $2, $3, $4)" +
|
||||
" ON CONFLICT (user_id, room_id, type) DO UPDATE" +
|
||||
" SET id = $5"
|
||||
|
||||
const selectAccountDataInRangeSQL = "" +
|
||||
"SELECT room_id, type FROM syncapi_account_data_type" +
|
||||
" WHERE user_id = $1 AND id > $2 AND id <= $3" +
|
||||
" ORDER BY id ASC"
|
||||
|
||||
const selectMaxAccountDataIDSQL = "" +
|
||||
"SELECT MAX(id) FROM syncapi_account_data_type"
|
||||
|
||||
type accountDataStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
insertAccountDataStmt *sql.Stmt
|
||||
selectMaxAccountDataIDStmt *sql.Stmt
|
||||
selectAccountDataInRangeStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
|
||||
s := &accountDataStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
}
|
||||
_, err := db.Exec(accountDataSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) InsertAccountData(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
userID, roomID, dataType string,
|
||||
) (pos types.StreamPosition, err error) {
|
||||
pos, err = s.streamIDStatements.nextAccountDataID(ctx, txn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType, pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) SelectAccountDataInRange(
|
||||
ctx context.Context,
|
||||
userID string,
|
||||
r types.Range,
|
||||
accountDataFilterPart *gomatrixserverlib.EventFilter,
|
||||
) (data map[string][]string, err error) {
|
||||
data = make(map[string][]string)
|
||||
|
||||
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed")
|
||||
|
||||
var entries int
|
||||
|
||||
for rows.Next() {
|
||||
var dataType string
|
||||
var roomID string
|
||||
|
||||
if err = rows.Scan(&roomID, &dataType); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// check if we should add this by looking at the filter.
|
||||
// It would be nice if we could do this in SQL-land, but the mix of variadic
|
||||
// and positional parameters makes the query annoyingly hard to do, it's easier
|
||||
// and clearer to do it in Go-land. If there are no filters for [not]types then
|
||||
// this gets skipped.
|
||||
for _, includeType := range accountDataFilterPart.Types {
|
||||
if includeType != dataType { // TODO: wildcard support
|
||||
continue
|
||||
}
|
||||
}
|
||||
for _, excludeType := range accountDataFilterPart.NotTypes {
|
||||
if excludeType == dataType { // TODO: wildcard support
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if len(data[roomID]) > 0 {
|
||||
data[roomID] = append(data[roomID], dataType)
|
||||
} else {
|
||||
data[roomID] = []string{dataType}
|
||||
}
|
||||
entries++
|
||||
if entries >= accountDataFilterPart.Limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) SelectMaxAccountDataID(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) (id int64, err error) {
|
||||
var nullableID sql.NullInt64
|
||||
err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID)
|
||||
if nullableID.Valid {
|
||||
id = nullableID.Int64
|
||||
}
|
||||
return
|
||||
}
|
||||
125
syncapi/storage/cosmosdb/backwards_extremities_table.go
Normal file
125
syncapi/storage/cosmosdb/backwards_extremities_table.go
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
// Copyright 2018 New Vector Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
)
|
||||
|
||||
const backwardExtremitiesSchema = `
|
||||
-- Stores output room events received from the roomserver.
|
||||
CREATE TABLE IF NOT EXISTS syncapi_backward_extremities (
|
||||
-- The 'room_id' key for the event.
|
||||
room_id TEXT NOT NULL,
|
||||
-- The event ID for the last known event. This is the backwards extremity.
|
||||
event_id TEXT NOT NULL,
|
||||
-- The prev_events for the last known event. This is used to update extremities.
|
||||
prev_event_id TEXT NOT NULL,
|
||||
PRIMARY KEY(room_id, event_id, prev_event_id)
|
||||
);
|
||||
`
|
||||
|
||||
const insertBackwardExtremitySQL = "" +
|
||||
"INSERT INTO syncapi_backward_extremities (room_id, event_id, prev_event_id)" +
|
||||
" VALUES ($1, $2, $3)" +
|
||||
" ON CONFLICT (room_id, event_id, prev_event_id) DO NOTHING"
|
||||
|
||||
const selectBackwardExtremitiesForRoomSQL = "" +
|
||||
"SELECT event_id, prev_event_id FROM syncapi_backward_extremities WHERE room_id = $1"
|
||||
|
||||
const deleteBackwardExtremitySQL = "" +
|
||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
|
||||
|
||||
const deleteBackwardExtremitiesForRoomSQL = "" +
|
||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
|
||||
|
||||
type backwardExtremitiesStatements struct {
|
||||
db *sql.DB
|
||||
insertBackwardExtremityStmt *sql.Stmt
|
||||
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||
deleteBackwardExtremityStmt *sql.Stmt
|
||||
deleteBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
|
||||
s := &backwardExtremitiesStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(backwardExtremitiesSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
|
||||
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
|
||||
ctx context.Context, roomID string,
|
||||
) (bwExtrems map[string][]string, err error) {
|
||||
rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectBackwardExtremitiesForRoom: rows.close() failed")
|
||||
|
||||
bwExtrems = make(map[string][]string)
|
||||
for rows.Next() {
|
||||
var eID string
|
||||
var prevEventID string
|
||||
if err = rows.Scan(&eID, &prevEventID); err != nil {
|
||||
return
|
||||
}
|
||||
bwExtrems[eID] = append(bwExtrems[eID], prevEventID)
|
||||
}
|
||||
|
||||
return bwExtrems, rows.Err()
|
||||
}
|
||||
|
||||
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
||||
ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
324
syncapi/storage/cosmosdb/current_room_state_table.go
Normal file
324
syncapi/storage/cosmosdb/current_room_state_table.go
Normal file
|
|
@ -0,0 +1,324 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const currentRoomStateSchema = `
|
||||
-- Stores the current room state for every room.
|
||||
CREATE TABLE IF NOT EXISTS syncapi_current_room_state (
|
||||
room_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
contains_url BOOL NOT NULL DEFAULT false,
|
||||
state_key TEXT NOT NULL,
|
||||
headered_event_json TEXT NOT NULL,
|
||||
membership TEXT,
|
||||
added_at BIGINT,
|
||||
UNIQUE (room_id, type, state_key)
|
||||
);
|
||||
-- for event deletion
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url);
|
||||
-- for querying membership states of users
|
||||
-- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave';
|
||||
-- for querying state by event IDs
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id);
|
||||
`
|
||||
|
||||
const upsertRoomStateSQL = "" +
|
||||
"INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
|
||||
" ON CONFLICT (room_id, type, state_key)" +
|
||||
" DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9"
|
||||
|
||||
const deleteRoomStateByEventIDSQL = "" +
|
||||
"DELETE FROM syncapi_current_room_state WHERE event_id = $1"
|
||||
|
||||
const DeleteRoomStateForRoomSQL = "" +
|
||||
"DELETE FROM syncapi_current_room_state WHERE event_id = $1"
|
||||
|
||||
const selectRoomIDsWithMembershipSQL = "" +
|
||||
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
|
||||
|
||||
const selectCurrentStateSQL = "" +
|
||||
"SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1"
|
||||
// WHEN, ORDER BY and LIMIT will be added by prepareWithFilter
|
||||
|
||||
const selectJoinedUsersSQL = "" +
|
||||
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
|
||||
|
||||
const selectStateEventSQL = "" +
|
||||
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
|
||||
|
||||
const selectEventsWithEventIDsSQL = "" +
|
||||
// TODO: The session_id and transaction_id blanks are here because otherwise
|
||||
// the rowsToStreamEvents expects there to be exactly six columns. We need to
|
||||
// figure out if these really need to be in the DB, and if so, we need a
|
||||
// better permanent fix for this. - neilalexander, 2 Jan 2020
|
||||
"SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
|
||||
" FROM syncapi_current_room_state WHERE event_id IN ($1)"
|
||||
|
||||
type currentRoomStateStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
upsertRoomStateStmt *sql.Stmt
|
||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||
DeleteRoomStateForRoomStmt *sql.Stmt
|
||||
selectRoomIDsWithMembershipStmt *sql.Stmt
|
||||
selectJoinedUsersStmt *sql.Stmt
|
||||
selectStateEventStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
|
||||
s := ¤tRoomStateStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
}
|
||||
_, err := db.Exec(currentRoomStateSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.DeleteRoomStateForRoomStmt, err = db.Prepare(DeleteRoomStateForRoomSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// JoinedMemberLists returns a map of room ID to a list of joined user IDs.
|
||||
func (s *currentRoomStateStatements) SelectJoinedUsers(
|
||||
ctx context.Context,
|
||||
) (map[string][]string, error) {
|
||||
rows, err := s.selectJoinedUsersStmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed")
|
||||
|
||||
result := make(map[string][]string)
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
var userID string
|
||||
if err := rows.Scan(&roomID, &userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users := result[roomID]
|
||||
users = append(users, userID)
|
||||
result[roomID] = users
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
|
||||
func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
userID string,
|
||||
membership string, // nolint: unparam
|
||||
) ([]string, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithMembershipStmt)
|
||||
rows, err := stmt.QueryContext(ctx, userID, membership)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsWithMembership: rows.close() failed")
|
||||
|
||||
var result []string
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
if err := rows.Scan(&roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, roomID)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CurrentState returns all the current state events for the given room.
|
||||
func (s *currentRoomStateStatements) SelectCurrentState(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
stateFilter *gomatrixserverlib.StateFilter,
|
||||
excludeEventIDs []string,
|
||||
) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||
stmt, params, err := prepareWithFilters(
|
||||
s.db, txn, selectCurrentStateSQL,
|
||||
[]interface{}{
|
||||
roomID,
|
||||
},
|
||||
stateFilter.Senders, stateFilter.NotSenders,
|
||||
stateFilter.Types, stateFilter.NotTypes,
|
||||
excludeEventIDs, stateFilter.Limit, FilterOrderNone,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||
}
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCurrentState: rows.close() failed")
|
||||
|
||||
return rowsToEvents(rows)
|
||||
}
|
||||
|
||||
func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
|
||||
ctx context.Context, txn *sql.Tx, eventID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
|
||||
_, err := stmt.ExecContext(ctx, eventID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *currentRoomStateStatements) DeleteRoomStateForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.DeleteRoomStateForRoomStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *currentRoomStateStatements) UpsertRoomState(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
event *gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition,
|
||||
) error {
|
||||
// Parse content as JSON and search for an "url" key
|
||||
containsURL := false
|
||||
var content map[string]interface{}
|
||||
if json.Unmarshal(event.Content(), &content) != nil {
|
||||
// Set containsURL to true if url is present
|
||||
_, containsURL = content["url"]
|
||||
}
|
||||
|
||||
headeredJSON, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// upsert state event
|
||||
stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
|
||||
_, err = stmt.ExecContext(
|
||||
ctx,
|
||||
event.RoomID(),
|
||||
event.EventID(),
|
||||
event.Type(),
|
||||
event.Sender(),
|
||||
containsURL,
|
||||
*event.StateKey(),
|
||||
headeredJSON,
|
||||
membership,
|
||||
addedAt,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func minOfInts(a, b int) int {
|
||||
if a <= b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
) ([]types.StreamEvent, error) {
|
||||
iEventIDs := make([]interface{}, len(eventIDs))
|
||||
for k, v := range eventIDs {
|
||||
iEventIDs[k] = v
|
||||
}
|
||||
res := make([]types.StreamEvent, 0, len(eventIDs))
|
||||
var start int
|
||||
for start < len(eventIDs) {
|
||||
n := minOfInts(len(eventIDs)-start, 999)
|
||||
query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", sqlutil.QueryVariadic(n), 1)
|
||||
rows, err := txn.QueryContext(ctx, query, iEventIDs[start:start+n]...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
start = start + n
|
||||
events, err := rowsToStreamEvents(rows)
|
||||
internal.CloseAndLogIfError(ctx, rows, "selectEventsWithEventIDs: rows.close() failed")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res = append(res, events...)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||
result := []*gomatrixserverlib.HeaderedEvent{}
|
||||
for rows.Next() {
|
||||
var eventID string
|
||||
var eventBytes []byte
|
||||
if err := rows.Scan(&eventID, &eventBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO: Handle redacted events
|
||||
var ev gomatrixserverlib.HeaderedEvent
|
||||
if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, &ev)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *currentRoomStateStatements) SelectStateEvent(
|
||||
ctx context.Context, roomID, evType, stateKey string,
|
||||
) (*gomatrixserverlib.HeaderedEvent, error) {
|
||||
stmt := s.selectStateEventStmt
|
||||
var res []byte
|
||||
err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var ev gomatrixserverlib.HeaderedEvent
|
||||
if err = json.Unmarshal(res, &ev); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ev, err
|
||||
}
|
||||
140
syncapi/storage/cosmosdb/filter_table.go
Normal file
140
syncapi/storage/cosmosdb/filter_table.go
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
// Copyright 2017 Jan Christian Grünhage
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const filterSchema = `
|
||||
-- Stores data about filters
|
||||
CREATE TABLE IF NOT EXISTS syncapi_filter (
|
||||
-- The filter
|
||||
filter TEXT NOT NULL,
|
||||
-- The ID
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The localpart of the Matrix user ID associated to this filter
|
||||
localpart TEXT NOT NULL,
|
||||
|
||||
UNIQUE (id, localpart)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS syncapi_filter_localpart ON syncapi_filter(localpart);
|
||||
`
|
||||
|
||||
const selectFilterSQL = "" +
|
||||
"SELECT filter FROM syncapi_filter WHERE localpart = $1 AND id = $2"
|
||||
|
||||
const selectFilterIDByContentSQL = "" +
|
||||
"SELECT id FROM syncapi_filter WHERE localpart = $1 AND filter = $2"
|
||||
|
||||
const insertFilterSQL = "" +
|
||||
"INSERT INTO syncapi_filter (filter, localpart) VALUES ($1, $2)"
|
||||
|
||||
type filterStatements struct {
|
||||
db *sql.DB
|
||||
selectFilterStmt *sql.Stmt
|
||||
selectFilterIDByContentStmt *sql.Stmt
|
||||
insertFilterStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
|
||||
_, err := db.Exec(filterSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &filterStatements{
|
||||
db: db,
|
||||
}
|
||||
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *filterStatements) SelectFilter(
|
||||
ctx context.Context, localpart string, filterID string,
|
||||
) (*gomatrixserverlib.Filter, error) {
|
||||
// Retrieve filter from database (stored as canonical JSON)
|
||||
var filterData []byte
|
||||
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Unmarshal JSON into Filter struct
|
||||
filter := gomatrixserverlib.DefaultFilter()
|
||||
if err = json.Unmarshal(filterData, &filter); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &filter, nil
|
||||
}
|
||||
|
||||
func (s *filterStatements) InsertFilter(
|
||||
ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
|
||||
) (filterID string, err error) {
|
||||
var existingFilterID string
|
||||
|
||||
// Serialise json
|
||||
filterJSON, err := json.Marshal(filter)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Remove whitespaces and sort JSON data
|
||||
// needed to prevent from inserting the same filter multiple times
|
||||
filterJSON, err = gomatrixserverlib.CanonicalJSON(filterJSON)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Check if filter already exists in the database using its localpart and content
|
||||
//
|
||||
// This can result in a race condition when two clients try to insert the
|
||||
// same filter and localpart at the same time, however this is not a
|
||||
// problem as both calls will result in the same filterID
|
||||
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
|
||||
localpart, filterJSON).Scan(&existingFilterID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return "", err
|
||||
}
|
||||
// If it does, return the existing ID
|
||||
if existingFilterID != "" {
|
||||
return existingFilterID, nil
|
||||
}
|
||||
|
||||
// Otherwise insert the filter and return the new ID
|
||||
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
rowid, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
filterID = fmt.Sprintf("%d", rowid)
|
||||
return
|
||||
}
|
||||
82
syncapi/storage/cosmosdb/filtering.go
Normal file
82
syncapi/storage/cosmosdb/filtering.go
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
package cosmosdb
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
type FilterOrder int
|
||||
|
||||
const (
|
||||
FilterOrderNone = iota
|
||||
FilterOrderAsc
|
||||
FilterOrderDesc
|
||||
)
|
||||
|
||||
// prepareWithFilters returns a prepared statement with the
|
||||
// relevant filters included. It also includes an []interface{}
|
||||
// list of all the relevant parameters to pass straight to
|
||||
// QueryContext, QueryRowContext etc.
|
||||
// We don't take the filter object directly here because the
|
||||
// fields might come from either a StateFilter or an EventFilter,
|
||||
// and it's easier just to have the caller extract the relevant
|
||||
// parts.
|
||||
func prepareWithFilters(
|
||||
db *sql.DB, txn *sql.Tx, query string, params []interface{},
|
||||
senders, notsenders, types, nottypes []string, excludeEventIDs []string,
|
||||
limit int, order FilterOrder,
|
||||
) (*sql.Stmt, []interface{}, error) {
|
||||
offset := len(params)
|
||||
if count := len(senders); count > 0 {
|
||||
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range senders {
|
||||
params, offset = append(params, v), offset+1
|
||||
}
|
||||
}
|
||||
if count := len(notsenders); count > 0 {
|
||||
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range notsenders {
|
||||
params, offset = append(params, v), offset+1
|
||||
}
|
||||
}
|
||||
if count := len(types); count > 0 {
|
||||
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range types {
|
||||
params, offset = append(params, v), offset+1
|
||||
}
|
||||
}
|
||||
if count := len(nottypes); count > 0 {
|
||||
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range nottypes {
|
||||
params, offset = append(params, v), offset+1
|
||||
}
|
||||
}
|
||||
if count := len(excludeEventIDs); count > 0 {
|
||||
query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range excludeEventIDs {
|
||||
params, offset = append(params, v), offset+1
|
||||
}
|
||||
}
|
||||
switch order {
|
||||
case FilterOrderAsc:
|
||||
query += " ORDER BY id ASC"
|
||||
case FilterOrderDesc:
|
||||
query += " ORDER BY id DESC"
|
||||
}
|
||||
query += fmt.Sprintf(" LIMIT $%d", offset+1)
|
||||
params = append(params, limit)
|
||||
|
||||
var stmt *sql.Stmt
|
||||
var err error
|
||||
if txn != nil {
|
||||
stmt, err = txn.Prepare(query)
|
||||
} else {
|
||||
stmt, err = db.Prepare(query)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("s.db.Prepare: %w", err)
|
||||
}
|
||||
return stmt, params, nil
|
||||
}
|
||||
185
syncapi/storage/cosmosdb/invites_table.go
Normal file
185
syncapi/storage/cosmosdb/invites_table.go
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const inviteEventsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS syncapi_invite_events (
|
||||
id INTEGER PRIMARY KEY,
|
||||
event_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
target_user_id TEXT NOT NULL,
|
||||
headered_event_json TEXT NOT NULL,
|
||||
deleted BOOL NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS syncapi_invites_target_user_id_idx ON syncapi_invite_events (target_user_id, id);
|
||||
CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx ON syncapi_invite_events (event_id);
|
||||
`
|
||||
|
||||
const insertInviteEventSQL = "" +
|
||||
"INSERT INTO syncapi_invite_events" +
|
||||
" (id, room_id, event_id, target_user_id, headered_event_json, deleted)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, false)"
|
||||
|
||||
const deleteInviteEventSQL = "" +
|
||||
"UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2"
|
||||
|
||||
const selectInviteEventsInRangeSQL = "" +
|
||||
"SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" +
|
||||
" WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
|
||||
" ORDER BY id DESC"
|
||||
|
||||
const selectMaxInviteIDSQL = "" +
|
||||
"SELECT MAX(id) FROM syncapi_invite_events"
|
||||
|
||||
type inviteEventsStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
insertInviteEventStmt *sql.Stmt
|
||||
selectInviteEventsInRangeStmt *sql.Stmt
|
||||
deleteInviteEventStmt *sql.Stmt
|
||||
selectMaxInviteIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
|
||||
s := &inviteEventsStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
}
|
||||
_, err := db.Exec(inviteEventsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *inviteEventsStatements) InsertInviteEvent(
|
||||
ctx context.Context, txn *sql.Tx, inviteEvent *gomatrixserverlib.HeaderedEvent,
|
||||
) (streamPos types.StreamPosition, err error) {
|
||||
streamPos, err = s.streamIDStatements.nextInviteID(ctx, txn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var headeredJSON []byte
|
||||
headeredJSON, err = json.Marshal(inviteEvent)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
|
||||
_, err = stmt.ExecContext(
|
||||
ctx,
|
||||
streamPos,
|
||||
inviteEvent.RoomID(),
|
||||
inviteEvent.EventID(),
|
||||
*inviteEvent.StateKey(),
|
||||
headeredJSON,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *inviteEventsStatements) DeleteInviteEvent(
|
||||
ctx context.Context, txn *sql.Tx, inviteEventID string,
|
||||
) (types.StreamPosition, error) {
|
||||
streamPos, err := s.streamIDStatements.nextInviteID(ctx, txn)
|
||||
if err != nil {
|
||||
return streamPos, err
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteInviteEventStmt)
|
||||
_, err = stmt.ExecContext(ctx, streamPos, inviteEventID)
|
||||
return streamPos, err
|
||||
}
|
||||
|
||||
// selectInviteEventsInRange returns a map of room ID to invite event for the
|
||||
// active invites for the target user ID in the supplied range.
|
||||
func (s *inviteEventsStatements) SelectInviteEventsInRange(
|
||||
ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range,
|
||||
) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt)
|
||||
rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed")
|
||||
result := map[string]*gomatrixserverlib.HeaderedEvent{}
|
||||
retired := map[string]*gomatrixserverlib.HeaderedEvent{}
|
||||
for rows.Next() {
|
||||
var (
|
||||
roomID string
|
||||
eventJSON []byte
|
||||
deleted bool
|
||||
)
|
||||
if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// if we have seen this room before, it has a higher stream position and hence takes priority
|
||||
// because the query is ORDER BY id DESC so drop them
|
||||
_, isRetired := retired[roomID]
|
||||
_, isInvited := result[roomID]
|
||||
if isRetired || isInvited {
|
||||
continue
|
||||
}
|
||||
|
||||
var event *gomatrixserverlib.HeaderedEvent
|
||||
if err := json.Unmarshal(eventJSON, &event); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if deleted {
|
||||
retired[roomID] = event
|
||||
} else {
|
||||
result[roomID] = event
|
||||
}
|
||||
}
|
||||
return result, retired, nil
|
||||
}
|
||||
|
||||
func (s *inviteEventsStatements) SelectMaxInviteID(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) (id int64, err error) {
|
||||
var nullableID sql.NullInt64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMaxInviteIDStmt)
|
||||
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
|
||||
if nullableID.Valid {
|
||||
id = nullableID.Int64
|
||||
}
|
||||
return
|
||||
}
|
||||
119
syncapi/storage/cosmosdb/memberships_table.go
Normal file
119
syncapi/storage/cosmosdb/memberships_table.go
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// The memberships table is designed to track the last time that
|
||||
// the user was a given state. This allows us to find out the
|
||||
// most recent time that a user was invited to, joined or left
|
||||
// a room, either by choice or otherwise. This is important for
|
||||
// building history visibility.
|
||||
|
||||
const membershipsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS syncapi_memberships (
|
||||
-- The 'room_id' key for the state event.
|
||||
room_id TEXT NOT NULL,
|
||||
-- The state event ID
|
||||
user_id TEXT NOT NULL,
|
||||
-- The status of the membership
|
||||
membership TEXT NOT NULL,
|
||||
-- The event ID that last changed the membership
|
||||
event_id TEXT NOT NULL,
|
||||
-- The stream position of the change
|
||||
stream_pos BIGINT NOT NULL,
|
||||
-- The topological position of the change in the room
|
||||
topological_pos BIGINT NOT NULL,
|
||||
-- Unique index
|
||||
UNIQUE (room_id, user_id, membership)
|
||||
);
|
||||
`
|
||||
|
||||
const upsertMembershipSQL = "" +
|
||||
"INSERT INTO syncapi_memberships (room_id, user_id, membership, event_id, stream_pos, topological_pos)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT (room_id, user_id, membership)" +
|
||||
" DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
|
||||
|
||||
const selectMembershipSQL = "" +
|
||||
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
|
||||
" WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" +
|
||||
" ORDER BY stream_pos DESC" +
|
||||
" LIMIT 1"
|
||||
|
||||
type membershipsStatements struct {
|
||||
db *sql.DB
|
||||
upsertMembershipStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
|
||||
s := &membershipsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(membershipsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *membershipsStatements) UpsertMembership(
|
||||
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent,
|
||||
streamPos, topologicalPos types.StreamPosition,
|
||||
) error {
|
||||
membership, err := event.Membership()
|
||||
if err != nil {
|
||||
return fmt.Errorf("event.Membership: %w", err)
|
||||
}
|
||||
_, err = sqlutil.TxStmt(txn, s.upsertMembershipStmt).ExecContext(
|
||||
ctx,
|
||||
event.RoomID(),
|
||||
*event.StateKey(),
|
||||
membership,
|
||||
event.EventID(),
|
||||
streamPos,
|
||||
topologicalPos,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *membershipsStatements) SelectMembership(
|
||||
ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
|
||||
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
|
||||
params := []interface{}{roomID, userID}
|
||||
for _, membership := range memberships {
|
||||
params = append(params, membership)
|
||||
}
|
||||
orig := strings.Replace(selectMembershipSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
|
||||
stmt, err := s.db.Prepare(orig)
|
||||
if err != nil {
|
||||
return "", 0, 0, err
|
||||
}
|
||||
err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos)
|
||||
return
|
||||
}
|
||||
477
syncapi/storage/cosmosdb/output_room_events_table.go
Normal file
477
syncapi/storage/cosmosdb/output_room_events_table.go
Normal file
|
|
@ -0,0 +1,477 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const outputRoomEventsSchema = `
|
||||
-- Stores output room events received from the roomserver.
|
||||
CREATE TABLE IF NOT EXISTS syncapi_output_room_events (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
event_id TEXT NOT NULL UNIQUE,
|
||||
room_id TEXT NOT NULL,
|
||||
headered_event_json TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
contains_url BOOL NOT NULL,
|
||||
add_state_ids TEXT, -- JSON encoded string array
|
||||
remove_state_ids TEXT, -- JSON encoded string array
|
||||
session_id BIGINT,
|
||||
transaction_id TEXT,
|
||||
exclude_from_sync BOOL NOT NULL DEFAULT FALSE
|
||||
);
|
||||
`
|
||||
|
||||
const insertEventSQL = "" +
|
||||
"INSERT INTO syncapi_output_room_events (" +
|
||||
"id, room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" +
|
||||
") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " +
|
||||
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
|
||||
|
||||
const selectEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1"
|
||||
|
||||
const selectRecentEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||
" WHERE room_id = $1 AND id > $2 AND id <= $3"
|
||||
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
|
||||
|
||||
const selectRecentEventsForSyncSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE"
|
||||
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
|
||||
|
||||
const selectEarlyEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||
" WHERE room_id = $1 AND id > $2 AND id <= $3"
|
||||
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
|
||||
|
||||
const selectMaxEventIDSQL = "" +
|
||||
"SELECT MAX(id) FROM syncapi_output_room_events"
|
||||
|
||||
const updateEventJSONSQL = "" +
|
||||
"UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2"
|
||||
|
||||
const selectStateInRangeSQL = "" +
|
||||
"SELECT id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" +
|
||||
" FROM syncapi_output_room_events" +
|
||||
" WHERE (id > $1 AND id <= $2)" +
|
||||
" AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))"
|
||||
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
|
||||
|
||||
const deleteEventsForRoomSQL = "" +
|
||||
"DELETE FROM syncapi_output_room_events WHERE room_id = $1"
|
||||
|
||||
type outputRoomEventsStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
insertEventStmt *sql.Stmt
|
||||
selectEventsStmt *sql.Stmt
|
||||
selectMaxEventIDStmt *sql.Stmt
|
||||
updateEventJSONStmt *sql.Stmt
|
||||
deleteEventsForRoomStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
|
||||
s := &outputRoomEventsStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
}
|
||||
_, err := db.Exec(outputRoomEventsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectEventsStmt, err = db.Prepare(selectEventsSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMaxEventIDStmt, err = db.Prepare(selectMaxEventIDSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.updateEventJSONStmt, err = db.Prepare(updateEventJSONSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteEventsForRoomStmt, err = db.Prepare(deleteEventsForRoomSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error {
|
||||
headeredJSON, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
|
||||
return err
|
||||
}
|
||||
|
||||
// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
|
||||
// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
|
||||
// two positions, only the most recent state is returned.
|
||||
func (s *outputRoomEventsStatements) SelectStateInRange(
|
||||
ctx context.Context, txn *sql.Tx, r types.Range,
|
||||
stateFilter *gomatrixserverlib.StateFilter,
|
||||
) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
|
||||
stmt, params, err := prepareWithFilters(
|
||||
s.db, txn, selectStateInRangeSQL,
|
||||
[]interface{}{
|
||||
r.Low(), r.High(),
|
||||
},
|
||||
stateFilter.Senders, stateFilter.NotSenders,
|
||||
stateFilter.Types, stateFilter.NotTypes,
|
||||
nil, stateFilter.Limit, FilterOrderAsc,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||
}
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close() // nolint: errcheck
|
||||
// Fetch all the state change events for all rooms between the two positions then loop each event and:
|
||||
// - Keep a cache of the event by ID (99% of state change events are for the event itself)
|
||||
// - For each room ID, build up an array of event IDs which represents cumulative adds/removes
|
||||
// For each room, map cumulative event IDs to events and return. This may need to a batch SELECT based on event ID
|
||||
// if they aren't in the event ID cache. We don't handle state deletion yet.
|
||||
eventIDToEvent := make(map[string]types.StreamEvent)
|
||||
|
||||
// RoomID => A set (map[string]bool) of state event IDs which are between the two positions
|
||||
stateNeeded := make(map[string]map[string]bool)
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
streamPos types.StreamPosition
|
||||
eventBytes []byte
|
||||
excludeFromSync bool
|
||||
addIDsJSON string
|
||||
delIDsJSON string
|
||||
)
|
||||
if err := rows.Scan(&streamPos, &eventBytes, &excludeFromSync, &addIDsJSON, &delIDsJSON); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
addIDs, delIDs, err := unmarshalStateIDs(addIDsJSON, delIDsJSON)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Sanity check for deleted state and whine if we see it. We don't need to do anything
|
||||
// since it'll just mark the event as not being needed.
|
||||
if len(addIDs) < len(delIDs) {
|
||||
log.WithFields(log.Fields{
|
||||
"since": r.From,
|
||||
"current": r.To,
|
||||
"adds": addIDsJSON,
|
||||
"dels": delIDsJSON,
|
||||
}).Warn("StateBetween: ignoring deleted state")
|
||||
}
|
||||
|
||||
// TODO: Handle redacted events
|
||||
var ev gomatrixserverlib.HeaderedEvent
|
||||
if err := json.Unmarshal(eventBytes, &ev); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
needSet := stateNeeded[ev.RoomID()]
|
||||
if needSet == nil { // make set if required
|
||||
needSet = make(map[string]bool)
|
||||
}
|
||||
for _, id := range delIDs {
|
||||
needSet[id] = false
|
||||
}
|
||||
for _, id := range addIDs {
|
||||
needSet[id] = true
|
||||
}
|
||||
stateNeeded[ev.RoomID()] = needSet
|
||||
|
||||
eventIDToEvent[ev.EventID()] = types.StreamEvent{
|
||||
HeaderedEvent: &ev,
|
||||
StreamPosition: streamPos,
|
||||
ExcludeFromSync: excludeFromSync,
|
||||
}
|
||||
}
|
||||
|
||||
return stateNeeded, eventIDToEvent, nil
|
||||
}
|
||||
|
||||
// MaxID returns the ID of the last inserted event in this table. 'txn' is optional. If it is not supplied,
|
||||
// then this function should only ever be used at startup, as it will race with inserting events if it is
|
||||
// done afterwards. If there are no inserted events, 0 is returned.
|
||||
func (s *outputRoomEventsStatements) SelectMaxEventID(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) (id int64, err error) {
|
||||
var nullableID sql.NullInt64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt)
|
||||
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
|
||||
if nullableID.Valid {
|
||||
id = nullableID.Int64
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position
|
||||
// of the inserted event.
|
||||
func (s *outputRoomEventsStatements) InsertEvent(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
event *gomatrixserverlib.HeaderedEvent, addState, removeState []string,
|
||||
transactionID *api.TransactionID, excludeFromSync bool,
|
||||
) (types.StreamPosition, error) {
|
||||
var txnID *string
|
||||
var sessionID *int64
|
||||
if transactionID != nil {
|
||||
sessionID = &transactionID.SessionID
|
||||
txnID = &transactionID.TransactionID
|
||||
}
|
||||
|
||||
// Parse content as JSON and search for an "url" key
|
||||
containsURL := false
|
||||
var content map[string]interface{}
|
||||
if json.Unmarshal(event.Content(), &content) != nil {
|
||||
// Set containsURL to true if url is present
|
||||
_, containsURL = content["url"]
|
||||
}
|
||||
|
||||
var headeredJSON []byte
|
||||
headeredJSON, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var addStateJSON, removeStateJSON []byte
|
||||
if len(addState) > 0 {
|
||||
addStateJSON, err = json.Marshal(addState)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("json.Marshal(addState): %w", err)
|
||||
}
|
||||
if len(removeState) > 0 {
|
||||
removeStateJSON, err = json.Marshal(removeState)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("json.Marshal(removeState): %w", err)
|
||||
}
|
||||
|
||||
streamPos, err := s.streamIDStatements.nextPDUID(ctx, txn)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
|
||||
_, err = insertStmt.ExecContext(
|
||||
ctx,
|
||||
streamPos,
|
||||
event.RoomID(),
|
||||
event.EventID(),
|
||||
headeredJSON,
|
||||
event.Type(),
|
||||
event.Sender(),
|
||||
containsURL,
|
||||
string(addStateJSON),
|
||||
string(removeStateJSON),
|
||||
sessionID,
|
||||
txnID,
|
||||
excludeFromSync,
|
||||
excludeFromSync,
|
||||
)
|
||||
return streamPos, err
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsStatements) SelectRecentEvents(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
|
||||
chronologicalOrder bool, onlySyncEvents bool,
|
||||
) ([]types.StreamEvent, bool, error) {
|
||||
var query string
|
||||
if onlySyncEvents {
|
||||
query = selectRecentEventsForSyncSQL
|
||||
} else {
|
||||
query = selectRecentEventsSQL
|
||||
}
|
||||
|
||||
stmt, params, err := prepareWithFilters(
|
||||
s.db, txn, query,
|
||||
[]interface{}{
|
||||
roomID, r.Low(), r.High(),
|
||||
},
|
||||
eventFilter.Senders, eventFilter.NotSenders,
|
||||
eventFilter.Types, eventFilter.NotTypes,
|
||||
nil, eventFilter.Limit+1, FilterOrderDesc,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||
}
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed")
|
||||
events, err := rowsToStreamEvents(rows)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if chronologicalOrder {
|
||||
// The events need to be returned from oldest to latest, which isn't
|
||||
// necessary the way the SQL query returns them, so a sort is necessary to
|
||||
// ensure the events are in the right order in the slice.
|
||||
sort.SliceStable(events, func(i int, j int) bool {
|
||||
return events[i].StreamPosition < events[j].StreamPosition
|
||||
})
|
||||
}
|
||||
// we queried for 1 more than the limit, so if we returned one more mark limited=true
|
||||
limited := false
|
||||
if len(events) > eventFilter.Limit {
|
||||
limited = true
|
||||
// re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last.
|
||||
if chronologicalOrder {
|
||||
events = events[1:]
|
||||
} else {
|
||||
events = events[:len(events)-1]
|
||||
}
|
||||
}
|
||||
return events, limited, nil
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
|
||||
) ([]types.StreamEvent, error) {
|
||||
stmt, params, err := prepareWithFilters(
|
||||
s.db, txn, selectEarlyEventsSQL,
|
||||
[]interface{}{
|
||||
roomID, r.Low(), r.High(),
|
||||
},
|
||||
eventFilter.Senders, eventFilter.NotSenders,
|
||||
eventFilter.Types, eventFilter.NotTypes,
|
||||
nil, eventFilter.Limit, FilterOrderAsc,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||
}
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectEarlyEvents: rows.close() failed")
|
||||
events, err := rowsToStreamEvents(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// The events need to be returned from oldest to latest, which isn't
|
||||
// necessarily the way the SQL query returns them, so a sort is necessary to
|
||||
// ensure the events are in the right order in the slice.
|
||||
sort.SliceStable(events, func(i int, j int) bool {
|
||||
return events[i].StreamPosition < events[j].StreamPosition
|
||||
})
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// selectEvents returns the events for the given event IDs. If an event is
|
||||
// missing from the database, it will be omitted.
|
||||
func (s *outputRoomEventsStatements) SelectEvents(
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
) ([]types.StreamEvent, error) {
|
||||
var returnEvents []types.StreamEvent
|
||||
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
|
||||
for _, eventID := range eventIDs {
|
||||
rows, err := stmt.QueryContext(ctx, eventID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if streamEvents, err := rowsToStreamEvents(rows); err == nil {
|
||||
returnEvents = append(returnEvents, streamEvents...)
|
||||
}
|
||||
internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
||||
}
|
||||
return returnEvents, nil
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteEventsForRoomStmt).ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
|
||||
func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
|
||||
var result []types.StreamEvent
|
||||
for rows.Next() {
|
||||
var (
|
||||
eventID string
|
||||
streamPos types.StreamPosition
|
||||
eventBytes []byte
|
||||
excludeFromSync bool
|
||||
sessionID *int64
|
||||
txnID *string
|
||||
transactionID *api.TransactionID
|
||||
)
|
||||
if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO: Handle redacted events
|
||||
var ev gomatrixserverlib.HeaderedEvent
|
||||
if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if sessionID != nil && txnID != nil {
|
||||
transactionID = &api.TransactionID{
|
||||
SessionID: *sessionID,
|
||||
TransactionID: *txnID,
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, types.StreamEvent{
|
||||
HeaderedEvent: &ev,
|
||||
StreamPosition: streamPos,
|
||||
TransactionID: transactionID,
|
||||
ExcludeFromSync: excludeFromSync,
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func unmarshalStateIDs(addIDsJSON, delIDsJSON string) (addIDs []string, delIDs []string, err error) {
|
||||
if len(addIDsJSON) > 0 {
|
||||
if err = json.Unmarshal([]byte(addIDsJSON), &addIDs); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if len(delIDsJSON) > 0 {
|
||||
if err = json.Unmarshal([]byte(delIDsJSON), &delIDs); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
179
syncapi/storage/cosmosdb/output_room_events_topology_table.go
Normal file
179
syncapi/storage/cosmosdb/output_room_events_topology_table.go
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
// Copyright 2018 New Vector Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const outputRoomEventsTopologySchema = `
|
||||
-- Stores output room events received from the roomserver.
|
||||
CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology (
|
||||
event_id TEXT PRIMARY KEY,
|
||||
topological_position BIGINT NOT NULL,
|
||||
stream_position BIGINT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
|
||||
UNIQUE(topological_position, room_id, stream_position)
|
||||
);
|
||||
-- The topological order will be used in events selection and ordering
|
||||
-- CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, stream_position, room_id);
|
||||
`
|
||||
|
||||
const insertEventInTopologySQL = "" +
|
||||
"INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id, stream_position)" +
|
||||
" VALUES ($1, $2, $3, $4)" +
|
||||
" ON CONFLICT DO NOTHING"
|
||||
|
||||
const selectEventIDsInRangeASCSQL = "" +
|
||||
"SELECT event_id FROM syncapi_output_room_events_topology" +
|
||||
" WHERE room_id = $1 AND (" +
|
||||
"(topological_position > $2 AND topological_position < $3) OR" +
|
||||
"(topological_position = $4 AND stream_position <= $5)" +
|
||||
") ORDER BY topological_position ASC, stream_position ASC LIMIT $6"
|
||||
|
||||
const selectEventIDsInRangeDESCSQL = "" +
|
||||
"SELECT event_id FROM syncapi_output_room_events_topology" +
|
||||
" WHERE room_id = $1 AND (" +
|
||||
"(topological_position > $2 AND topological_position < $3) OR" +
|
||||
"(topological_position = $4 AND stream_position <= $5)" +
|
||||
") ORDER BY topological_position DESC, stream_position DESC LIMIT $6"
|
||||
|
||||
const selectPositionInTopologySQL = "" +
|
||||
"SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" +
|
||||
" WHERE event_id = $1"
|
||||
|
||||
const selectMaxPositionInTopologySQL = "" +
|
||||
"SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" +
|
||||
" WHERE room_id = $1 ORDER BY stream_position DESC"
|
||||
|
||||
const deleteTopologyForRoomSQL = "" +
|
||||
"DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1"
|
||||
|
||||
type outputRoomEventsTopologyStatements struct {
|
||||
db *sql.DB
|
||||
insertEventInTopologyStmt *sql.Stmt
|
||||
selectEventIDsInRangeASCStmt *sql.Stmt
|
||||
selectEventIDsInRangeDESCStmt *sql.Stmt
|
||||
selectPositionInTopologyStmt *sql.Stmt
|
||||
selectMaxPositionInTopologyStmt *sql.Stmt
|
||||
deleteTopologyForRoomStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
|
||||
s := &outputRoomEventsTopologyStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(outputRoomEventsTopologySchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// insertEventInTopology inserts the given event in the room's topology, based
|
||||
// on the event's depth.
|
||||
func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
|
||||
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition,
|
||||
) (types.StreamPosition, error) {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt).ExecContext(
|
||||
ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
|
||||
)
|
||||
return types.StreamPosition(event.Depth()), err
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
minDepth, maxDepth, maxStreamPos types.StreamPosition,
|
||||
limit int, chronologicalOrder bool,
|
||||
) (eventIDs []string, err error) {
|
||||
// Decide on the selection's order according to whether chronological order
|
||||
// is requested or not.
|
||||
var stmt *sql.Stmt
|
||||
if chronologicalOrder {
|
||||
stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt)
|
||||
} else {
|
||||
stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt)
|
||||
}
|
||||
|
||||
// Query the event IDs.
|
||||
rows, err := stmt.QueryContext(ctx, roomID, minDepth, maxDepth, maxDepth, maxStreamPos, limit)
|
||||
if err == sql.ErrNoRows {
|
||||
// If no event matched the request, return an empty slice.
|
||||
return []string{}, nil
|
||||
} else if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Return the IDs.
|
||||
var eventID string
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&eventID); err != nil {
|
||||
return
|
||||
}
|
||||
eventIDs = append(eventIDs, eventID)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// selectPositionInTopology returns the position of a given event in the
|
||||
// topology of the room it belongs to.
|
||||
func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology(
|
||||
ctx context.Context, txn *sql.Tx, eventID string,
|
||||
) (pos types.StreamPosition, spos types.StreamPosition, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectPositionInTopologyStmt)
|
||||
err = stmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (pos types.StreamPosition, spos types.StreamPosition, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt)
|
||||
err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
206
syncapi/storage/cosmosdb/peeks_table.go
Normal file
206
syncapi/storage/cosmosdb/peeks_table.go
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
)
|
||||
|
||||
const peeksSchema = `
|
||||
CREATE TABLE IF NOT EXISTS syncapi_peeks (
|
||||
id INTEGER,
|
||||
room_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
deleted BOOL NOT NULL DEFAULT false,
|
||||
-- When the peek was created in UNIX epoch ms.
|
||||
creation_ts INTEGER NOT NULL,
|
||||
UNIQUE(room_id, user_id, device_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS syncapi_peeks_room_id_idx ON syncapi_peeks(room_id);
|
||||
CREATE INDEX IF NOT EXISTS syncapi_peeks_user_id_device_id_idx ON syncapi_peeks(user_id, device_id);
|
||||
`
|
||||
|
||||
const insertPeekSQL = "" +
|
||||
"INSERT OR REPLACE INTO syncapi_peeks" +
|
||||
" (id, room_id, user_id, device_id, creation_ts, deleted)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, false)"
|
||||
|
||||
const deletePeekSQL = "" +
|
||||
"UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3 AND device_id = $4"
|
||||
|
||||
const deletePeeksSQL = "" +
|
||||
"UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3"
|
||||
|
||||
// we care about all the peeks which were created in this range, deleted in this range,
|
||||
// or were created before this range but haven't been deleted yet.
|
||||
// BEWARE: sqlite chokes on out of order substitution strings.
|
||||
const selectPeeksInRangeSQL = "" +
|
||||
"SELECT id, room_id, deleted FROM syncapi_peeks WHERE user_id = $1 AND device_id = $2 AND ((id <= $3 AND NOT deleted=true) OR (id > $3 AND id <= $4))"
|
||||
|
||||
const selectPeekingDevicesSQL = "" +
|
||||
"SELECT room_id, user_id, device_id FROM syncapi_peeks WHERE deleted=false"
|
||||
|
||||
const selectMaxPeekIDSQL = "" +
|
||||
"SELECT MAX(id) FROM syncapi_peeks"
|
||||
|
||||
type peekStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
insertPeekStmt *sql.Stmt
|
||||
deletePeekStmt *sql.Stmt
|
||||
deletePeeksStmt *sql.Stmt
|
||||
selectPeeksInRangeStmt *sql.Stmt
|
||||
selectPeekingDevicesStmt *sql.Stmt
|
||||
selectMaxPeekIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) {
|
||||
_, err := db.Exec(peeksSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &peekStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
}
|
||||
if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *peekStatements) InsertPeek(
|
||||
ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string,
|
||||
) (streamPos types.StreamPosition, err error) {
|
||||
streamPos, err = s.streamIDStatements.nextPDUID(ctx, txn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
_, err = sqlutil.TxStmt(txn, s.insertPeekStmt).ExecContext(ctx, streamPos, roomID, userID, deviceID, nowMilli)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *peekStatements) DeletePeek(
|
||||
ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string,
|
||||
) (streamPos types.StreamPosition, err error) {
|
||||
streamPos, err = s.streamIDStatements.nextPDUID(ctx, txn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = sqlutil.TxStmt(txn, s.deletePeekStmt).ExecContext(ctx, streamPos, roomID, userID, deviceID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *peekStatements) DeletePeeks(
|
||||
ctx context.Context, txn *sql.Tx, roomID, userID string,
|
||||
) (types.StreamPosition, error) {
|
||||
streamPos, err := s.streamIDStatements.nextPDUID(ctx, txn)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
result, err := sqlutil.TxStmt(txn, s.deletePeeksStmt).ExecContext(ctx, streamPos, roomID, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if numAffected == 0 {
|
||||
return 0, sql.ErrNoRows
|
||||
}
|
||||
return streamPos, nil
|
||||
}
|
||||
|
||||
func (s *peekStatements) SelectPeeksInRange(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID string, r types.Range,
|
||||
) (peeks []types.Peek, err error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectPeeksInRangeStmt).QueryContext(ctx, userID, deviceID, r.Low(), r.High())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectPeeksInRange: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
peek := types.Peek{}
|
||||
var id types.StreamPosition
|
||||
if err = rows.Scan(&id, &peek.RoomID, &peek.Deleted); err != nil {
|
||||
return
|
||||
}
|
||||
peek.New = (id > r.Low() && id <= r.High()) && !peek.Deleted
|
||||
peeks = append(peeks, peek)
|
||||
}
|
||||
|
||||
return peeks, rows.Err()
|
||||
}
|
||||
|
||||
func (s *peekStatements) SelectPeekingDevices(
|
||||
ctx context.Context,
|
||||
) (peekingDevices map[string][]types.PeekingDevice, err error) {
|
||||
rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectPeekingDevices: rows.close() failed")
|
||||
|
||||
result := make(map[string][]types.PeekingDevice)
|
||||
for rows.Next() {
|
||||
var roomID, userID, deviceID string
|
||||
if err := rows.Scan(&roomID, &userID, &deviceID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
devices := result[roomID]
|
||||
devices = append(devices, types.PeekingDevice{UserID: userID, DeviceID: deviceID})
|
||||
result[roomID] = devices
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *peekStatements) SelectMaxPeekID(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) (id int64, err error) {
|
||||
var nullableID sql.NullInt64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMaxPeekIDStmt)
|
||||
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
|
||||
if nullableID.Valid {
|
||||
id = nullableID.Int64
|
||||
}
|
||||
return
|
||||
}
|
||||
141
syncapi/storage/cosmosdb/receipt_table.go
Normal file
141
syncapi/storage/cosmosdb/receipt_table.go
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/eduserver/api"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const receiptsSchema = `
|
||||
-- Stores data about receipts
|
||||
CREATE TABLE IF NOT EXISTS syncapi_receipts (
|
||||
-- The ID
|
||||
id BIGINT,
|
||||
room_id TEXT NOT NULL,
|
||||
receipt_type TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
receipt_ts BIGINT NOT NULL,
|
||||
CONSTRAINT syncapi_receipts_unique UNIQUE (room_id, receipt_type, user_id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS syncapi_receipts_room_id_idx ON syncapi_receipts(room_id);
|
||||
`
|
||||
|
||||
const upsertReceipt = "" +
|
||||
"INSERT INTO syncapi_receipts" +
|
||||
" (id, room_id, receipt_type, user_id, event_id, receipt_ts)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT (room_id, receipt_type, user_id)" +
|
||||
" DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9"
|
||||
|
||||
const selectRoomReceipts = "" +
|
||||
"SELECT id, room_id, receipt_type, user_id, event_id, receipt_ts" +
|
||||
" FROM syncapi_receipts" +
|
||||
" WHERE id > $1 and room_id in ($2)"
|
||||
|
||||
const selectMaxReceiptIDSQL = "" +
|
||||
"SELECT MAX(id) FROM syncapi_receipts"
|
||||
|
||||
type receiptStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
upsertReceipt *sql.Stmt
|
||||
selectRoomReceipts *sql.Stmt
|
||||
selectMaxReceiptID *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) {
|
||||
_, err := db.Exec(receiptsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := &receiptStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
}
|
||||
if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil {
|
||||
return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err)
|
||||
}
|
||||
if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil {
|
||||
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
|
||||
}
|
||||
if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil {
|
||||
return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// UpsertReceipt creates new user receipts
|
||||
func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) {
|
||||
pos, err = r.streamIDStatements.nextReceiptID(ctx, txn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, r.upsertReceipt)
|
||||
_, err = stmt.ExecContext(ctx, pos, roomId, receiptType, userId, eventId, timestamp, pos, eventId, timestamp)
|
||||
return
|
||||
}
|
||||
|
||||
// SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp
|
||||
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) {
|
||||
selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1)
|
||||
lastPos := streamPos
|
||||
params := make([]interface{}, len(roomIDs)+1)
|
||||
params[0] = streamPos
|
||||
for k, v := range roomIDs {
|
||||
params[k+1] = v
|
||||
}
|
||||
rows, err := r.db.QueryContext(ctx, selectSQL, params...)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed")
|
||||
var res []api.OutputReceiptEvent
|
||||
for rows.Next() {
|
||||
r := api.OutputReceiptEvent{}
|
||||
var id types.StreamPosition
|
||||
err = rows.Scan(&id, &r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp)
|
||||
if err != nil {
|
||||
return 0, res, fmt.Errorf("unable to scan row to api.Receipts: %w", err)
|
||||
}
|
||||
res = append(res, r)
|
||||
if id > lastPos {
|
||||
lastPos = id
|
||||
}
|
||||
}
|
||||
return lastPos, res, rows.Err()
|
||||
}
|
||||
|
||||
func (s *receiptStatements) SelectMaxReceiptID(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) (id int64, err error) {
|
||||
var nullableID sql.NullInt64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMaxReceiptID)
|
||||
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
|
||||
if nullableID.Valid {
|
||||
id = nullableID.Int64
|
||||
}
|
||||
return
|
||||
}
|
||||
160
syncapi/storage/cosmosdb/send_to_device_table.go
Normal file
160
syncapi/storage/cosmosdb/send_to_device_table.go
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const sendToDeviceSchema = `
|
||||
-- Stores send-to-device messages.
|
||||
CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
|
||||
-- The ID that uniquely identifies this message.
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The user ID to send the message to.
|
||||
user_id TEXT NOT NULL,
|
||||
-- The device ID to send the message to.
|
||||
device_id TEXT NOT NULL,
|
||||
-- The event content JSON.
|
||||
content TEXT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertSendToDeviceMessageSQL = `
|
||||
INSERT INTO syncapi_send_to_device (user_id, device_id, content)
|
||||
VALUES ($1, $2, $3)
|
||||
`
|
||||
|
||||
const selectSendToDeviceMessagesSQL = `
|
||||
SELECT id, user_id, device_id, content
|
||||
FROM syncapi_send_to_device
|
||||
WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
|
||||
ORDER BY id DESC
|
||||
`
|
||||
|
||||
const deleteSendToDeviceMessagesSQL = `
|
||||
DELETE FROM syncapi_send_to_device
|
||||
WHERE user_id = $1 AND device_id = $2 AND id < $3
|
||||
`
|
||||
|
||||
const selectMaxSendToDeviceIDSQL = "" +
|
||||
"SELECT MAX(id) FROM syncapi_send_to_device"
|
||||
|
||||
type sendToDeviceStatements struct {
|
||||
db *sql.DB
|
||||
insertSendToDeviceMessageStmt *sql.Stmt
|
||||
selectSendToDeviceMessagesStmt *sql.Stmt
|
||||
deleteSendToDeviceMessagesStmt *sql.Stmt
|
||||
selectMaxSendToDeviceIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
||||
s := &sendToDeviceStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(sendToDeviceSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMaxSendToDeviceIDStmt, err = db.Prepare(selectMaxSendToDeviceIDSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
|
||||
) (pos types.StreamPosition, err error) {
|
||||
var result sql.Result
|
||||
result, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
|
||||
if p, err := result.LastInsertId(); err != nil {
|
||||
return 0, err
|
||||
} else {
|
||||
pos = types.StreamPosition(p)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition,
|
||||
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
var id types.StreamPosition
|
||||
var userID, deviceID, content string
|
||||
if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to retrieve send-to-device message")
|
||||
return
|
||||
}
|
||||
if id > lastPos {
|
||||
lastPos = id
|
||||
}
|
||||
event := types.SendToDeviceEvent{
|
||||
ID: id,
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
}
|
||||
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message")
|
||||
continue
|
||||
}
|
||||
events = append(events, event)
|
||||
}
|
||||
if lastPos == 0 {
|
||||
lastPos = to
|
||||
}
|
||||
return lastPos, events, rows.Err()
|
||||
}
|
||||
|
||||
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID string, pos types.StreamPosition,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, userID, deviceID, pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) (id int64, err error) {
|
||||
var nullableID sql.NullInt64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt)
|
||||
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
|
||||
if nullableID.Valid {
|
||||
id = nullableID.Int64
|
||||
}
|
||||
return
|
||||
}
|
||||
94
syncapi/storage/cosmosdb/stream_id_table.go
Normal file
94
syncapi/storage/cosmosdb/stream_id_table.go
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
)
|
||||
|
||||
const streamIDTableSchema = `
|
||||
-- Global stream ID counter, used by other tables.
|
||||
CREATE TABLE IF NOT EXISTS syncapi_stream_id (
|
||||
stream_name TEXT NOT NULL PRIMARY KEY,
|
||||
stream_id INT DEFAULT 0,
|
||||
|
||||
UNIQUE(stream_name)
|
||||
);
|
||||
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0)
|
||||
ON CONFLICT DO NOTHING;
|
||||
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("receipt", 0)
|
||||
ON CONFLICT DO NOTHING;
|
||||
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("accountdata", 0)
|
||||
ON CONFLICT DO NOTHING;
|
||||
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("invite", 0)
|
||||
ON CONFLICT DO NOTHING;
|
||||
`
|
||||
|
||||
const increaseStreamIDStmt = "" +
|
||||
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1"
|
||||
|
||||
const selectStreamIDStmt = "" +
|
||||
"SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1"
|
||||
|
||||
type streamIDStatements struct {
|
||||
db *sql.DB
|
||||
increaseStreamIDStmt *sql.Stmt
|
||||
selectStreamIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
||||
s.db = db
|
||||
_, err = db.Exec(streamIDTableSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.increaseStreamIDStmt, err = db.Prepare(increaseStreamIDStmt); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectStreamIDStmt, err = db.Prepare(selectStreamIDStmt); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
|
||||
if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
|
||||
return
|
||||
}
|
||||
err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
|
||||
if _, err = increaseStmt.ExecContext(ctx, "receipt"); err != nil {
|
||||
return
|
||||
}
|
||||
err = selectStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
|
||||
if _, err = increaseStmt.ExecContext(ctx, "invite"); err != nil {
|
||||
return
|
||||
}
|
||||
err = selectStmt.QueryRowContext(ctx, "invite").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
|
||||
if _, err = increaseStmt.ExecContext(ctx, "accountdata"); err != nil {
|
||||
return
|
||||
}
|
||||
err = selectStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
|
||||
return
|
||||
}
|
||||
130
syncapi/storage/cosmosdb/syncserver.go
Normal file
130
syncapi/storage/cosmosdb/syncserver.go
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"database/sql"
|
||||
|
||||
// Import the sqlite3 package
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
|
||||
)
|
||||
|
||||
// SyncServerDatasource represents a sync server datasource which manages
|
||||
// both the database for PDUs and caches for EDUs.
|
||||
type SyncServerDatasource struct {
|
||||
shared.Database
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
sqlutil.PartitionOffsetStatements
|
||||
streamID streamIDStatements
|
||||
}
|
||||
|
||||
// NewDatabase creates a new sync server database
|
||||
// nolint: gocyclo
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) {
|
||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
var d SyncServerDatasource
|
||||
var err error
|
||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.writer = sqlutil.NewExclusiveWriter()
|
||||
if err = d.prepare(dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
|
||||
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = d.streamID.prepare(d.db); err != nil {
|
||||
return err
|
||||
}
|
||||
accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
events, err := NewSqliteEventsTable(d.db, &d.streamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
roomState, err := NewSqliteCurrentRoomStateTable(d.db, &d.streamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
invites, err := NewSqliteInvitesTable(d.db, &d.streamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
peeks, err := NewSqlitePeeksTable(d.db, &d.streamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
topology, err := NewSqliteTopologyTable(d.db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bwExtrem, err := NewSqliteBackwardsExtremitiesTable(d.db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sendToDevice, err := NewSqliteSendToDeviceTable(d.db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filter, err := NewSqliteFilterTable(d.db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
receipts, err := NewSqliteReceiptsTable(d.db, &d.streamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
memberships, err := NewSqliteMembershipsTable(d.db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m := sqlutil.NewMigrations()
|
||||
deltas.LoadFixSequences(m)
|
||||
deltas.LoadRemoveSendToDeviceSentColumn(m)
|
||||
if err = m.RunDeltas(d.db, dbProperties); err != nil {
|
||||
return err
|
||||
}
|
||||
d.Database = shared.Database{
|
||||
DB: d.db,
|
||||
Writer: d.writer,
|
||||
Invites: invites,
|
||||
Peeks: peeks,
|
||||
AccountData: accountData,
|
||||
OutputEvents: events,
|
||||
BackwardExtremities: bwExtrem,
|
||||
CurrentRoomState: roomState,
|
||||
Topology: topology,
|
||||
Filter: filter,
|
||||
SendToDevice: sendToDevice,
|
||||
Receipts: receipts,
|
||||
Memberships: memberships,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -20,6 +20,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/cosmosdb"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
|
||||
)
|
||||
|
|
@ -27,6 +28,8 @@ import (
|
|||
// NewSyncServerDatasource opens a database connection.
|
||||
func NewSyncServerDatasource(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsCosmosDB():
|
||||
return cosmosdb.NewDatabase(dbProperties)
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ import (
|
|||
// NewPublicRoomsServerDatabase opens a database connection.
|
||||
func NewSyncServerDatasource(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsCosmosDB():
|
||||
return nil, fmt.Errorf("can't use CosmosDB implementation")
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
|
|
|
|||
134
userapi/storage/accounts/cosmosdb/account_data_table.go
Normal file
134
userapi/storage/accounts/cosmosdb/account_data_table.go
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
const accountDataSchema = `
|
||||
-- Stores data about accounts data.
|
||||
CREATE TABLE IF NOT EXISTS account_data (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL,
|
||||
-- The room ID for this data (empty string if not specific to a room)
|
||||
room_id TEXT,
|
||||
-- The account data type
|
||||
type TEXT NOT NULL,
|
||||
-- The account data content
|
||||
content TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY(localpart, room_id, type)
|
||||
);
|
||||
`
|
||||
|
||||
const insertAccountDataSQL = `
|
||||
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
||||
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
|
||||
`
|
||||
|
||||
const selectAccountDataSQL = "" +
|
||||
"SELECT room_id, type, content FROM account_data WHERE localpart = $1"
|
||||
|
||||
const selectAccountDataByTypeSQL = "" +
|
||||
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||
|
||||
type accountDataStatements struct {
|
||||
db *sql.DB
|
||||
insertAccountDataStmt *sql.Stmt
|
||||
selectAccountDataStmt *sql.Stmt
|
||||
selectAccountDataByTypeStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
||||
s.db = db
|
||||
_, err = db.Exec(accountDataSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) insertAccountData(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) selectAccountData(
|
||||
ctx context.Context, localpart string,
|
||||
) (
|
||||
/* global */ map[string]json.RawMessage,
|
||||
/* rooms */ map[string]map[string]json.RawMessage,
|
||||
error,
|
||||
) {
|
||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
global := map[string]json.RawMessage{}
|
||||
rooms := map[string]map[string]json.RawMessage{}
|
||||
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
var dataType string
|
||||
var content []byte
|
||||
|
||||
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if roomID != "" {
|
||||
if _, ok := rooms[roomID]; !ok {
|
||||
rooms[roomID] = map[string]json.RawMessage{}
|
||||
}
|
||||
rooms[roomID][dataType] = content
|
||||
} else {
|
||||
global[dataType] = content
|
||||
}
|
||||
}
|
||||
|
||||
return global, rooms, nil
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) selectAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
) (data json.RawMessage, err error) {
|
||||
var bytes []byte
|
||||
stmt := s.selectAccountDataByTypeStmt
|
||||
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return
|
||||
}
|
||||
data = json.RawMessage(bytes)
|
||||
return
|
||||
}
|
||||
187
userapi/storage/accounts/cosmosdb/accounts_table.go
Normal file
187
userapi/storage/accounts/cosmosdb/accounts_table.go
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const accountsSchema = `
|
||||
-- Stores data about accounts.
|
||||
CREATE TABLE IF NOT EXISTS account_accounts (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
-- When this account was first created, as a unix timestamp (ms resolution).
|
||||
created_ts BIGINT NOT NULL,
|
||||
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
||||
password_hash TEXT,
|
||||
-- Identifies which application service this account belongs to, if any.
|
||||
appservice_id TEXT,
|
||||
-- If the account is currently active
|
||||
is_deactivated BOOLEAN DEFAULT 0
|
||||
-- TODO:
|
||||
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
|
||||
);
|
||||
`
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const updatePasswordSQL = "" +
|
||||
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
|
||||
const deactivateAccountSQL = "" +
|
||||
"UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
||||
|
||||
const selectAccountByLocalpartSQL = "" +
|
||||
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
|
||||
|
||||
const selectPasswordHashSQL = "" +
|
||||
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||
|
||||
const selectNewNumericLocalpartSQL = "" +
|
||||
"SELECT COUNT(localpart) FROM account_accounts"
|
||||
|
||||
type accountsStatements struct {
|
||||
db *sql.DB
|
||||
insertAccountStmt *sql.Stmt
|
||||
updatePasswordStmt *sql.Stmt
|
||||
deactivateAccountStmt *sql.Stmt
|
||||
selectAccountByLocalpartStmt *sql.Stmt
|
||||
selectPasswordHashStmt *sql.Stmt
|
||||
selectNewNumericLocalpartStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func (s *accountsStatements) execSchema(db *sql.DB) error {
|
||||
_, err := db.Exec(accountsSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||
s.db = db
|
||||
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deactivateAccountStmt, err = db.Prepare(deactivateAccountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
s.serverName = server
|
||||
return
|
||||
}
|
||||
|
||||
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||
// on success.
|
||||
func (s *accountsStatements) insertAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := s.insertAccountStmt
|
||||
|
||||
var err error
|
||||
if appserviceID == "" {
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
||||
} else {
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.Account{
|
||||
Localpart: localpart,
|
||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
ServerName: s.serverName,
|
||||
AppServiceID: appserviceID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *accountsStatements) updatePassword(
|
||||
ctx context.Context, localpart, passwordHash string,
|
||||
) (err error) {
|
||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) deactivateAccount(
|
||||
ctx context.Context, localpart string,
|
||||
) (err error) {
|
||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) selectPasswordHash(
|
||||
ctx context.Context, localpart string,
|
||||
) (hash string, err error) {
|
||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) selectAccountByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (*api.Account, error) {
|
||||
var appserviceIDPtr sql.NullString
|
||||
var acc api.Account
|
||||
|
||||
stmt := s.selectAccountByLocalpartStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if appserviceIDPtr.Valid {
|
||||
acc.AppServiceID = appserviceIDPtr.String
|
||||
}
|
||||
|
||||
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
acc.ServerName = s.serverName
|
||||
|
||||
return &acc, nil
|
||||
}
|
||||
|
||||
func (s *accountsStatements) selectNewNumericLocalpart(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) (id int64, err error) {
|
||||
stmt := s.selectNewNumericLocalpartStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||
return
|
||||
}
|
||||
21
userapi/storage/accounts/cosmosdb/constraint_wasm.go
Normal file
21
userapi/storage/accounts/cosmosdb/constraint_wasm.go
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// +build wasm
|
||||
|
||||
package cosmosdb
|
||||
|
||||
func isConstraintError(err error) bool {
|
||||
return false
|
||||
}
|
||||
86
userapi/storage/accounts/cosmosdb/openid_table.go
Normal file
86
userapi/storage/accounts/cosmosdb/openid_table.go
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const openIDTokenSchema = `
|
||||
-- Stores data about accounts.
|
||||
CREATE TABLE IF NOT EXISTS open_id_tokens (
|
||||
-- The value of the token issued to a user
|
||||
token TEXT NOT NULL PRIMARY KEY,
|
||||
-- The Matrix user ID for this account
|
||||
localpart TEXT NOT NULL,
|
||||
-- When the token expires, as a unix timestamp (ms resolution).
|
||||
token_expires_at_ms BIGINT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertTokenSQL = "" +
|
||||
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||
|
||||
const selectTokenSQL = "" +
|
||||
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
||||
|
||||
type tokenStatements struct {
|
||||
db *sql.DB
|
||||
insertTokenStmt *sql.Stmt
|
||||
selectTokenStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||
s.db = db
|
||||
_, err = db.Exec(openIDTokenSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil {
|
||||
return
|
||||
}
|
||||
s.serverName = server
|
||||
return
|
||||
}
|
||||
|
||||
// insertToken inserts a new OpenID Connect token to the DB.
|
||||
// Returns new token, otherwise returns error if the token already exists.
|
||||
func (s *tokenStatements) insertToken(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
token, localpart string,
|
||||
expiresAtMS int64,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
|
||||
return
|
||||
}
|
||||
|
||||
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
|
||||
// Returns the existing token's attributes, or err if no token is found
|
||||
func (s *tokenStatements) selectOpenIDTokenAtrributes(
|
||||
ctx context.Context,
|
||||
token string,
|
||||
) (*api.OpenIDTokenAttributes, error) {
|
||||
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||
&openIDTokenAttrs.UserID,
|
||||
&openIDTokenAttrs.ExpiresAtMS,
|
||||
)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve token from the db")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &openIDTokenAttrs, nil
|
||||
}
|
||||
143
userapi/storage/accounts/cosmosdb/profile_table.go
Normal file
143
userapi/storage/accounts/cosmosdb/profile_table.go
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
const profilesSchema = `
|
||||
-- Stores data about accounts profiles.
|
||||
CREATE TABLE IF NOT EXISTS account_profiles (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
-- The display name for this account
|
||||
display_name TEXT,
|
||||
-- The URL of the avatar for this account
|
||||
avatar_url TEXT
|
||||
);
|
||||
`
|
||||
|
||||
const insertProfileSQL = "" +
|
||||
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
||||
|
||||
const selectProfileByLocalpartSQL = "" +
|
||||
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
|
||||
|
||||
const setAvatarURLSQL = "" +
|
||||
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
|
||||
|
||||
const setDisplayNameSQL = "" +
|
||||
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
|
||||
|
||||
const selectProfilesBySearchSQL = "" +
|
||||
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
|
||||
type profilesStatements struct {
|
||||
db *sql.DB
|
||||
insertProfileStmt *sql.Stmt
|
||||
selectProfileByLocalpartStmt *sql.Stmt
|
||||
setAvatarURLStmt *sql.Stmt
|
||||
setDisplayNameStmt *sql.Stmt
|
||||
selectProfilesBySearchStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
||||
s.db = db
|
||||
_, err = db.Exec(profilesSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *profilesStatements) insertProfile(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *profilesStatements) selectProfileByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (*authtypes.Profile, error) {
|
||||
var profile authtypes.Profile
|
||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
|
||||
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
func (s *profilesStatements) setAvatarURL(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
||||
_, err = stmt.ExecContext(ctx, avatarURL, localpart)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *profilesStatements) setDisplayName(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
||||
_, err = stmt.ExecContext(ctx, displayName, localpart)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *profilesStatements) selectProfilesBySearch(
|
||||
ctx context.Context, searchString string, limit int,
|
||||
) ([]authtypes.Profile, error) {
|
||||
var profiles []authtypes.Profile
|
||||
// The fmt.Sprintf directive below is building a parameter for the
|
||||
// "LIKE" condition in the SQL query. %% escapes the % char, so the
|
||||
// statement in the end will look like "LIKE %searchString%".
|
||||
rows, err := s.selectProfilesBySearchStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var profile authtypes.Profile
|
||||
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
profiles = append(profiles, profile)
|
||||
}
|
||||
return profiles, nil
|
||||
}
|
||||
410
userapi/storage/accounts/cosmosdb/storage.go
Normal file
410
userapi/storage/accounts/cosmosdb/storage.go
Normal file
|
|
@ -0,0 +1,410 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Database represents an account database
|
||||
type Database struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
|
||||
sqlutil.PartitionOffsetStatements
|
||||
accounts accountsStatements
|
||||
profiles profilesStatements
|
||||
accountDatas accountDataStatements
|
||||
threepids threepidStatements
|
||||
openIDTokens tokenStatements
|
||||
serverName gomatrixserverlib.ServerName
|
||||
bcryptCost int
|
||||
openIDTokenLifetimeMS int64
|
||||
|
||||
accountsMu sync.Mutex
|
||||
profilesMu sync.Mutex
|
||||
accountDatasMu sync.Mutex
|
||||
threepidsMu sync.Mutex
|
||||
}
|
||||
|
||||
// NewDatabase creates a new accounts and profiles database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
|
||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := &Database{
|
||||
serverName: serverName,
|
||||
db: db,
|
||||
writer: sqlutil.NewExclusiveWriter(),
|
||||
bcryptCost: bcryptCost,
|
||||
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
}
|
||||
|
||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
||||
if err = d.accounts.execSchema(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := sqlutil.NewMigrations()
|
||||
deltas.LoadIsActive(m)
|
||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
partitions := sqlutil.PartitionOffsetStatements{}
|
||||
if err = partitions.Prepare(db, d.writer, "account"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.accounts.prepare(db, serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.profiles.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.accountDatas.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.threepids.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.openIDTokens.prepare(db, serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||
func (d *Database) GetAccountByPassword(
|
||||
ctx context.Context, localpart, plaintextPassword string,
|
||||
) (*api.Account, error) {
|
||||
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
||||
func (d *Database) GetProfileByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (*authtypes.Profile, error) {
|
||||
return d.profiles.selectProfileByLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
||||
// localpart. Returns an error if something went wrong with the SQL query
|
||||
func (d *Database) SetAvatarURL(
|
||||
ctx context.Context, localpart string, avatarURL string,
|
||||
) error {
|
||||
d.profilesMu.Lock()
|
||||
defer d.profilesMu.Unlock()
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL)
|
||||
})
|
||||
}
|
||||
|
||||
// SetDisplayName updates the display name of the profile associated with the given
|
||||
// localpart. Returns an error if something went wrong with the SQL query
|
||||
func (d *Database) SetDisplayName(
|
||||
ctx context.Context, localpart string, displayName string,
|
||||
) error {
|
||||
d.profilesMu.Lock()
|
||||
defer d.profilesMu.Unlock()
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.profiles.setDisplayName(ctx, txn, localpart, displayName)
|
||||
})
|
||||
}
|
||||
|
||||
// SetPassword sets the account password to the given hash.
|
||||
func (d *Database) SetPassword(
|
||||
ctx context.Context, localpart, plaintextPassword string,
|
||||
) error {
|
||||
hash, err := d.hashPassword(plaintextPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = d.accounts.updatePassword(ctx, localpart, hash)
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateGuestAccount makes a new guest account and creates an empty profile
|
||||
// for this account.
|
||||
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {
|
||||
// We need to lock so we sequentially create numeric localparts. If we don't, two calls to
|
||||
// this function will cause the same number to be selected and one will fail with 'database is locked'
|
||||
// when the first txn upgrades to a write txn. We also need to lock the account creation else we can
|
||||
// race with CreateAccount
|
||||
// We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed.
|
||||
d.profilesMu.Lock()
|
||||
d.accountDatasMu.Lock()
|
||||
d.accountsMu.Lock()
|
||||
defer d.profilesMu.Unlock()
|
||||
defer d.accountDatasMu.Unlock()
|
||||
defer d.accountsMu.Unlock()
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
var numLocalpart int64
|
||||
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
localpart := strconv.FormatInt(numLocalpart, 10)
|
||||
acc, err = d.createAccount(ctx, txn, localpart, "", "")
|
||||
return err
|
||||
})
|
||||
return acc, err
|
||||
}
|
||||
|
||||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, ErrUserExists.
|
||||
func (d *Database) CreateAccount(
|
||||
ctx context.Context, localpart, plaintextPassword, appserviceID string,
|
||||
) (acc *api.Account, err error) {
|
||||
// Create one account at a time else we can get 'database is locked'.
|
||||
d.profilesMu.Lock()
|
||||
d.accountDatasMu.Lock()
|
||||
d.accountsMu.Lock()
|
||||
defer d.profilesMu.Unlock()
|
||||
defer d.accountDatasMu.Unlock()
|
||||
defer d.accountsMu.Unlock()
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// WARNING! This function assumes that the relevant mutexes have already
|
||||
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
||||
func (d *Database) createAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
|
||||
) (*api.Account, error) {
|
||||
var err error
|
||||
var account *api.Account
|
||||
// Generate a password hash if this is not a password-less user
|
||||
hash := ""
|
||||
if plaintextPassword != "" {
|
||||
hash, err = d.hashPassword(plaintextPassword)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil {
|
||||
return nil, sqlutil.ErrUserExists
|
||||
}
|
||||
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
|
||||
"global": {
|
||||
"content": [],
|
||||
"override": [],
|
||||
"room": [],
|
||||
"sender": [],
|
||||
"underride": []
|
||||
}
|
||||
}`)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// SaveAccountData saves new account data for a given user and a given room.
|
||||
// If the account data is not specific to a room, the room ID should be an empty string
|
||||
// If an account data already exists for a given set (user, room, data type), it will
|
||||
// update the corresponding row with the new content
|
||||
// Returns a SQL error if there was an issue with the insertion/update
|
||||
func (d *Database) SaveAccountData(
|
||||
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
d.accountDatasMu.Lock()
|
||||
defer d.accountDatasMu.Unlock()
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
||||
})
|
||||
}
|
||||
|
||||
// GetAccountData returns account data related to a given localpart
|
||||
// If no account data could be found, returns an empty arrays
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
||||
global map[string]json.RawMessage,
|
||||
rooms map[string]map[string]json.RawMessage,
|
||||
err error,
|
||||
) {
|
||||
return d.accountDatas.selectAccountData(ctx, localpart)
|
||||
}
|
||||
|
||||
// GetAccountDataByType returns account data matching a given
|
||||
// localpart, room ID and type.
|
||||
// If no account data could be found, returns nil
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
func (d *Database) GetAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
) (data json.RawMessage, err error) {
|
||||
return d.accountDatas.selectAccountDataByType(
|
||||
ctx, localpart, roomID, dataType,
|
||||
)
|
||||
}
|
||||
|
||||
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
|
||||
func (d *Database) GetNewNumericLocalpart(
|
||||
ctx context.Context,
|
||||
) (int64, error) {
|
||||
return d.accounts.selectNewNumericLocalpart(ctx, nil)
|
||||
}
|
||||
|
||||
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
|
||||
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost)
|
||||
return string(hashBytes), err
|
||||
}
|
||||
|
||||
// Err3PIDInUse is the error returned when trying to save an association involving
|
||||
// a third-party identifier which is already associated to a local user.
|
||||
var Err3PIDInUse = errors.New("This third-party identifier is already in use")
|
||||
|
||||
// SaveThreePIDAssociation saves the association between a third party identifier
|
||||
// and a local Matrix user (identified by the user's ID's local part).
|
||||
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
func (d *Database) SaveThreePIDAssociation(
|
||||
ctx context.Context, threepid, localpart, medium string,
|
||||
) (err error) {
|
||||
d.threepidsMu.Lock()
|
||||
defer d.threepidsMu.Unlock()
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
user, err := d.threepids.selectLocalpartForThreePID(
|
||||
ctx, txn, threepid, medium,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(user) > 0 {
|
||||
return Err3PIDInUse
|
||||
}
|
||||
|
||||
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveThreePIDAssociation removes the association involving a given third-party
|
||||
// identifier.
|
||||
// If no association exists involving this third-party identifier, returns nothing.
|
||||
// If there was a problem talking to the database, returns an error.
|
||||
func (d *Database) RemoveThreePIDAssociation(
|
||||
ctx context.Context, threepid string, medium string,
|
||||
) (err error) {
|
||||
d.threepidsMu.Lock()
|
||||
defer d.threepidsMu.Unlock()
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.threepids.deleteThreePID(ctx, txn, threepid, medium)
|
||||
})
|
||||
}
|
||||
|
||||
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
|
||||
// identifier.
|
||||
// If no association involves the given third-party idenfitier, returns an empty
|
||||
// string.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
func (d *Database) GetLocalpartForThreePID(
|
||||
ctx context.Context, threepid string, medium string,
|
||||
) (localpart string, err error) {
|
||||
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
|
||||
}
|
||||
|
||||
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
|
||||
// a given local user.
|
||||
// If no association is known for this user, returns an empty slice.
|
||||
// Returns an error if there was an issue talking to the database.
|
||||
func (d *Database) GetThreePIDsForLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (threepids []authtypes.ThreePID, err error) {
|
||||
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
// CheckAccountAvailability checks if the username/localpart is already present
|
||||
// in the database.
|
||||
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
||||
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||
if err == sql.ErrNoRows {
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
// GetAccountByLocalpart returns the account associated with the given localpart.
|
||||
// This function assumes the request is authenticated or the account data is used only internally.
|
||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
|
||||
) (*api.Account, error) {
|
||||
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
// SearchProfiles returns all profiles where the provided localpart or display name
|
||||
// match any part of the profiles in the database.
|
||||
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
|
||||
) ([]authtypes.Profile, error) {
|
||||
return d.profiles.selectProfilesBySearch(ctx, searchString, limit)
|
||||
}
|
||||
|
||||
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
|
||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
||||
return d.accounts.deactivateAccount(ctx, localpart)
|
||||
}
|
||||
|
||||
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
|
||||
func (d *Database) CreateOpenIDToken(
|
||||
ctx context.Context,
|
||||
token, localpart string,
|
||||
) (int64, error) {
|
||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
|
||||
err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
|
||||
})
|
||||
return expiresAtMS, err
|
||||
}
|
||||
|
||||
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
|
||||
func (d *Database) GetOpenIDTokenAttributes(
|
||||
ctx context.Context,
|
||||
token string,
|
||||
) (*api.OpenIDTokenAttributes, error) {
|
||||
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
|
||||
}
|
||||
133
userapi/storage/accounts/cosmosdb/threepid_table.go
Normal file
133
userapi/storage/accounts/cosmosdb/threepid_table.go
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
)
|
||||
|
||||
const threepidSchema = `
|
||||
-- Stores data about third party identifiers
|
||||
CREATE TABLE IF NOT EXISTS account_threepid (
|
||||
-- The third party identifier
|
||||
threepid TEXT NOT NULL,
|
||||
-- The 3PID medium
|
||||
medium TEXT NOT NULL DEFAULT 'email',
|
||||
-- The localpart of the Matrix user ID associated to this 3PID
|
||||
localpart TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY(threepid, medium)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart);
|
||||
`
|
||||
|
||||
const selectLocalpartForThreePIDSQL = "" +
|
||||
"SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||
|
||||
const selectThreePIDsForLocalpartSQL = "" +
|
||||
"SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
|
||||
|
||||
const insertThreePIDSQL = "" +
|
||||
"INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteThreePIDSQL = "" +
|
||||
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||
|
||||
type threepidStatements struct {
|
||||
db *sql.DB
|
||||
selectLocalpartForThreePIDStmt *sql.Stmt
|
||||
selectThreePIDsForLocalpartStmt *sql.Stmt
|
||||
insertThreePIDStmt *sql.Stmt
|
||||
deleteThreePIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
|
||||
s.db = db
|
||||
_, err = db.Exec(threepidSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s *threepidStatements) selectLocalpartForThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
||||
) (localpart string, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
||||
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *threepidStatements) selectThreePIDsForLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) (threepids []authtypes.ThreePID, err error) {
|
||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectThreePIDsForLocalpart: rows.close() failed")
|
||||
|
||||
threepids = []authtypes.ThreePID{}
|
||||
for rows.Next() {
|
||||
var threepid string
|
||||
var medium string
|
||||
if err = rows.Scan(&threepid, &medium); err != nil {
|
||||
return
|
||||
}
|
||||
threepids = append(threepids, authtypes.ThreePID{
|
||||
Address: threepid,
|
||||
Medium: medium,
|
||||
})
|
||||
}
|
||||
return threepids, rows.Err()
|
||||
}
|
||||
|
||||
func (s *threepidStatements) insertThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
||||
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *threepidStatements) deleteThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
|
||||
_, err = stmt.ExecContext(ctx, threepid, medium)
|
||||
return err
|
||||
}
|
||||
|
|
@ -20,6 +20,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts/cosmosdb"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts/postgres"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -29,6 +30,8 @@ import (
|
|||
// and sets postgres connection parameters
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsCosmosDB():
|
||||
return cosmosdb.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ func NewDatabase(
|
|||
openIDTokenLifetimeMS int64,
|
||||
) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsCosmosDB():
|
||||
return nil, fmt.Errorf("can't use CosmosDB implementation")
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
|
|
|
|||
322
userapi/storage/devices/cosmosdb/devices_table.go
Normal file
322
userapi/storage/devices/cosmosdb/devices_table.go
Normal file
|
|
@ -0,0 +1,322 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const devicesSchema = `
|
||||
-- This sequence is used for automatic allocation of session_id.
|
||||
-- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
|
||||
|
||||
-- Stores data about devices.
|
||||
CREATE TABLE IF NOT EXISTS device_devices (
|
||||
access_token TEXT PRIMARY KEY,
|
||||
session_id INTEGER,
|
||||
device_id TEXT ,
|
||||
localpart TEXT ,
|
||||
created_ts BIGINT,
|
||||
display_name TEXT,
|
||||
last_seen_ts BIGINT,
|
||||
ip TEXT,
|
||||
user_agent TEXT,
|
||||
|
||||
UNIQUE (localpart, device_id)
|
||||
);
|
||||
`
|
||||
|
||||
const insertDeviceSQL = "" +
|
||||
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||||
|
||||
const selectDevicesCountSQL = "" +
|
||||
"SELECT COUNT(access_token) FROM device_devices"
|
||||
|
||||
const selectDeviceByTokenSQL = "" +
|
||||
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
|
||||
|
||||
const selectDeviceByIDSQL = "" +
|
||||
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||
|
||||
const selectDevicesByLocalpartSQL = "" +
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||
|
||||
const updateDeviceNameSQL = "" +
|
||||
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
|
||||
const deleteDeviceSQL = "" +
|
||||
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
|
||||
|
||||
const deleteDevicesByLocalpartSQL = "" +
|
||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||
|
||||
const deleteDevicesSQL = "" +
|
||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||
|
||||
const selectDevicesByIDSQL = "" +
|
||||
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
|
||||
|
||||
const updateDeviceLastSeen = "" +
|
||||
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
|
||||
|
||||
type devicesStatements struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
insertDeviceStmt *sql.Stmt
|
||||
selectDevicesCountStmt *sql.Stmt
|
||||
selectDeviceByTokenStmt *sql.Stmt
|
||||
selectDeviceByIDStmt *sql.Stmt
|
||||
selectDevicesByIDStmt *sql.Stmt
|
||||
selectDevicesByLocalpartStmt *sql.Stmt
|
||||
updateDeviceNameStmt *sql.Stmt
|
||||
updateDeviceLastSeenStmt *sql.Stmt
|
||||
deleteDeviceStmt *sql.Stmt
|
||||
deleteDevicesByLocalpartStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func (s *devicesStatements) execSchema(db *sql.DB) error {
|
||||
_, err := db.Exec(devicesSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
|
||||
s.db = db
|
||||
s.writer = writer
|
||||
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
|
||||
return
|
||||
}
|
||||
s.serverName = server
|
||||
return
|
||||
}
|
||||
|
||||
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
|
||||
// Returns an error if the user already has a device with the given device ID.
|
||||
// Returns the device on success.
|
||||
func (s *devicesStatements) insertDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
) (*api.Device, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
var sessionID int64
|
||||
countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
|
||||
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionID++
|
||||
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &api.Device{
|
||||
ID: id,
|
||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
AccessToken: accessToken,
|
||||
SessionID: sessionID,
|
||||
LastSeenTS: createdTimeMS,
|
||||
LastSeenIP: ipAddr,
|
||||
UserAgent: userAgent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *devicesStatements) deleteDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||
_, err := stmt.ExecContext(ctx, id, localpart)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) deleteDevices(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||
) error {
|
||||
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
|
||||
prep, err := s.db.Prepare(orig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, prep)
|
||||
params := make([]interface{}, len(devices)+1)
|
||||
params[0] = localpart
|
||||
for i, v := range devices {
|
||||
params[i+1] = v
|
||||
}
|
||||
_, err = stmt.ExecContext(ctx, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) deleteDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) updateDeviceName(
|
||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDeviceByToken(
|
||||
ctx context.Context, accessToken string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
stmt := s.selectDeviceByTokenStmt
|
||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
|
||||
if err == nil {
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.AccessToken = accessToken
|
||||
}
|
||||
return &dev, err
|
||||
}
|
||||
|
||||
// selectDeviceByID retrieves a device from the database with the given user
|
||||
// localpart and deviceID
|
||||
func (s *devicesStatements) selectDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
var displayName sql.NullString
|
||||
stmt := s.selectDeviceByIDStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
|
||||
if err == nil {
|
||||
dev.ID = deviceID
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
if displayName.Valid {
|
||||
dev.DisplayName = displayName.String
|
||||
}
|
||||
}
|
||||
return &dev, err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
) ([]api.Device, error) {
|
||||
devices := []api.Device{}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
|
||||
|
||||
if err != nil {
|
||||
return devices, err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var dev api.Device
|
||||
var lastseents sql.NullInt64
|
||||
var id, displayname, ip, useragent sql.NullString
|
||||
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
|
||||
if err != nil {
|
||||
return devices, err
|
||||
}
|
||||
if id.Valid {
|
||||
dev.ID = id.String
|
||||
}
|
||||
if displayname.Valid {
|
||||
dev.DisplayName = displayname.String
|
||||
}
|
||||
if lastseents.Valid {
|
||||
dev.LastSeenTS = lastseents.Int64
|
||||
}
|
||||
if ip.Valid {
|
||||
dev.LastSeenIP = ip.String
|
||||
}
|
||||
if useragent.Valid {
|
||||
dev.UserAgent = useragent.String
|
||||
}
|
||||
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
|
||||
iDeviceIDs := make([]interface{}, len(deviceIDs))
|
||||
for i := range deviceIDs {
|
||||
iDeviceIDs[i] = deviceIDs[i]
|
||||
}
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, sqlQuery, iDeviceIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
|
||||
var devices []api.Device
|
||||
for rows.Next() {
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var displayName sql.NullString
|
||||
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if displayName.Valid {
|
||||
dev.DisplayName = displayName.String
|
||||
}
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
return devices, rows.Err()
|
||||
}
|
||||
|
||||
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
|
||||
lastSeenTs := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)
|
||||
return err
|
||||
}
|
||||
216
userapi/storage/devices/cosmosdb/storage.go
Normal file
216
userapi/storage/devices/cosmosdb/storage.go
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// The length of generated device IDs
|
||||
var deviceIDByteLength = 6
|
||||
|
||||
// Database represents a device database.
|
||||
type Database struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
devices devicesStatements
|
||||
}
|
||||
|
||||
// NewDatabase creates a new device database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) {
|
||||
dbProperties.ConnectionString = cosmosdbutil.GetConnectionString(&dbProperties.ConnectionString)
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writer := sqlutil.NewExclusiveWriter()
|
||||
d := devicesStatements{}
|
||||
|
||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
||||
if err = d.execSchema(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := sqlutil.NewMigrations()
|
||||
deltas.LoadLastSeenTSIP(m)
|
||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.prepare(db, writer, serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Database{db, writer, d}, nil
|
||||
}
|
||||
|
||||
// GetDeviceByAccessToken returns the device matching the given access token.
|
||||
// Returns sql.ErrNoRows if no matching device was found.
|
||||
func (d *Database) GetDeviceByAccessToken(
|
||||
ctx context.Context, token string,
|
||||
) (*api.Device, error) {
|
||||
return d.devices.selectDeviceByToken(ctx, token)
|
||||
}
|
||||
|
||||
// GetDeviceByID returns the device matching the given ID.
|
||||
// Returns sql.ErrNoRows if no matching device was found.
|
||||
func (d *Database) GetDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
) (*api.Device, error) {
|
||||
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
|
||||
}
|
||||
|
||||
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
||||
func (d *Database) GetDevicesByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) ([]api.Device, error) {
|
||||
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
|
||||
}
|
||||
|
||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
return d.devices.selectDevicesByID(ctx, deviceIDs)
|
||||
}
|
||||
|
||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||
// an error will be returned.
|
||||
// If no device ID is given one is generated.
|
||||
// Returns the device on success.
|
||||
func (d *Database) CreateDevice(
|
||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
) (dev *api.Device, returnErr error) {
|
||||
if deviceID != nil {
|
||||
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
// Revoke existing tokens for this device
|
||||
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
return err
|
||||
})
|
||||
} else {
|
||||
// We generate device IDs in a loop in case its already taken.
|
||||
// We cap this at going round 5 times to ensure we don't spin forever
|
||||
var newDeviceID string
|
||||
for i := 1; i <= 5; i++ {
|
||||
newDeviceID, returnErr = generateDeviceID()
|
||||
if returnErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
return err
|
||||
})
|
||||
if returnErr == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// generateDeviceID creates a new device id. Returns an error if failed to generate
|
||||
// random bytes.
|
||||
func generateDeviceID() (string, error) {
|
||||
b := make([]byte, deviceIDByteLength)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// url-safe no padding
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// UpdateDevice updates the given device with the display name.
|
||||
// Returns SQL error if there are problems and nil on success.
|
||||
func (d *Database) UpdateDevice(
|
||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||
) error {
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveDevice revokes a device by deleting the entry in the database
|
||||
// matching with the given device ID and user ID localpart.
|
||||
// If the device doesn't exist, it will not return an error
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveDevice(
|
||||
ctx context.Context, deviceID, localpart string,
|
||||
) error {
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
||||
// matching with the given device IDs and user ID localpart.
|
||||
// If the devices don't exist, it will not return an error
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveDevices(
|
||||
ctx context.Context, localpart string, devices []string,
|
||||
) error {
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveAllDevices revokes devices by deleting the entry in the
|
||||
// database matching the given user ID localpart.
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveAllDevices(
|
||||
ctx context.Context, localpart, exceptDeviceID string,
|
||||
) (devices []api.Device, err error) {
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
|
||||
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
|
||||
})
|
||||
}
|
||||
|
|
@ -20,6 +20,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/devices/cosmosdb"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/devices/postgres"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -29,6 +30,8 @@ import (
|
|||
// and sets postgres connection parameters
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsCosmosDB():
|
||||
return cosmosdb.NewDatabase(dbProperties, serverName)
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties, serverName)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ func NewDatabase(
|
|||
serverName gomatrixserverlib.ServerName,
|
||||
) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsCosmosDB():
|
||||
return nil, fmt.Errorf("can't use CosmosDB implementation")
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties, serverName)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
|
|
|
|||
Loading…
Reference in a new issue