diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 86d226ce7..1302bbf27 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -264,9 +264,9 @@ func (s *eventStatements) BulkSelectStateEventByNID( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { - tuples := stateKeyTupleSorter(stateKeyTuples) + tuples := types.StateKeyTupleSorter(stateKeyTuples) sort.Sort(tuples) - eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() + eventTypeNIDArray, eventStateKeyNIDArray := tuples.TypesAndStateKeysAsArrays() stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt) rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray) if err != nil { diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index 6f8f9e1b5..73b04c92c 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -19,7 +19,6 @@ import ( "context" "database/sql" "fmt" - "sort" "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" @@ -141,35 +140,3 @@ func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array { } return pq.Int64Array(nids) } - -type stateKeyTupleSorter []types.StateKeyTuple - -func (s stateKeyTupleSorter) Len() int { return len(s) } -func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } -func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -// Check whether a tuple is in the list. Assumes that the list is sorted. -func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool { - i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) }) - return i < len(s) && s[i] == value -} - -// List the unique eventTypeNIDs and eventStateKeyNIDs. -// Assumes that the list is sorted. -func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) { - eventTypeNIDs = make(pq.Int64Array, len(s)) - eventStateKeyNIDs = make(pq.Int64Array, len(s)) - for i := range s { - eventTypeNIDs[i] = int64(s[i].EventTypeNID) - eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID) - } - eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))] - eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))] - return -} - -type int64Sorter []int64 - -func (s int64Sorter) Len() int { return len(s) } -func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] } -func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/storage/postgres/state_block_table_test.go b/roomserver/storage/postgres/state_block_table_test.go deleted file mode 100644 index a0e2ec952..000000000 --- a/roomserver/storage/postgres/state_block_table_test.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "sort" - "testing" - - "github.com/matrix-org/dendrite/roomserver/types" -) - -func TestStateKeyTupleSorter(t *testing.T) { - input := stateKeyTupleSorter{ - {EventTypeNID: 1, EventStateKeyNID: 2}, - {EventTypeNID: 1, EventStateKeyNID: 4}, - {EventTypeNID: 2, EventStateKeyNID: 2}, - {EventTypeNID: 1, EventStateKeyNID: 1}, - } - want := []types.StateKeyTuple{ - {EventTypeNID: 1, EventStateKeyNID: 1}, - {EventTypeNID: 1, EventStateKeyNID: 2}, - {EventTypeNID: 1, EventStateKeyNID: 4}, - {EventTypeNID: 2, EventStateKeyNID: 2}, - } - doNotWant := []types.StateKeyTuple{ - {EventTypeNID: 0, EventStateKeyNID: 0}, - {EventTypeNID: 1, EventStateKeyNID: 3}, - {EventTypeNID: 2, EventStateKeyNID: 1}, - {EventTypeNID: 3, EventStateKeyNID: 1}, - } - wantTypeNIDs := []int64{1, 2} - wantStateKeyNIDs := []int64{1, 2, 4} - - // Sort the input and check it's in the right order. - sort.Sort(input) - gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays() - - for i := range want { - if input[i] != want[i] { - t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i]) - } - - if !input.contains(want[i]) { - t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i]) - } - } - - for i := range doNotWant { - if input.contains(doNotWant[i]) { - t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i]) - } - } - - if len(wantTypeNIDs) != len(gotTypeNIDs) { - t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) - } - - for i := range wantTypeNIDs { - if wantTypeNIDs[i] != gotTypeNIDs[i] { - t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) - } - } - - if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) { - t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs) - } - - for i := range wantStateKeyNIDs { - if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] { - t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) - } - } -} diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index feb06150a..1dda34c36 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -247,9 +247,9 @@ func (s *eventStatements) BulkSelectStateEventByNID( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { - tuples := stateKeyTupleSorter(stateKeyTuples) + tuples := types.StateKeyTupleSorter(stateKeyTuples) sort.Sort(tuples) - eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() + eventTypeNIDArray, eventStateKeyNIDArray := tuples.TypesAndStateKeysAsArrays() params := make([]interface{}, 0, len(eventNIDs)+len(eventTypeNIDArray)+len(eventStateKeyNIDArray)) selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1) for _, v := range eventNIDs { diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 3c829cdcd..5f9b646bd 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -20,7 +20,6 @@ import ( "database/sql" "encoding/json" "fmt" - "sort" "strings" "github.com/matrix-org/dendrite/internal" @@ -142,35 +141,3 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries( } return results, err } - -type stateKeyTupleSorter []types.StateKeyTuple - -func (s stateKeyTupleSorter) Len() int { return len(s) } -func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } -func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -// Check whether a tuple is in the list. Assumes that the list is sorted. -func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool { - i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) }) - return i < len(s) && s[i] == value -} - -// List the unique eventTypeNIDs and eventStateKeyNIDs. -// Assumes that the list is sorted. -func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) { - eventTypeNIDs = make([]int64, len(s)) - eventStateKeyNIDs = make([]int64, len(s)) - for i := range s { - eventTypeNIDs[i] = int64(s[i].EventTypeNID) - eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID) - } - eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))] - eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))] - return -} - -type int64Sorter []int64 - -func (s int64Sorter) Len() int { return len(s) } -func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] } -func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/storage/sqlite3/state_block_table_test.go b/roomserver/storage/sqlite3/state_block_table_test.go deleted file mode 100644 index 98439f5c0..000000000 --- a/roomserver/storage/sqlite3/state_block_table_test.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "sort" - "testing" - - "github.com/matrix-org/dendrite/roomserver/types" -) - -func TestStateKeyTupleSorter(t *testing.T) { - input := stateKeyTupleSorter{ - {EventTypeNID: 1, EventStateKeyNID: 2}, - {EventTypeNID: 1, EventStateKeyNID: 4}, - {EventTypeNID: 2, EventStateKeyNID: 2}, - {EventTypeNID: 1, EventStateKeyNID: 1}, - } - want := []types.StateKeyTuple{ - {EventTypeNID: 1, EventStateKeyNID: 1}, - {EventTypeNID: 1, EventStateKeyNID: 2}, - {EventTypeNID: 1, EventStateKeyNID: 4}, - {EventTypeNID: 2, EventStateKeyNID: 2}, - } - doNotWant := []types.StateKeyTuple{ - {EventTypeNID: 0, EventStateKeyNID: 0}, - {EventTypeNID: 1, EventStateKeyNID: 3}, - {EventTypeNID: 2, EventStateKeyNID: 1}, - {EventTypeNID: 3, EventStateKeyNID: 1}, - } - wantTypeNIDs := []int64{1, 2} - wantStateKeyNIDs := []int64{1, 2, 4} - - // Sort the input and check it's in the right order. - sort.Sort(input) - gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays() - - for i := range want { - if input[i] != want[i] { - t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i]) - } - - if !input.contains(want[i]) { - t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i]) - } - } - - for i := range doNotWant { - if input.contains(doNotWant[i]) { - t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i]) - } - } - - if len(wantTypeNIDs) != len(gotTypeNIDs) { - t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) - } - - for i := range wantTypeNIDs { - if wantTypeNIDs[i] != gotTypeNIDs[i] { - t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) - } - } - - if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) { - t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs) - } - - for i := range wantStateKeyNIDs { - if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] { - t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) - } - } -} diff --git a/roomserver/types/types.go b/roomserver/types/types.go index 65fbee04e..3f0c4d45e 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -19,7 +19,9 @@ import ( "encoding/json" "sort" + "github.com/lib/pq" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" "golang.org/x/crypto/blake2b" ) @@ -96,6 +98,38 @@ func (a StateKeyTuple) LessThan(b StateKeyTuple) bool { return a.EventStateKeyNID < b.EventStateKeyNID } +type StateKeyTupleSorter []StateKeyTuple + +func (s StateKeyTupleSorter) Len() int { return len(s) } +func (s StateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } +func (s StateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// Check whether a tuple is in the list. Assumes that the list is sorted. +func (s StateKeyTupleSorter) contains(value StateKeyTuple) bool { + i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) }) + return i < len(s) && s[i] == value +} + +// List the unique eventTypeNIDs and eventStateKeyNIDs. +// Assumes that the list is sorted. +func (s StateKeyTupleSorter) TypesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) { + eventTypeNIDs = make(pq.Int64Array, len(s)) + eventStateKeyNIDs = make(pq.Int64Array, len(s)) + for i := range s { + eventTypeNIDs[i] = int64(s[i].EventTypeNID) + eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID) + } + eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))] + eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))] + return +} + +type int64Sorter []int64 + +func (s int64Sorter) Len() int { return len(s) } +func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] } +func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + // A StateEntry is an entry in the room state of a matrix room. type StateEntry struct { StateKeyTuple diff --git a/roomserver/types/types_test.go b/roomserver/types/types_test.go index b1e84b821..a26b80f74 100644 --- a/roomserver/types/types_test.go +++ b/roomserver/types/types_test.go @@ -1,6 +1,7 @@ package types import ( + "sort" "testing" ) @@ -24,3 +25,66 @@ func TestDeduplicateStateEntries(t *testing.T) { } } } + +func TestStateKeyTupleSorter(t *testing.T) { + input := StateKeyTupleSorter{ + {EventTypeNID: 1, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 4}, + {EventTypeNID: 2, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 1}, + } + want := []StateKeyTuple{ + {EventTypeNID: 1, EventStateKeyNID: 1}, + {EventTypeNID: 1, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 4}, + {EventTypeNID: 2, EventStateKeyNID: 2}, + } + doNotWant := []StateKeyTuple{ + {EventTypeNID: 0, EventStateKeyNID: 0}, + {EventTypeNID: 1, EventStateKeyNID: 3}, + {EventTypeNID: 2, EventStateKeyNID: 1}, + {EventTypeNID: 3, EventStateKeyNID: 1}, + } + wantTypeNIDs := []int64{1, 2} + wantStateKeyNIDs := []int64{1, 2, 4} + + // Sort the input and check it's in the right order. + sort.Sort(input) + gotTypeNIDs, gotStateKeyNIDs := input.TypesAndStateKeysAsArrays() + + for i := range want { + if input[i] != want[i] { + t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i]) + } + + if !input.contains(want[i]) { + t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i]) + } + } + + for i := range doNotWant { + if input.contains(doNotWant[i]) { + t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i]) + } + } + + if len(wantTypeNIDs) != len(gotTypeNIDs) { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + + for i := range wantTypeNIDs { + if wantTypeNIDs[i] != gotTypeNIDs[i] { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + } + + if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) { + t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs) + } + + for i := range wantStateKeyNIDs { + if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + } +}