mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-22 14:21:55 -06:00
Add roomserver tests (3/4) (#2447)
* Add Room Aliases tests * Add Rooms table test * Move StateKeyTuplerSorter to the types package * Add StateBlock tests Some optimizations * Add State Snapshot tests Some optimization * Return []int64 and convert to pq.Int64Array for postgres * Move []types.EventNID back to rows.Next() * Update tests, rename SelectRoomIDs
This commit is contained in:
parent
6af35385ba
commit
05607d6b87
|
@ -264,11 +264,11 @@ func (s *eventStatements) BulkSelectStateEventByNID(
|
||||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||||
stateKeyTuples []types.StateKeyTuple,
|
stateKeyTuples []types.StateKeyTuple,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
tuples := stateKeyTupleSorter(stateKeyTuples)
|
tuples := types.StateKeyTupleSorter(stateKeyTuples)
|
||||||
sort.Sort(tuples)
|
sort.Sort(tuples)
|
||||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
eventTypeNIDArray, eventStateKeyNIDArray := tuples.TypesAndStateKeysAsArrays()
|
||||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt)
|
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt)
|
||||||
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
|
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), pq.Int64Array(eventTypeNIDArray), pq.Int64Array(eventStateKeyNIDArray))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,12 +61,12 @@ type roomAliasesStatements struct {
|
||||||
deleteRoomAliasStmt *sql.Stmt
|
deleteRoomAliasStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func createRoomAliasesTable(db *sql.DB) error {
|
func CreateRoomAliasesTable(db *sql.DB) error {
|
||||||
_, err := db.Exec(roomAliasesSchema)
|
_, err := db.Exec(roomAliasesSchema)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
|
func PrepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
|
||||||
s := &roomAliasesStatements{}
|
s := &roomAliasesStatements{}
|
||||||
|
|
||||||
return s, sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
|
@ -108,8 +108,8 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed")
|
||||||
|
|
||||||
var aliases []string
|
var aliases []string
|
||||||
for rows.Next() {
|
|
||||||
var alias string
|
var alias string
|
||||||
|
for rows.Next() {
|
||||||
if err = rows.Scan(&alias); err != nil {
|
if err = rows.Scan(&alias); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -95,12 +95,12 @@ type roomStatements struct {
|
||||||
bulkSelectRoomNIDsStmt *sql.Stmt
|
bulkSelectRoomNIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func createRoomsTable(db *sql.DB) error {
|
func CreateRoomsTable(db *sql.DB) error {
|
||||||
_, err := db.Exec(roomsSchema)
|
_, err := db.Exec(roomsSchema)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||||
s := &roomStatements{}
|
s := &roomStatements{}
|
||||||
|
|
||||||
return s, sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
|
@ -117,7 +117,7 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
|
func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
|
||||||
rows, err := stmt.QueryContext(ctx)
|
rows, err := stmt.QueryContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -125,8 +125,8 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]stri
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
|
||||||
var roomIDs []string
|
var roomIDs []string
|
||||||
for rows.Next() {
|
|
||||||
var roomID string
|
var roomID string
|
||||||
|
for rows.Next() {
|
||||||
if err = rows.Scan(&roomID); err != nil {
|
if err = rows.Scan(&roomID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -231,9 +231,9 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
|
||||||
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
|
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
|
||||||
for rows.Next() {
|
|
||||||
var roomNID types.RoomNID
|
var roomNID types.RoomNID
|
||||||
var roomVersion gomatrixserverlib.RoomVersion
|
var roomVersion gomatrixserverlib.RoomVersion
|
||||||
|
for rows.Next() {
|
||||||
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
|
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -254,8 +254,8 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roo
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
|
||||||
var roomIDs []string
|
var roomIDs []string
|
||||||
for rows.Next() {
|
|
||||||
var roomID string
|
var roomID string
|
||||||
|
for rows.Next() {
|
||||||
if err = rows.Scan(&roomID); err != nil {
|
if err = rows.Scan(&roomID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -276,8 +276,8 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, ro
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
|
||||||
var roomNIDs []types.RoomNID
|
var roomNIDs []types.RoomNID
|
||||||
for rows.Next() {
|
|
||||||
var roomNID types.RoomNID
|
var roomNID types.RoomNID
|
||||||
|
for rows.Next() {
|
||||||
if err = rows.Scan(&roomNID); err != nil {
|
if err = rows.Scan(&roomNID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
@ -71,12 +70,12 @@ type stateBlockStatements struct {
|
||||||
bulkSelectStateBlockEntriesStmt *sql.Stmt
|
bulkSelectStateBlockEntriesStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func createStateBlockTable(db *sql.DB) error {
|
func CreateStateBlockTable(db *sql.DB) error {
|
||||||
_, err := db.Exec(stateDataSchema)
|
_, err := db.Exec(stateDataSchema)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
||||||
s := &stateBlockStatements{}
|
s := &stateBlockStatements{}
|
||||||
|
|
||||||
return s, sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
|
@ -90,9 +89,9 @@ func (s *stateBlockStatements) BulkInsertStateData(
|
||||||
entries types.StateEntries,
|
entries types.StateEntries,
|
||||||
) (id types.StateBlockNID, err error) {
|
) (id types.StateBlockNID, err error) {
|
||||||
entries = entries[:util.SortAndUnique(entries)]
|
entries = entries[:util.SortAndUnique(entries)]
|
||||||
var nids types.EventNIDs
|
nids := make(types.EventNIDs, entries.Len())
|
||||||
for _, e := range entries {
|
for i := range entries {
|
||||||
nids = append(nids, e.EventNID)
|
nids[i] = entries[i].EventNID
|
||||||
}
|
}
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
|
||||||
err = stmt.QueryRowContext(
|
err = stmt.QueryRowContext(
|
||||||
|
@ -113,15 +112,15 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
||||||
|
|
||||||
results := make([][]types.EventNID, len(stateBlockNIDs))
|
results := make([][]types.EventNID, len(stateBlockNIDs))
|
||||||
i := 0
|
i := 0
|
||||||
for ; rows.Next(); i++ {
|
|
||||||
var stateBlockNID types.StateBlockNID
|
var stateBlockNID types.StateBlockNID
|
||||||
var result pq.Int64Array
|
var result pq.Int64Array
|
||||||
|
for ; rows.Next(); i++ {
|
||||||
if err = rows.Scan(&stateBlockNID, &result); err != nil {
|
if err = rows.Scan(&stateBlockNID, &result); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
r := []types.EventNID{}
|
r := make([]types.EventNID, len(result))
|
||||||
for _, e := range result {
|
for x := range result {
|
||||||
r = append(r, types.EventNID(e))
|
r[x] = types.EventNID(result[x])
|
||||||
}
|
}
|
||||||
results[i] = r
|
results[i] = r
|
||||||
}
|
}
|
||||||
|
@ -141,35 +140,3 @@ func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {
|
||||||
}
|
}
|
||||||
return pq.Int64Array(nids)
|
return pq.Int64Array(nids)
|
||||||
}
|
}
|
||||||
|
|
||||||
type stateKeyTupleSorter []types.StateKeyTuple
|
|
||||||
|
|
||||||
func (s stateKeyTupleSorter) Len() int { return len(s) }
|
|
||||||
func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
|
|
||||||
func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
|
||||||
|
|
||||||
// Check whether a tuple is in the list. Assumes that the list is sorted.
|
|
||||||
func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
|
|
||||||
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
|
|
||||||
return i < len(s) && s[i] == value
|
|
||||||
}
|
|
||||||
|
|
||||||
// List the unique eventTypeNIDs and eventStateKeyNIDs.
|
|
||||||
// Assumes that the list is sorted.
|
|
||||||
func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) {
|
|
||||||
eventTypeNIDs = make(pq.Int64Array, len(s))
|
|
||||||
eventStateKeyNIDs = make(pq.Int64Array, len(s))
|
|
||||||
for i := range s {
|
|
||||||
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
|
|
||||||
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
|
|
||||||
}
|
|
||||||
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
|
|
||||||
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
type int64Sorter []int64
|
|
||||||
|
|
||||||
func (s int64Sorter) Len() int { return len(s) }
|
|
||||||
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
|
|
||||||
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
|
||||||
|
|
|
@ -1,86 +0,0 @@
|
||||||
// Copyright 2017-2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package postgres
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sort"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestStateKeyTupleSorter(t *testing.T) {
|
|
||||||
input := stateKeyTupleSorter{
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
|
||||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
|
||||||
}
|
|
||||||
want := []types.StateKeyTuple{
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
|
||||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
|
||||||
}
|
|
||||||
doNotWant := []types.StateKeyTuple{
|
|
||||||
{EventTypeNID: 0, EventStateKeyNID: 0},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 3},
|
|
||||||
{EventTypeNID: 2, EventStateKeyNID: 1},
|
|
||||||
{EventTypeNID: 3, EventStateKeyNID: 1},
|
|
||||||
}
|
|
||||||
wantTypeNIDs := []int64{1, 2}
|
|
||||||
wantStateKeyNIDs := []int64{1, 2, 4}
|
|
||||||
|
|
||||||
// Sort the input and check it's in the right order.
|
|
||||||
sort.Sort(input)
|
|
||||||
gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
|
|
||||||
|
|
||||||
for i := range want {
|
|
||||||
if input[i] != want[i] {
|
|
||||||
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
if !input.contains(want[i]) {
|
|
||||||
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range doNotWant {
|
|
||||||
if input.contains(doNotWant[i]) {
|
|
||||||
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(wantTypeNIDs) != len(gotTypeNIDs) {
|
|
||||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range wantTypeNIDs {
|
|
||||||
if wantTypeNIDs[i] != gotTypeNIDs[i] {
|
|
||||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
|
|
||||||
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range wantStateKeyNIDs {
|
|
||||||
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
|
|
||||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -77,12 +77,12 @@ type stateSnapshotStatements struct {
|
||||||
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func createStateSnapshotTable(db *sql.DB) error {
|
func CreateStateSnapshotTable(db *sql.DB) error {
|
||||||
_, err := db.Exec(stateSnapshotSchema)
|
_, err := db.Exec(stateSnapshotSchema)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||||
s := &stateSnapshotStatements{}
|
s := &stateSnapshotStatements{}
|
||||||
|
|
||||||
return s, sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
|
@ -95,12 +95,10 @@ func (s *stateSnapshotStatements) InsertState(
|
||||||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs,
|
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs,
|
||||||
) (stateNID types.StateSnapshotNID, err error) {
|
) (stateNID types.StateSnapshotNID, err error) {
|
||||||
nids = nids[:util.SortAndUnique(nids)]
|
nids = nids[:util.SortAndUnique(nids)]
|
||||||
var id int64
|
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&stateNID)
|
||||||
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&id)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
stateNID = types.StateSnapshotNID(id)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,9 +117,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
||||||
defer rows.Close() // nolint: errcheck
|
defer rows.Close() // nolint: errcheck
|
||||||
results := make([]types.StateBlockNIDList, len(stateNIDs))
|
results := make([]types.StateBlockNIDList, len(stateNIDs))
|
||||||
i := 0
|
i := 0
|
||||||
|
var stateBlockNIDs pq.Int64Array
|
||||||
for ; rows.Next(); i++ {
|
for ; rows.Next(); i++ {
|
||||||
result := &results[i]
|
result := &results[i]
|
||||||
var stateBlockNIDs pq.Int64Array
|
|
||||||
if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
|
if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -80,19 +80,19 @@ func (d *Database) create(db *sql.DB) error {
|
||||||
if err := CreateEventsTable(db); err != nil {
|
if err := CreateEventsTable(db); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := createRoomsTable(db); err != nil {
|
if err := CreateRoomsTable(db); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := createStateBlockTable(db); err != nil {
|
if err := CreateStateBlockTable(db); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := createStateSnapshotTable(db); err != nil {
|
if err := CreateStateSnapshotTable(db); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := CreatePrevEventsTable(db); err != nil {
|
if err := CreatePrevEventsTable(db); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := createRoomAliasesTable(db); err != nil {
|
if err := CreateRoomAliasesTable(db); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := CreateInvitesTable(db); err != nil {
|
if err := CreateInvitesTable(db); err != nil {
|
||||||
|
@ -128,15 +128,15 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
rooms, err := prepareRoomsTable(db)
|
rooms, err := PrepareRoomsTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stateBlock, err := prepareStateBlockTable(db)
|
stateBlock, err := PrepareStateBlockTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stateSnapshot, err := prepareStateSnapshotTable(db)
|
stateSnapshot, err := PrepareStateSnapshotTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -144,7 +144,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
roomAliases, err := prepareRoomAliasesTable(db)
|
roomAliases, err := PrepareRoomAliasesTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1216,7 +1216,7 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
|
||||||
|
|
||||||
// GetKnownRooms returns a list of all rooms we know about.
|
// GetKnownRooms returns a list of all rooms we know about.
|
||||||
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
|
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
|
||||||
return d.RoomsTable.SelectRoomIDs(ctx, nil)
|
return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForgetRoom sets a users room to forgotten
|
// ForgetRoom sets a users room to forgotten
|
||||||
|
|
|
@ -247,9 +247,9 @@ func (s *eventStatements) BulkSelectStateEventByNID(
|
||||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||||
stateKeyTuples []types.StateKeyTuple,
|
stateKeyTuples []types.StateKeyTuple,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
tuples := stateKeyTupleSorter(stateKeyTuples)
|
tuples := types.StateKeyTupleSorter(stateKeyTuples)
|
||||||
sort.Sort(tuples)
|
sort.Sort(tuples)
|
||||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
eventTypeNIDArray, eventStateKeyNIDArray := tuples.TypesAndStateKeysAsArrays()
|
||||||
params := make([]interface{}, 0, len(eventNIDs)+len(eventTypeNIDArray)+len(eventStateKeyNIDArray))
|
params := make([]interface{}, 0, len(eventNIDs)+len(eventTypeNIDArray)+len(eventStateKeyNIDArray))
|
||||||
selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
|
||||||
for _, v := range eventNIDs {
|
for _, v := range eventNIDs {
|
||||||
|
|
|
@ -63,12 +63,12 @@ type roomAliasesStatements struct {
|
||||||
deleteRoomAliasStmt *sql.Stmt
|
deleteRoomAliasStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func createRoomAliasesTable(db *sql.DB) error {
|
func CreateRoomAliasesTable(db *sql.DB) error {
|
||||||
_, err := db.Exec(roomAliasesSchema)
|
_, err := db.Exec(roomAliasesSchema)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
|
func PrepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
|
||||||
s := &roomAliasesStatements{
|
s := &roomAliasesStatements{
|
||||||
db: db,
|
db: db,
|
||||||
}
|
}
|
||||||
|
@ -113,8 +113,8 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
||||||
|
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed")
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
var alias string
|
var alias string
|
||||||
|
for rows.Next() {
|
||||||
if err = rows.Scan(&alias); err != nil {
|
if err = rows.Scan(&alias); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,12 +86,12 @@ type roomStatements struct {
|
||||||
selectRoomIDsStmt *sql.Stmt
|
selectRoomIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func createRoomsTable(db *sql.DB) error {
|
func CreateRoomsTable(db *sql.DB) error {
|
||||||
_, err := db.Exec(roomsSchema)
|
_, err := db.Exec(roomsSchema)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||||
s := &roomStatements{
|
s := &roomStatements{
|
||||||
db: db,
|
db: db,
|
||||||
}
|
}
|
||||||
|
@ -108,7 +108,7 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
|
func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
|
||||||
rows, err := stmt.QueryContext(ctx)
|
rows, err := stmt.QueryContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -116,8 +116,8 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]stri
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
|
||||||
var roomIDs []string
|
var roomIDs []string
|
||||||
for rows.Next() {
|
|
||||||
var roomID string
|
var roomID string
|
||||||
|
for rows.Next() {
|
||||||
if err = rows.Scan(&roomID); err != nil {
|
if err = rows.Scan(&roomID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -241,9 +241,9 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
|
||||||
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
|
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
|
||||||
for rows.Next() {
|
|
||||||
var roomNID types.RoomNID
|
var roomNID types.RoomNID
|
||||||
var roomVersion gomatrixserverlib.RoomVersion
|
var roomVersion gomatrixserverlib.RoomVersion
|
||||||
|
for rows.Next() {
|
||||||
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
|
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -270,8 +270,8 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roo
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
|
||||||
var roomIDs []string
|
var roomIDs []string
|
||||||
for rows.Next() {
|
|
||||||
var roomID string
|
var roomID string
|
||||||
|
for rows.Next() {
|
||||||
if err = rows.Scan(&roomID); err != nil {
|
if err = rows.Scan(&roomID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -298,8 +298,8 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, ro
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
|
||||||
var roomNIDs []types.RoomNID
|
var roomNIDs []types.RoomNID
|
||||||
for rows.Next() {
|
|
||||||
var roomNID types.RoomNID
|
var roomNID types.RoomNID
|
||||||
|
for rows.Next() {
|
||||||
if err = rows.Scan(&roomNID); err != nil {
|
if err = rows.Scan(&roomNID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,6 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
@ -64,12 +63,12 @@ type stateBlockStatements struct {
|
||||||
bulkSelectStateBlockEntriesStmt *sql.Stmt
|
bulkSelectStateBlockEntriesStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func createStateBlockTable(db *sql.DB) error {
|
func CreateStateBlockTable(db *sql.DB) error {
|
||||||
_, err := db.Exec(stateDataSchema)
|
_, err := db.Exec(stateDataSchema)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
||||||
s := &stateBlockStatements{
|
s := &stateBlockStatements{
|
||||||
db: db,
|
db: db,
|
||||||
}
|
}
|
||||||
|
@ -85,9 +84,9 @@ func (s *stateBlockStatements) BulkInsertStateData(
|
||||||
entries types.StateEntries,
|
entries types.StateEntries,
|
||||||
) (id types.StateBlockNID, err error) {
|
) (id types.StateBlockNID, err error) {
|
||||||
entries = entries[:util.SortAndUnique(entries)]
|
entries = entries[:util.SortAndUnique(entries)]
|
||||||
nids := types.EventNIDs{} // zero slice to not store 'null' in the DB
|
nids := make(types.EventNIDs, entries.Len())
|
||||||
for _, e := range entries {
|
for i := range entries {
|
||||||
nids = append(nids, e.EventNID)
|
nids[i] = entries[i].EventNID
|
||||||
}
|
}
|
||||||
js, err := json.Marshal(nids)
|
js, err := json.Marshal(nids)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -122,13 +121,13 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
||||||
|
|
||||||
results := make([][]types.EventNID, len(stateBlockNIDs))
|
results := make([][]types.EventNID, len(stateBlockNIDs))
|
||||||
i := 0
|
i := 0
|
||||||
for ; rows.Next(); i++ {
|
|
||||||
var stateBlockNID types.StateBlockNID
|
var stateBlockNID types.StateBlockNID
|
||||||
var result json.RawMessage
|
var result json.RawMessage
|
||||||
|
for ; rows.Next(); i++ {
|
||||||
if err = rows.Scan(&stateBlockNID, &result); err != nil {
|
if err = rows.Scan(&stateBlockNID, &result); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
r := []types.EventNID{}
|
var r []types.EventNID
|
||||||
if err = json.Unmarshal(result, &r); err != nil {
|
if err = json.Unmarshal(result, &r); err != nil {
|
||||||
return nil, fmt.Errorf("json.Unmarshal: %w", err)
|
return nil, fmt.Errorf("json.Unmarshal: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -142,35 +141,3 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
||||||
}
|
}
|
||||||
return results, err
|
return results, err
|
||||||
}
|
}
|
||||||
|
|
||||||
type stateKeyTupleSorter []types.StateKeyTuple
|
|
||||||
|
|
||||||
func (s stateKeyTupleSorter) Len() int { return len(s) }
|
|
||||||
func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
|
|
||||||
func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
|
||||||
|
|
||||||
// Check whether a tuple is in the list. Assumes that the list is sorted.
|
|
||||||
func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
|
|
||||||
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
|
|
||||||
return i < len(s) && s[i] == value
|
|
||||||
}
|
|
||||||
|
|
||||||
// List the unique eventTypeNIDs and eventStateKeyNIDs.
|
|
||||||
// Assumes that the list is sorted.
|
|
||||||
func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) {
|
|
||||||
eventTypeNIDs = make([]int64, len(s))
|
|
||||||
eventStateKeyNIDs = make([]int64, len(s))
|
|
||||||
for i := range s {
|
|
||||||
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
|
|
||||||
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
|
|
||||||
}
|
|
||||||
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
|
|
||||||
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
type int64Sorter []int64
|
|
||||||
|
|
||||||
func (s int64Sorter) Len() int { return len(s) }
|
|
||||||
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
|
|
||||||
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
|
||||||
|
|
|
@ -1,86 +0,0 @@
|
||||||
// Copyright 2017-2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package sqlite3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sort"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestStateKeyTupleSorter(t *testing.T) {
|
|
||||||
input := stateKeyTupleSorter{
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
|
||||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
|
||||||
}
|
|
||||||
want := []types.StateKeyTuple{
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
|
||||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
|
||||||
}
|
|
||||||
doNotWant := []types.StateKeyTuple{
|
|
||||||
{EventTypeNID: 0, EventStateKeyNID: 0},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 3},
|
|
||||||
{EventTypeNID: 2, EventStateKeyNID: 1},
|
|
||||||
{EventTypeNID: 3, EventStateKeyNID: 1},
|
|
||||||
}
|
|
||||||
wantTypeNIDs := []int64{1, 2}
|
|
||||||
wantStateKeyNIDs := []int64{1, 2, 4}
|
|
||||||
|
|
||||||
// Sort the input and check it's in the right order.
|
|
||||||
sort.Sort(input)
|
|
||||||
gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
|
|
||||||
|
|
||||||
for i := range want {
|
|
||||||
if input[i] != want[i] {
|
|
||||||
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
if !input.contains(want[i]) {
|
|
||||||
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range doNotWant {
|
|
||||||
if input.contains(doNotWant[i]) {
|
|
||||||
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(wantTypeNIDs) != len(gotTypeNIDs) {
|
|
||||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range wantTypeNIDs {
|
|
||||||
if wantTypeNIDs[i] != gotTypeNIDs[i] {
|
|
||||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
|
|
||||||
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range wantStateKeyNIDs {
|
|
||||||
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
|
|
||||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -68,12 +68,12 @@ type stateSnapshotStatements struct {
|
||||||
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func createStateSnapshotTable(db *sql.DB) error {
|
func CreateStateSnapshotTable(db *sql.DB) error {
|
||||||
_, err := db.Exec(stateSnapshotSchema)
|
_, err := db.Exec(stateSnapshotSchema)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||||
s := &stateSnapshotStatements{
|
s := &stateSnapshotStatements{
|
||||||
db: db,
|
db: db,
|
||||||
}
|
}
|
||||||
|
@ -96,12 +96,10 @@ func (s *stateSnapshotStatements) InsertState(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
|
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
|
||||||
var id int64
|
err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&stateNID)
|
||||||
err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&id)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
stateNID = types.StateSnapshotNID(id)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -127,9 +125,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed")
|
||||||
results := make([]types.StateBlockNIDList, len(stateNIDs))
|
results := make([]types.StateBlockNIDList, len(stateNIDs))
|
||||||
i := 0
|
i := 0
|
||||||
|
var stateBlockNIDsJSON string
|
||||||
for ; rows.Next(); i++ {
|
for ; rows.Next(); i++ {
|
||||||
result := &results[i]
|
result := &results[i]
|
||||||
var stateBlockNIDsJSON string
|
|
||||||
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil {
|
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,19 +89,19 @@ func (d *Database) create(db *sql.DB) error {
|
||||||
if err := CreateEventsTable(db); err != nil {
|
if err := CreateEventsTable(db); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := createRoomsTable(db); err != nil {
|
if err := CreateRoomsTable(db); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := createStateBlockTable(db); err != nil {
|
if err := CreateStateBlockTable(db); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := createStateSnapshotTable(db); err != nil {
|
if err := CreateStateSnapshotTable(db); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := CreatePrevEventsTable(db); err != nil {
|
if err := CreatePrevEventsTable(db); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := createRoomAliasesTable(db); err != nil {
|
if err := CreateRoomAliasesTable(db); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := CreateInvitesTable(db); err != nil {
|
if err := CreateInvitesTable(db); err != nil {
|
||||||
|
@ -137,15 +137,15 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
rooms, err := prepareRoomsTable(db)
|
rooms, err := PrepareRoomsTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stateBlock, err := prepareStateBlockTable(db)
|
stateBlock, err := PrepareStateBlockTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stateSnapshot, err := prepareStateSnapshotTable(db)
|
stateSnapshot, err := PrepareStateSnapshotTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -153,7 +153,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
roomAliases, err := prepareRoomAliasesTable(db)
|
roomAliases, err := PrepareRoomAliasesTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,7 +72,7 @@ type Rooms interface {
|
||||||
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
|
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
|
||||||
SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
|
SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
|
||||||
SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error)
|
SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error)
|
||||||
SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error)
|
SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error)
|
||||||
BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error)
|
BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error)
|
||||||
BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error)
|
BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error)
|
||||||
}
|
}
|
||||||
|
|
96
roomserver/storage/tables/room_aliases_table_test.go
Normal file
96
roomserver/storage/tables/room_aliases_table_test.go
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
package tables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustCreateRoomAliasesTable(t *testing.T, dbType test.DBType) (tab tables.RoomAliases, close func()) {
|
||||||
|
t.Helper()
|
||||||
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
}, sqlutil.NewExclusiveWriter())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
switch dbType {
|
||||||
|
case test.DBTypePostgres:
|
||||||
|
err = postgres.CreateRoomAliasesTable(db)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
tab, err = postgres.PrepareRoomAliasesTable(db)
|
||||||
|
case test.DBTypeSQLite:
|
||||||
|
err = sqlite3.CreateRoomAliasesTable(db)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
tab, err = sqlite3.PrepareRoomAliasesTable(db)
|
||||||
|
}
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
return tab, close
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomAliasesTable(t *testing.T) {
|
||||||
|
alice := test.NewUser()
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
room2 := test.NewRoom(t, alice)
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
tab, close := mustCreateRoomAliasesTable(t, dbType)
|
||||||
|
defer close()
|
||||||
|
alias, alias2, alias3 := "#alias:localhost", "#alias2:localhost", "#alias3:localhost"
|
||||||
|
// insert aliases
|
||||||
|
err := tab.InsertRoomAlias(ctx, nil, alias, room.ID, alice.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = tab.InsertRoomAlias(ctx, nil, alias2, room.ID, alice.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = tab.InsertRoomAlias(ctx, nil, alias3, room2.ID, alice.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// verify we can get the roomID for the alias
|
||||||
|
roomID, err := tab.SelectRoomIDFromAlias(ctx, nil, alias)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, room.ID, roomID)
|
||||||
|
|
||||||
|
// .. and the creator
|
||||||
|
creator, err := tab.SelectCreatorIDFromAlias(ctx, nil, alias)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, alice.ID, creator)
|
||||||
|
|
||||||
|
creator, err = tab.SelectCreatorIDFromAlias(ctx, nil, "#doesntexist:localhost")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "", creator)
|
||||||
|
|
||||||
|
roomID, err = tab.SelectRoomIDFromAlias(ctx, nil, "#doesntexist:localhost")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "", roomID)
|
||||||
|
|
||||||
|
// get all aliases for a room
|
||||||
|
aliases, err := tab.SelectAliasesFromRoomID(ctx, nil, room.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{alias, alias2}, aliases)
|
||||||
|
|
||||||
|
// delete an alias and verify it's deleted
|
||||||
|
err = tab.DeleteRoomAlias(ctx, nil, alias2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
aliases, err = tab.SelectAliasesFromRoomID(ctx, nil, room.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{alias}, aliases)
|
||||||
|
|
||||||
|
// deleting the same alias should be a no-op
|
||||||
|
err = tab.DeleteRoomAlias(ctx, nil, alias2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Delete non-existent alias should be a no-op
|
||||||
|
err = tab.DeleteRoomAlias(ctx, nil, "#doesntexist:localhost")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
128
roomserver/storage/tables/rooms_table_test.go
Normal file
128
roomserver/storage/tables/rooms_table_test.go
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
package tables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustCreateRoomsTable(t *testing.T, dbType test.DBType) (tab tables.Rooms, close func()) {
|
||||||
|
t.Helper()
|
||||||
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
}, sqlutil.NewExclusiveWriter())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
switch dbType {
|
||||||
|
case test.DBTypePostgres:
|
||||||
|
err = postgres.CreateRoomsTable(db)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
tab, err = postgres.PrepareRoomsTable(db)
|
||||||
|
case test.DBTypeSQLite:
|
||||||
|
err = sqlite3.CreateRoomsTable(db)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
tab, err = sqlite3.PrepareRoomsTable(db)
|
||||||
|
}
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
return tab, close
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomsTable(t *testing.T) {
|
||||||
|
alice := test.NewUser()
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
tab, close := mustCreateRoomsTable(t, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
wantRoomNID, err := tab.InsertRoomNID(ctx, nil, room.ID, room.Version)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Create dummy room
|
||||||
|
_, err = tab.InsertRoomNID(ctx, nil, util.RandomString(16), room.Version)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
gotRoomNID, err := tab.SelectRoomNID(ctx, nil, room.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, wantRoomNID, gotRoomNID)
|
||||||
|
|
||||||
|
// Ensure non existent roomNID errors
|
||||||
|
roomNID, err := tab.SelectRoomNID(ctx, nil, "!doesnotexist:localhost")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, types.RoomNID(0), roomNID)
|
||||||
|
|
||||||
|
roomInfo, err := tab.SelectRoomInfo(ctx, nil, room.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, &types.RoomInfo{
|
||||||
|
RoomNID: wantRoomNID,
|
||||||
|
RoomVersion: room.Version,
|
||||||
|
StateSnapshotNID: 0,
|
||||||
|
IsStub: true, // there are no latestEventNIDs
|
||||||
|
}, roomInfo)
|
||||||
|
|
||||||
|
roomInfo, err = tab.SelectRoomInfo(ctx, nil, "!doesnotexist:localhost")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, roomInfo)
|
||||||
|
|
||||||
|
// There are no rooms with latestEventNIDs yet
|
||||||
|
roomIDs, err := tab.SelectRoomIDsWithEvents(ctx, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, len(roomIDs))
|
||||||
|
|
||||||
|
roomVersions, err := tab.SelectRoomVersionsForRoomNIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, roomVersions[wantRoomNID], room.Version)
|
||||||
|
// Room does not exist
|
||||||
|
_, ok := roomVersions[1337]
|
||||||
|
assert.False(t, ok)
|
||||||
|
|
||||||
|
roomIDs, err = tab.BulkSelectRoomIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{room.ID}, roomIDs)
|
||||||
|
|
||||||
|
roomNIDs, err := tab.BulkSelectRoomNIDs(ctx, nil, []string{room.ID, "!doesnotexist:localhost"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []types.RoomNID{wantRoomNID}, roomNIDs)
|
||||||
|
|
||||||
|
wantEventNIDs := []types.EventNID{1, 2, 3}
|
||||||
|
lastEventSentNID := types.EventNID(3)
|
||||||
|
stateSnapshotNID := types.StateSnapshotNID(1)
|
||||||
|
// make the room "usable"
|
||||||
|
err = tab.UpdateLatestEventNIDs(ctx, nil, wantRoomNID, wantEventNIDs, lastEventSentNID, stateSnapshotNID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
roomInfo, err = tab.SelectRoomInfo(ctx, nil, room.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, &types.RoomInfo{
|
||||||
|
RoomNID: wantRoomNID,
|
||||||
|
RoomVersion: room.Version,
|
||||||
|
StateSnapshotNID: 1,
|
||||||
|
IsStub: false,
|
||||||
|
}, roomInfo)
|
||||||
|
|
||||||
|
eventNIDs, snapshotNID, err := tab.SelectLatestEventNIDs(ctx, nil, wantRoomNID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, wantEventNIDs, eventNIDs)
|
||||||
|
assert.Equal(t, types.StateSnapshotNID(1), snapshotNID)
|
||||||
|
|
||||||
|
// Again, doesn't exist
|
||||||
|
_, _, err = tab.SelectLatestEventNIDs(ctx, nil, 1337)
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
eventNIDs, eventNID, snapshotNID, err := tab.SelectLatestEventsNIDsForUpdate(ctx, nil, wantRoomNID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, wantEventNIDs, eventNIDs)
|
||||||
|
assert.Equal(t, types.EventNID(3), eventNID)
|
||||||
|
assert.Equal(t, types.StateSnapshotNID(1), snapshotNID)
|
||||||
|
})
|
||||||
|
}
|
92
roomserver/storage/tables/state_block_table_test.go
Normal file
92
roomserver/storage/tables/state_block_table_test.go
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
package tables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustCreateStateBlockTable(t *testing.T, dbType test.DBType) (tab tables.StateBlock, close func()) {
|
||||||
|
t.Helper()
|
||||||
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
}, sqlutil.NewExclusiveWriter())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
switch dbType {
|
||||||
|
case test.DBTypePostgres:
|
||||||
|
err = postgres.CreateStateBlockTable(db)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
tab, err = postgres.PrepareStateBlockTable(db)
|
||||||
|
case test.DBTypeSQLite:
|
||||||
|
err = sqlite3.CreateStateBlockTable(db)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
tab, err = sqlite3.PrepareStateBlockTable(db)
|
||||||
|
}
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
return tab, close
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStateBlockTable(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
tab, close := mustCreateStateBlockTable(t, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
// generate some dummy data
|
||||||
|
var entries types.StateEntries
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
entry := types.StateEntry{
|
||||||
|
EventNID: types.EventNID(i),
|
||||||
|
}
|
||||||
|
entries = append(entries, entry)
|
||||||
|
}
|
||||||
|
stateBlockNID, err := tab.BulkInsertStateData(ctx, nil, entries)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, types.StateBlockNID(1), stateBlockNID)
|
||||||
|
|
||||||
|
// generate a different hash, to get a new StateBlockNID
|
||||||
|
var entries2 types.StateEntries
|
||||||
|
for i := 100; i < 300; i++ {
|
||||||
|
entry := types.StateEntry{
|
||||||
|
EventNID: types.EventNID(i),
|
||||||
|
}
|
||||||
|
entries2 = append(entries2, entry)
|
||||||
|
}
|
||||||
|
stateBlockNID, err = tab.BulkInsertStateData(ctx, nil, entries2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, types.StateBlockNID(2), stateBlockNID)
|
||||||
|
|
||||||
|
eventNIDs, err := tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{1, 2})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, len(entries), len(eventNIDs[0]))
|
||||||
|
assert.Equal(t, len(entries2), len(eventNIDs[1]))
|
||||||
|
|
||||||
|
// try to get a StateBlockNID which does not exist
|
||||||
|
_, err = tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{5})
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
// This should return an error, since we can only retrieve 1 StateBlock
|
||||||
|
_, err = tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{1, 5})
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
for i := 0; i < 65555; i++ {
|
||||||
|
entry := types.StateEntry{
|
||||||
|
EventNID: types.EventNID(i),
|
||||||
|
}
|
||||||
|
entries2 = append(entries2, entry)
|
||||||
|
}
|
||||||
|
stateBlockNID, err = tab.BulkInsertStateData(ctx, nil, entries2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, types.StateBlockNID(3), stateBlockNID)
|
||||||
|
})
|
||||||
|
}
|
86
roomserver/storage/tables/state_snapshot_table_test.go
Normal file
86
roomserver/storage/tables/state_snapshot_table_test.go
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
package tables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustCreateStateSnapshotTable(t *testing.T, dbType test.DBType) (tab tables.StateSnapshot, close func()) {
|
||||||
|
t.Helper()
|
||||||
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
}, sqlutil.NewExclusiveWriter())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
switch dbType {
|
||||||
|
case test.DBTypePostgres:
|
||||||
|
err = postgres.CreateStateSnapshotTable(db)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
tab, err = postgres.PrepareStateSnapshotTable(db)
|
||||||
|
case test.DBTypeSQLite:
|
||||||
|
err = sqlite3.CreateStateSnapshotTable(db)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
tab, err = sqlite3.PrepareStateSnapshotTable(db)
|
||||||
|
}
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
return tab, close
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStateSnapshotTable(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
tab, close := mustCreateStateSnapshotTable(t, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
// generate some dummy data
|
||||||
|
var stateBlockNIDs types.StateBlockNIDs
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
stateBlockNIDs = append(stateBlockNIDs, types.StateBlockNID(i))
|
||||||
|
}
|
||||||
|
stateNID, err := tab.InsertState(ctx, nil, 1, stateBlockNIDs)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, types.StateSnapshotNID(1), stateNID)
|
||||||
|
|
||||||
|
// verify ON CONFLICT; Note: this updates the sequence!
|
||||||
|
stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, types.StateSnapshotNID(1), stateNID)
|
||||||
|
|
||||||
|
// create a second snapshot
|
||||||
|
var stateBlockNIDs2 types.StateBlockNIDs
|
||||||
|
for i := 100; i < 150; i++ {
|
||||||
|
stateBlockNIDs2 = append(stateBlockNIDs2, types.StateBlockNID(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// StateSnapshotNID is now 3, since the DO UPDATE SET statement incremented the sequence
|
||||||
|
assert.Equal(t, types.StateSnapshotNID(3), stateNID)
|
||||||
|
|
||||||
|
nidLists, err := tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{1, 3})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, stateBlockNIDs, types.StateBlockNIDs(nidLists[0].StateBlockNIDs))
|
||||||
|
assert.Equal(t, stateBlockNIDs2, types.StateBlockNIDs(nidLists[1].StateBlockNIDs))
|
||||||
|
|
||||||
|
// check we get an error if the state snapshot does not exist
|
||||||
|
_, err = tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{2})
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
// create a second snapshot
|
||||||
|
for i := 0; i < 65555; i++ {
|
||||||
|
stateBlockNIDs2 = append(stateBlockNIDs2, types.StateBlockNID(i))
|
||||||
|
}
|
||||||
|
_, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
"golang.org/x/crypto/blake2b"
|
"golang.org/x/crypto/blake2b"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -97,6 +98,38 @@ func (a StateKeyTuple) LessThan(b StateKeyTuple) bool {
|
||||||
return a.EventStateKeyNID < b.EventStateKeyNID
|
return a.EventStateKeyNID < b.EventStateKeyNID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StateKeyTupleSorter []StateKeyTuple
|
||||||
|
|
||||||
|
func (s StateKeyTupleSorter) Len() int { return len(s) }
|
||||||
|
func (s StateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
|
||||||
|
func (s StateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||||
|
|
||||||
|
// Check whether a tuple is in the list. Assumes that the list is sorted.
|
||||||
|
func (s StateKeyTupleSorter) contains(value StateKeyTuple) bool {
|
||||||
|
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
|
||||||
|
return i < len(s) && s[i] == value
|
||||||
|
}
|
||||||
|
|
||||||
|
// List the unique eventTypeNIDs and eventStateKeyNIDs.
|
||||||
|
// Assumes that the list is sorted.
|
||||||
|
func (s StateKeyTupleSorter) TypesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) {
|
||||||
|
eventTypeNIDs = make([]int64, len(s))
|
||||||
|
eventStateKeyNIDs = make([]int64, len(s))
|
||||||
|
for i := range s {
|
||||||
|
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
|
||||||
|
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
|
||||||
|
}
|
||||||
|
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
|
||||||
|
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type int64Sorter []int64
|
||||||
|
|
||||||
|
func (s int64Sorter) Len() int { return len(s) }
|
||||||
|
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
|
||||||
|
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||||
|
|
||||||
// A StateEntry is an entry in the room state of a matrix room.
|
// A StateEntry is an entry in the room state of a matrix room.
|
||||||
type StateEntry struct {
|
type StateEntry struct {
|
||||||
StateKeyTuple
|
StateKeyTuple
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -24,3 +25,66 @@ func TestDeduplicateStateEntries(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStateKeyTupleSorter(t *testing.T) {
|
||||||
|
input := StateKeyTupleSorter{
|
||||||
|
{EventTypeNID: 1, EventStateKeyNID: 2},
|
||||||
|
{EventTypeNID: 1, EventStateKeyNID: 4},
|
||||||
|
{EventTypeNID: 2, EventStateKeyNID: 2},
|
||||||
|
{EventTypeNID: 1, EventStateKeyNID: 1},
|
||||||
|
}
|
||||||
|
want := []StateKeyTuple{
|
||||||
|
{EventTypeNID: 1, EventStateKeyNID: 1},
|
||||||
|
{EventTypeNID: 1, EventStateKeyNID: 2},
|
||||||
|
{EventTypeNID: 1, EventStateKeyNID: 4},
|
||||||
|
{EventTypeNID: 2, EventStateKeyNID: 2},
|
||||||
|
}
|
||||||
|
doNotWant := []StateKeyTuple{
|
||||||
|
{EventTypeNID: 0, EventStateKeyNID: 0},
|
||||||
|
{EventTypeNID: 1, EventStateKeyNID: 3},
|
||||||
|
{EventTypeNID: 2, EventStateKeyNID: 1},
|
||||||
|
{EventTypeNID: 3, EventStateKeyNID: 1},
|
||||||
|
}
|
||||||
|
wantTypeNIDs := []int64{1, 2}
|
||||||
|
wantStateKeyNIDs := []int64{1, 2, 4}
|
||||||
|
|
||||||
|
// Sort the input and check it's in the right order.
|
||||||
|
sort.Sort(input)
|
||||||
|
gotTypeNIDs, gotStateKeyNIDs := input.TypesAndStateKeysAsArrays()
|
||||||
|
|
||||||
|
for i := range want {
|
||||||
|
if input[i] != want[i] {
|
||||||
|
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
if !input.contains(want[i]) {
|
||||||
|
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range doNotWant {
|
||||||
|
if input.contains(doNotWant[i]) {
|
||||||
|
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(wantTypeNIDs) != len(gotTypeNIDs) {
|
||||||
|
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range wantTypeNIDs {
|
||||||
|
if wantTypeNIDs[i] != gotTypeNIDs[i] {
|
||||||
|
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
|
||||||
|
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range wantStateKeyNIDs {
|
||||||
|
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
|
||||||
|
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue