Merge branch 'master' into kegan/dl-inbound

This commit is contained in:
Kegsay 2020-08-05 13:40:18 +01:00 committed by GitHub
commit 10f4f11a91
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 74 additions and 49 deletions

View file

@ -18,6 +18,7 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings"
"github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/federationsender/types"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
@ -69,7 +70,7 @@ type joinedHostsStatements struct {
deleteJoinedHostsStmt *sql.Stmt deleteJoinedHostsStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt
selectJoinedHostsForRoomsStmt *sql.Stmt // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
} }
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
@ -93,9 +94,6 @@ func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error)
if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil { if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil {
return return
} }
if s.selectJoinedHostsForRoomsStmt, err = db.Prepare(selectJoinedHostsForRoomsSQL); err != nil {
return
}
return return
} }
@ -168,7 +166,8 @@ func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
iRoomIDs[i] = roomIDs[i] iRoomIDs[i] = roomIDs[i]
} }
rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, iRoomIDs...) sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -59,6 +59,7 @@ const countStreamIDsForUserSQL = "" +
type deviceKeysStatements struct { type deviceKeysStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
upsertDeviceKeysStmt *sql.Stmt upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt
@ -68,6 +69,7 @@ type deviceKeysStatements struct {
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
s := &deviceKeysStatements{ s := &deviceKeysStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err := db.Exec(deviceKeysSchema) _, err := db.Exec(deviceKeysSchema)
if err != nil { if err != nil {
@ -165,6 +167,7 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
} }
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
for _, key := range keys { for _, key := range keys {
now := time.Now().Unix() now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
@ -175,4 +178,5 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
} }
} }
return nil return nil
})
} }

View file

@ -21,6 +21,7 @@ import (
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
@ -51,6 +52,7 @@ const selectKeyChangesSQL = "" +
type keyChangesStatements struct { type keyChangesStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
upsertKeyChangeStmt *sql.Stmt upsertKeyChangeStmt *sql.Stmt
selectKeyChangesStmt *sql.Stmt selectKeyChangesStmt *sql.Stmt
} }
@ -58,6 +60,7 @@ type keyChangesStatements struct {
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
s := &keyChangesStatements{ s := &keyChangesStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err := db.Exec(keyChangesSchema) _, err := db.Exec(keyChangesSchema)
if err != nil { if err != nil {
@ -73,8 +76,10 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
} }
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
return err return err
})
} }
func (s *keyChangesStatements) SelectKeyChanges( func (s *keyChangesStatements) SelectKeyChanges(

View file

@ -60,6 +60,7 @@ const selectKeyByAlgorithmSQL = "" +
type oneTimeKeysStatements struct { type oneTimeKeysStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
upsertKeysStmt *sql.Stmt upsertKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt selectKeysStmt *sql.Stmt
selectKeysCountStmt *sql.Stmt selectKeysCountStmt *sql.Stmt
@ -70,6 +71,7 @@ type oneTimeKeysStatements struct {
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
s := &oneTimeKeysStatements{ s := &oneTimeKeysStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err := db.Exec(oneTimeKeysSchema) _, err := db.Exec(oneTimeKeysSchema)
if err != nil { if err != nil {
@ -150,7 +152,7 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.
UserID: keys.UserID, UserID: keys.UserID,
KeyCount: make(map[string]int), KeyCount: make(map[string]int),
} }
return counts, sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { return counts, s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
for keyIDWithAlgo, keyJSON := range keys.KeyJSON { for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo) algo, keyID := keys.Split(keyIDWithAlgo)
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext( _, err := txn.Stmt(s.upsertKeysStmt).ExecContext(
@ -183,14 +185,17 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
) (map[string]json.RawMessage, error) { ) (map[string]json.RawMessage, error) {
var keyID string var keyID string
var keyJSON string var keyJSON string
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil
} }
return nil, err return err
} }
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
return err
})
return map[string]json.RawMessage{ return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON), algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err }, err

View file

@ -2,6 +2,10 @@ package storage
import ( import (
"context" "context"
"fmt"
"io/ioutil"
"log"
"os"
"reflect" "reflect"
"testing" "testing"
@ -11,6 +15,21 @@ import (
var ctx = context.Background() var ctx = context.Background()
func MustCreateDatabase(t *testing.T) (Database, func()) {
tmpfile, err := ioutil.TempFile("", "keyserver_storage_test")
if err != nil {
log.Fatal(err)
}
t.Logf("Database %s", tmpfile.Name())
db, err := NewDatabase(fmt.Sprintf("file://%s", tmpfile.Name()), nil)
if err != nil {
t.Fatalf("Failed to NewDatabase: %s", err)
}
return db, func() {
os.Remove(tmpfile.Name())
}
}
func MustNotError(t *testing.T, err error) { func MustNotError(t *testing.T, err error) {
t.Helper() t.Helper()
if err == nil { if err == nil {
@ -20,10 +39,8 @@ func MustNotError(t *testing.T, err error) {
} }
func TestKeyChanges(t *testing.T) { func TestKeyChanges(t *testing.T) {
db, err := NewDatabase("file::memory:", nil) db, clean := MustCreateDatabase(t)
if err != nil { defer clean()
t.Fatalf("Failed to NewDatabase: %s", err)
}
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
@ -40,10 +57,8 @@ func TestKeyChanges(t *testing.T) {
} }
func TestKeyChangesNoDupes(t *testing.T) { func TestKeyChangesNoDupes(t *testing.T) {
db, err := NewDatabase("file::memory:", nil) db, clean := MustCreateDatabase(t)
if err != nil { defer clean()
t.Fatalf("Failed to NewDatabase: %s", err)
}
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost"))
@ -60,10 +75,8 @@ func TestKeyChangesNoDupes(t *testing.T) {
} }
func TestKeyChangesUpperLimit(t *testing.T) { func TestKeyChangesUpperLimit(t *testing.T) {
db, err := NewDatabase("file::memory:", nil) db, clean := MustCreateDatabase(t)
if err != nil { defer clean()
t.Fatalf("Failed to NewDatabase: %s", err)
}
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
@ -82,10 +95,9 @@ func TestKeyChangesUpperLimit(t *testing.T) {
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, // The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
// and that they are returned correctly when querying for device keys. // and that they are returned correctly when querying for device keys.
func TestDeviceKeysStreamIDGeneration(t *testing.T) { func TestDeviceKeysStreamIDGeneration(t *testing.T) {
db, err := NewDatabase("file::memory:", nil) var err error
if err != nil { db, clean := MustCreateDatabase(t)
t.Fatalf("Failed to NewDatabase: %s", err) defer clean()
}
alice := "@alice:TestDeviceKeysStreamIDGeneration" alice := "@alice:TestDeviceKeysStreamIDGeneration"
bob := "@bob:TestDeviceKeysStreamIDGeneration" bob := "@bob:TestDeviceKeysStreamIDGeneration"
msgs := []api.DeviceMessage{ msgs := []api.DeviceMessage{