mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-22 14:21:55 -06:00
Include the requested current state alongside the latest events in the query API. (#30)
* Return the requested portions of current state in the query API * Use Unique from github.com/matrix-org/util * rewrite bulkSelectFilteredStateBlockEntries to use append for clarity * Add test for stateKeyTupleSorter * Replace current with a new StateEntryList rather than individually setting the fields
This commit is contained in:
parent
e82090e277
commit
e667f17e14
|
@ -16,9 +16,6 @@ type RoomEventDatabase interface {
|
||||||
// Returns an error if the there is an error talking to the database
|
// Returns an error if the there is an error talking to the database
|
||||||
// or if the event IDs aren't in the database.
|
// or if the event IDs aren't in the database.
|
||||||
StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error)
|
StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error)
|
||||||
// Lookup the numeric IDs for a list of string event state keys.
|
|
||||||
// Returns a map from string state key to numeric ID for the state key.
|
|
||||||
EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
|
||||||
// Lookup the Events for a list of numeric event IDs.
|
// Lookup the Events for a list of numeric event IDs.
|
||||||
// Returns a sorted list of events.
|
// Returns a sorted list of events.
|
||||||
Events(eventNIDs []types.EventNID) ([]types.Event, error)
|
Events(eventNIDs []types.EventNID) ([]types.Event, error)
|
||||||
|
|
|
@ -3,6 +3,7 @@ package query
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/state"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
@ -12,13 +13,17 @@ import (
|
||||||
|
|
||||||
// RoomserverQueryAPIDatabase has the storage APIs needed to implement the query API.
|
// RoomserverQueryAPIDatabase has the storage APIs needed to implement the query API.
|
||||||
type RoomserverQueryAPIDatabase interface {
|
type RoomserverQueryAPIDatabase interface {
|
||||||
|
state.RoomStateDatabase
|
||||||
// Lookup the numeric ID for the room.
|
// Lookup the numeric ID for the room.
|
||||||
// Returns 0 if the room doesn't exists.
|
// Returns 0 if the room doesn't exists.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
RoomNID(roomID string) (types.RoomNID, error)
|
RoomNID(roomID string) (types.RoomNID, error)
|
||||||
// Lookup event references for the latest events in the room.
|
// Lookup event references for the latest events in the room and the current state snapshot.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, error)
|
LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, error)
|
||||||
|
// Lookup the Events for a list of numeric event IDs.
|
||||||
|
// Returns a list of events sorted by numeric event ID.
|
||||||
|
Events(eventNIDs []types.EventNID) ([]types.Event, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoomserverQueryAPI is an implementation of RoomserverQueryAPI
|
// RoomserverQueryAPI is an implementation of RoomserverQueryAPI
|
||||||
|
@ -40,9 +45,33 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
response.RoomExists = true
|
response.RoomExists = true
|
||||||
response.LatestEvents, err = r.DB.LatestEventIDs(roomNID)
|
var currentStateSnapshotNID types.StateSnapshotNID
|
||||||
// TODO: look up the current state.
|
response.LatestEvents, currentStateSnapshotNID, err = r.DB.LatestEventIDs(roomNID)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup the currrent state for the requested tuples.
|
||||||
|
stateEntries, err := state.LoadStateAtSnapshotForStringTuples(r.DB, currentStateSnapshotNID, request.StateToFetch)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
eventNIDs := make([]types.EventNID, len(stateEntries))
|
||||||
|
for i := range stateEntries {
|
||||||
|
eventNIDs[i] = stateEntries[i].EventNID
|
||||||
|
}
|
||||||
|
|
||||||
|
stateEvents, err := r.DB.Events(eventNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
response.StateEvents = make([]gomatrixserverlib.Event, len(stateEvents))
|
||||||
|
for i := range stateEvents {
|
||||||
|
response.StateEvents[i] = stateEvents[i].Event
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux.
|
// SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux.
|
||||||
|
|
|
@ -365,7 +365,12 @@ func main() {
|
||||||
testRoomserver(input, want, func(q api.RoomserverQueryAPI) {
|
testRoomserver(input, want, func(q api.RoomserverQueryAPI) {
|
||||||
var response api.QueryLatestEventsAndStateResponse
|
var response api.QueryLatestEventsAndStateResponse
|
||||||
if err := q.QueryLatestEventsAndState(
|
if err := q.QueryLatestEventsAndState(
|
||||||
&api.QueryLatestEventsAndStateRequest{RoomID: "!HCXfdvrfksxuYnIFiJ:matrix.org"},
|
&api.QueryLatestEventsAndStateRequest{
|
||||||
|
RoomID: "!HCXfdvrfksxuYnIFiJ:matrix.org",
|
||||||
|
StateToFetch: []api.StateKeyTuple{
|
||||||
|
{"m.room.member", "@richvdh:matrix.org"},
|
||||||
|
},
|
||||||
|
},
|
||||||
&response,
|
&response,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
@ -376,6 +381,9 @@ func main() {
|
||||||
if len(response.LatestEvents) != 1 || response.LatestEvents[0].EventID != "$1463671339126270PnVwC:matrix.org" {
|
if len(response.LatestEvents) != 1 || response.LatestEvents[0].EventID != "$1463671339126270PnVwC:matrix.org" {
|
||||||
panic(fmt.Errorf(`Wanted "$1463671339126270PnVwC:matrix.org" to be the latest event got %#v`, response.LatestEvents))
|
panic(fmt.Errorf(`Wanted "$1463671339126270PnVwC:matrix.org" to be the latest event got %#v`, response.LatestEvents))
|
||||||
}
|
}
|
||||||
|
if len(response.StateEvents) != 1 || response.StateEvents[0].EventID() != "$1463671339126270PnVwC:matrix.org" {
|
||||||
|
panic(fmt.Errorf(`Wanted "$1463671339126270PnVwC:matrix.org" to be the state event got %#v`, response.StateEvents))
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
fmt.Println("==PASSED==", os.Args[0])
|
fmt.Println("==PASSED==", os.Args[0])
|
||||||
|
|
|
@ -4,6 +4,7 @@ package state
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"sort"
|
"sort"
|
||||||
|
@ -11,12 +12,25 @@ import (
|
||||||
|
|
||||||
// A RoomStateDatabase has the storage APIs needed to load state from the database
|
// A RoomStateDatabase has the storage APIs needed to load state from the database
|
||||||
type RoomStateDatabase interface {
|
type RoomStateDatabase interface {
|
||||||
|
// Lookup the numeric IDs for a list of string event types.
|
||||||
|
// Returns a map from string event type to numeric ID for the event type.
|
||||||
|
EventTypeNIDs(eventTypes []string) (map[string]types.EventTypeNID, error)
|
||||||
|
// Lookup the numeric IDs for a list of string event state keys.
|
||||||
|
// Returns a map from string state key to numeric ID for the state key.
|
||||||
|
EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
||||||
// Lookup the numeric state data IDs for each numeric state snapshot ID
|
// Lookup the numeric state data IDs for each numeric state snapshot ID
|
||||||
// The returned slice is sorted by numeric state snapshot ID.
|
// The returned slice is sorted by numeric state snapshot ID.
|
||||||
StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
||||||
// Lookup the state data for each numeric state data ID
|
// Lookup the state data for each numeric state data ID
|
||||||
// The returned slice is sorted by numeric state data ID.
|
// The returned slice is sorted by numeric state data ID.
|
||||||
StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
|
StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
|
||||||
|
// Lookup the state data for the state key tuples for each numeric state block ID
|
||||||
|
// This is used to fetch a subset of the room state at a snapshot.
|
||||||
|
// If a block doesn't contain any of the requested tuples then it can be discarded from the result.
|
||||||
|
// The returned slice is sorted by numeric state block ID.
|
||||||
|
StateEntriesForTuples(stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) (
|
||||||
|
[]types.StateEntryList, error,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadStateAtSnapshot loads the full state of a room at a particular snapshot.
|
// LoadStateAtSnapshot loads the full state of a room at a particular snapshot.
|
||||||
|
@ -27,6 +41,7 @@ func LoadStateAtSnapshot(db RoomStateDatabase, stateNID types.StateSnapshotNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
|
||||||
stateBlockNIDList := stateBlockNIDLists[0]
|
stateBlockNIDList := stateBlockNIDLists[0]
|
||||||
|
|
||||||
stateEntryLists, err := db.StateEntries(stateBlockNIDList.StateBlockNIDs)
|
stateEntryLists, err := db.StateEntries(stateBlockNIDList.StateBlockNIDs)
|
||||||
|
@ -35,7 +50,7 @@ func LoadStateAtSnapshot(db RoomStateDatabase, stateNID types.StateSnapshotNID)
|
||||||
}
|
}
|
||||||
stateEntriesMap := stateEntryListMap(stateEntryLists)
|
stateEntriesMap := stateEntryListMap(stateEntryLists)
|
||||||
|
|
||||||
// Combined all the state entries for this snapshot.
|
// Combine all the state entries for this snapshot.
|
||||||
// The order of state block NIDs in the list tells us the order to combine them in.
|
// The order of state block NIDs in the list tells us the order to combine them in.
|
||||||
var fullState []types.StateEntry
|
var fullState []types.StateEntry
|
||||||
for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
|
for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
|
||||||
|
@ -98,7 +113,7 @@ func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.State
|
||||||
panic(fmt.Errorf("Corrupt DB: Missing state snapshot numeric ID %d", prevState.BeforeStateSnapshotNID))
|
panic(fmt.Errorf("Corrupt DB: Missing state snapshot numeric ID %d", prevState.BeforeStateSnapshotNID))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Combined all the state entries for this snapshot.
|
// Combine all the state entries for this snapshot.
|
||||||
// The order of state block NIDs in the list tells us the order to combine them in.
|
// The order of state block NIDs in the list tells us the order to combine them in.
|
||||||
var fullState []types.StateEntry
|
var fullState []types.StateEntry
|
||||||
for _, stateBlockNID := range stateBlockNIDs {
|
for _, stateBlockNID := range stateBlockNIDs {
|
||||||
|
@ -182,6 +197,100 @@ func DifferenceBetweeenStateSnapshots(db RoomStateDatabase, oldStateNID, newStat
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// stringTuplesToNumericTuples converts the string state key tuples into numeric IDs
|
||||||
|
// If there isn't a numeric ID for either the event type or the event state key then the tuple is discarded.
|
||||||
|
// Returns an error if there was a problem talking to the database.
|
||||||
|
func stringTuplesToNumericTuples(db RoomStateDatabase, stringTuples []api.StateKeyTuple) ([]types.StateKeyTuple, error) {
|
||||||
|
eventTypes := make([]string, len(stringTuples))
|
||||||
|
stateKeys := make([]string, len(stringTuples))
|
||||||
|
for i := range stringTuples {
|
||||||
|
eventTypes[i] = stringTuples[i].EventType
|
||||||
|
stateKeys[i] = stringTuples[i].EventStateKey
|
||||||
|
}
|
||||||
|
eventTypes = util.UniqueStrings(eventTypes)
|
||||||
|
eventTypeMap, err := db.EventTypeNIDs(eventTypes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stateKeys = util.UniqueStrings(stateKeys)
|
||||||
|
stateKeyMap, err := db.EventStateKeyNIDs(stateKeys)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var result []types.StateKeyTuple
|
||||||
|
for _, stringTuple := range stringTuples {
|
||||||
|
var numericTuple types.StateKeyTuple
|
||||||
|
var ok1, ok2 bool
|
||||||
|
numericTuple.EventTypeNID, ok1 = eventTypeMap[stringTuple.EventType]
|
||||||
|
numericTuple.EventStateKeyNID, ok2 = stateKeyMap[stringTuple.EventStateKey]
|
||||||
|
// Discard the tuple if there wasn't a numeric ID for either the event type or the state key.
|
||||||
|
if ok1 && ok2 {
|
||||||
|
result = append(result, numericTuple)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadStateAtSnapshotForStringTuples loads the state for a list of event type and state key pairs at a snapshot.
|
||||||
|
// This is used when we only want to load a subset of the room state at a snapshot.
|
||||||
|
// If there is no entry for a given event type and state key pair then it will be discarded.
|
||||||
|
// This is typically the state before an event or the current state of a room.
|
||||||
|
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
|
||||||
|
func LoadStateAtSnapshotForStringTuples(
|
||||||
|
db RoomStateDatabase, stateNID types.StateSnapshotNID, stateKeyTuples []api.StateKeyTuple,
|
||||||
|
) ([]types.StateEntry, error) {
|
||||||
|
numericTuples, err := stringTuplesToNumericTuples(db, stateKeyTuples)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return loadStateAtSnapshotForNumericTuples(db, stateNID, numericTuples)
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadStateAtSnapshotForNumericTuples loads the state for a list of event type and state key pairs at a snapshot.
|
||||||
|
// This is used when we only want to load a subset of the room state at a snapshot.
|
||||||
|
// If there is no entry for a given event type and state key pair then it will be discarded.
|
||||||
|
// This is typically the state before an event or the current state of a room.
|
||||||
|
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
|
||||||
|
func loadStateAtSnapshotForNumericTuples(
|
||||||
|
db RoomStateDatabase, stateNID types.StateSnapshotNID, stateKeyTuples []types.StateKeyTuple,
|
||||||
|
) ([]types.StateEntry, error) {
|
||||||
|
stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{stateNID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
|
||||||
|
stateBlockNIDList := stateBlockNIDLists[0]
|
||||||
|
|
||||||
|
stateEntryLists, err := db.StateEntriesForTuples(stateBlockNIDList.StateBlockNIDs, stateKeyTuples)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stateEntriesMap := stateEntryListMap(stateEntryLists)
|
||||||
|
|
||||||
|
// Combine all the state entries for this snapshot.
|
||||||
|
// The order of state block NIDs in the list tells us the order to combine them in.
|
||||||
|
var fullState []types.StateEntry
|
||||||
|
for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
|
||||||
|
entries, ok := stateEntriesMap.lookup(stateBlockNID)
|
||||||
|
if !ok {
|
||||||
|
// If the block is missing from the map it means that none of its entries matched a requested tuple.
|
||||||
|
// This can happen if the block doesn't contain an update for one of the requested tuples.
|
||||||
|
// If none of the requested tuples are in the block then it can be safely skipped.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fullState = append(fullState, entries...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stable sort so that the most recent entry for each state key stays
|
||||||
|
// remains later in the list than the older entries for the same state key.
|
||||||
|
sort.Stable(stateEntryByStateKeySorter(fullState))
|
||||||
|
// Unique returns the last entry and hence the most recent entry for each state key.
|
||||||
|
fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))]
|
||||||
|
return fullState, nil
|
||||||
|
}
|
||||||
|
|
||||||
type stateBlockNIDListMap []types.StateBlockNIDList
|
type stateBlockNIDListMap []types.StateBlockNIDList
|
||||||
|
|
||||||
func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) {
|
func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) {
|
||||||
|
|
|
@ -2,6 +2,7 @@ package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"github.com/lib/pq"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -66,9 +67,16 @@ const insertEventTypeNIDSQL = "" +
|
||||||
const selectEventTypeNIDSQL = "" +
|
const selectEventTypeNIDSQL = "" +
|
||||||
"SELECT event_type_nid FROM event_types WHERE event_type = $1"
|
"SELECT event_type_nid FROM event_types WHERE event_type = $1"
|
||||||
|
|
||||||
|
// Bulk lookup from string event type to numeric ID for that event type.
|
||||||
|
// Takes an array of strings as the query parameter.
|
||||||
|
const bulkSelectEventTypeNIDSQL = "" +
|
||||||
|
"SELECT event_type, event_type_nid FROM event_types" +
|
||||||
|
" WHERE event_type = ANY($1)"
|
||||||
|
|
||||||
type eventTypeStatements struct {
|
type eventTypeStatements struct {
|
||||||
insertEventTypeNIDStmt *sql.Stmt
|
insertEventTypeNIDStmt *sql.Stmt
|
||||||
selectEventTypeNIDStmt *sql.Stmt
|
selectEventTypeNIDStmt *sql.Stmt
|
||||||
|
bulkSelectEventTypeNIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventTypeStatements) prepare(db *sql.DB) (err error) {
|
func (s *eventTypeStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
@ -80,6 +88,7 @@ func (s *eventTypeStatements) prepare(db *sql.DB) (err error) {
|
||||||
return statementList{
|
return statementList{
|
||||||
{&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL},
|
{&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL},
|
||||||
{&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL},
|
{&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL},
|
||||||
|
{&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL},
|
||||||
}.prepare(db)
|
}.prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,3 +103,22 @@ func (s *eventTypeStatements) selectEventTypeNID(eventType string) (types.EventT
|
||||||
err := s.selectEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID)
|
err := s.selectEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID)
|
||||||
return types.EventTypeNID(eventTypeNID), err
|
return types.EventTypeNID(eventTypeNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *eventTypeStatements) bulkSelectEventTypeNID(eventTypes []string) (map[string]types.EventTypeNID, error) {
|
||||||
|
rows, err := s.bulkSelectEventTypeNIDStmt.Query(pq.StringArray(eventTypes))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
result := make(map[string]types.EventTypeNID, len(eventTypes))
|
||||||
|
for rows.Next() {
|
||||||
|
var eventType string
|
||||||
|
var eventTypeNID int64
|
||||||
|
if err := rows.Scan(&eventType, &eventTypeNID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result[eventType] = types.EventTypeNID(eventTypeNID)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
|
@ -35,7 +35,7 @@ const selectRoomNIDSQL = "" +
|
||||||
"SELECT room_nid FROM rooms WHERE room_id = $1"
|
"SELECT room_nid FROM rooms WHERE room_id = $1"
|
||||||
|
|
||||||
const selectLatestEventNIDsSQL = "" +
|
const selectLatestEventNIDsSQL = "" +
|
||||||
"SELECT latest_event_nids FROM rooms WHERE room_nid = $1"
|
"SELECT latest_event_nids, state_snapshot_nid FROM rooms WHERE room_nid = $1"
|
||||||
|
|
||||||
const selectLatestEventNIDsForUpdateSQL = "" +
|
const selectLatestEventNIDsForUpdateSQL = "" +
|
||||||
"SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM rooms WHERE room_nid = $1 FOR UPDATE"
|
"SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM rooms WHERE room_nid = $1 FOR UPDATE"
|
||||||
|
@ -77,17 +77,18 @@ func (s *roomStatements) selectRoomNID(roomID string) (types.RoomNID, error) {
|
||||||
return types.RoomNID(roomNID), err
|
return types.RoomNID(roomNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) selectLatestEventNIDs(roomNID types.RoomNID) ([]types.EventNID, error) {
|
func (s *roomStatements) selectLatestEventNIDs(roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) {
|
||||||
var nids pq.Int64Array
|
var nids pq.Int64Array
|
||||||
err := s.selectLatestEventNIDsStmt.QueryRow(int64(roomNID)).Scan(&nids)
|
var stateSnapshotNID int64
|
||||||
|
err := s.selectLatestEventNIDsStmt.QueryRow(int64(roomNID)).Scan(&nids, &stateSnapshotNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
eventNIDs := make([]types.EventNID, len(nids))
|
eventNIDs := make([]types.EventNID, len(nids))
|
||||||
for i := range nids {
|
for i := range nids {
|
||||||
eventNIDs[i] = types.EventNID(nids[i])
|
eventNIDs[i] = types.EventNID(nids[i])
|
||||||
}
|
}
|
||||||
return eventNIDs, nil
|
return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID types.RoomNID) (
|
func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID types.RoomNID) (
|
||||||
|
|
|
@ -5,6 +5,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
"sort"
|
||||||
)
|
)
|
||||||
|
|
||||||
const stateDataSchema = `
|
const stateDataSchema = `
|
||||||
|
@ -35,21 +37,35 @@ const insertStateDataSQL = "" +
|
||||||
const selectNextStateBlockNIDSQL = "" +
|
const selectNextStateBlockNIDSQL = "" +
|
||||||
"SELECT nextval('state_block_nid_seq')"
|
"SELECT nextval('state_block_nid_seq')"
|
||||||
|
|
||||||
// Bulk state lookup by numeric event ID.
|
// Bulk state lookup by numeric state block ID.
|
||||||
// Sort by the state_block_nid, event_type_nid, event_state_key_nid
|
// Sort by the state_block_nid, event_type_nid, event_state_key_nid
|
||||||
// This means that all the entries for a given state_block_nid will appear
|
// This means that all the entries for a given state_block_nid will appear
|
||||||
// together in the list and those entries will sorted by event_type_nid
|
// together in the list and those entries will sorted by event_type_nid
|
||||||
// and event_state_key_nid. This property makes it easier to merge two
|
// and event_state_key_nid. This property makes it easier to merge two
|
||||||
// state data blocks together.
|
// state data blocks together.
|
||||||
const bulkSelectStateDataEntriesSQL = "" +
|
const bulkSelectStateBlockEntriesSQL = "" +
|
||||||
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
||||||
" FROM state_block WHERE state_block_nid = ANY($1)" +
|
" FROM state_block WHERE state_block_nid = ANY($1)" +
|
||||||
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
||||||
|
|
||||||
|
// Bulk state lookup by numeric state block ID.
|
||||||
|
// Filters the rows in each block to the requested types and state keys.
|
||||||
|
// We would like to restrict to particular type state key pairs but we are
|
||||||
|
// restricted by the query language to pull the cross product of a list
|
||||||
|
// of types and a list state_keys. So we have to filter the result in the
|
||||||
|
// application to restrict it to the list of event types and state keys we
|
||||||
|
// actually wanted.
|
||||||
|
const bulkSelectFilteredStateBlockEntriesSQL = "" +
|
||||||
|
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
||||||
|
" FROM state_block WHERE state_block_nid = ANY($1)" +
|
||||||
|
" AND event_type_nid = ANY($2) AND event_state_key_nid = ANY($3)" +
|
||||||
|
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
||||||
|
|
||||||
type stateBlockStatements struct {
|
type stateBlockStatements struct {
|
||||||
insertStateDataStmt *sql.Stmt
|
insertStateDataStmt *sql.Stmt
|
||||||
selectNextStateBlockNIDStmt *sql.Stmt
|
selectNextStateBlockNIDStmt *sql.Stmt
|
||||||
bulkSelectStateDataEntriesStmt *sql.Stmt
|
bulkSelectStateBlockEntriesStmt *sql.Stmt
|
||||||
|
bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateBlockStatements) prepare(db *sql.DB) (err error) {
|
func (s *stateBlockStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
@ -61,7 +77,8 @@ func (s *stateBlockStatements) prepare(db *sql.DB) (err error) {
|
||||||
return statementList{
|
return statementList{
|
||||||
{&s.insertStateDataStmt, insertStateDataSQL},
|
{&s.insertStateDataStmt, insertStateDataSQL},
|
||||||
{&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
|
{&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
|
||||||
{&s.bulkSelectStateDataEntriesStmt, bulkSelectStateDataEntriesSQL},
|
{&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL},
|
||||||
|
{&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL},
|
||||||
}.prepare(db)
|
}.prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,12 +103,12 @@ func (s *stateBlockStatements) selectNextStateBlockNID() (types.StateBlockNID, e
|
||||||
return types.StateBlockNID(stateBlockNID), err
|
return types.StateBlockNID(stateBlockNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateBlockStatements) bulkSelectStateDataEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) {
|
func (s *stateBlockStatements) bulkSelectStateBlockEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) {
|
||||||
nids := make([]int64, len(stateBlockNIDs))
|
nids := make([]int64, len(stateBlockNIDs))
|
||||||
for i := range stateBlockNIDs {
|
for i := range stateBlockNIDs {
|
||||||
nids[i] = int64(stateBlockNIDs[i])
|
nids[i] = int64(stateBlockNIDs[i])
|
||||||
}
|
}
|
||||||
rows, err := s.bulkSelectStateDataEntriesStmt.Query(pq.Int64Array(nids))
|
rows, err := s.bulkSelectStateBlockEntriesStmt.Query(pq.Int64Array(nids))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -131,3 +148,103 @@ func (s *stateBlockStatements) bulkSelectStateDataEntries(stateBlockNIDs []types
|
||||||
}
|
}
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
|
||||||
|
stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple,
|
||||||
|
) ([]types.StateEntryList, error) {
|
||||||
|
tuples := stateKeyTupleSorter(stateKeyTuples)
|
||||||
|
// Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
|
||||||
|
sort.Sort(tuples)
|
||||||
|
|
||||||
|
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
||||||
|
rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.Query(
|
||||||
|
stateBlockNIDsAsArray(stateBlockNIDs), eventTypeNIDArray, eventStateKeyNIDArray,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var results []types.StateEntryList
|
||||||
|
var current types.StateEntryList
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
stateBlockNID int64
|
||||||
|
eventTypeNID int64
|
||||||
|
eventStateKeyNID int64
|
||||||
|
eventNID int64
|
||||||
|
entry types.StateEntry
|
||||||
|
)
|
||||||
|
if err := rows.Scan(
|
||||||
|
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
|
||||||
|
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
|
||||||
|
entry.EventNID = types.EventNID(eventNID)
|
||||||
|
|
||||||
|
// We can use binary search here because we sorted the tuples earlier
|
||||||
|
if !tuples.contains(entry.StateKeyTuple) {
|
||||||
|
// The select will return the cross product of types and state keys.
|
||||||
|
// So we need to check if type of the entry is in the list.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
|
||||||
|
// The state entry row is for a different state data block to the current one.
|
||||||
|
// So we append the current entry to the results and start adding to a new one.
|
||||||
|
// The first time through the loop current will be empty.
|
||||||
|
if current.StateEntries != nil {
|
||||||
|
results = append(results, current)
|
||||||
|
}
|
||||||
|
current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)}
|
||||||
|
}
|
||||||
|
current.StateEntries = append(current.StateEntries, entry)
|
||||||
|
}
|
||||||
|
// Add the last entry to the list if it is not empty.
|
||||||
|
if current.StateEntries != nil {
|
||||||
|
results = append(results, current)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {
|
||||||
|
nids := make([]int64, len(stateBlockNIDs))
|
||||||
|
for i := range stateBlockNIDs {
|
||||||
|
nids[i] = int64(stateBlockNIDs[i])
|
||||||
|
}
|
||||||
|
return pq.Int64Array(nids)
|
||||||
|
}
|
||||||
|
|
||||||
|
type stateKeyTupleSorter []types.StateKeyTuple
|
||||||
|
|
||||||
|
func (s stateKeyTupleSorter) Len() int { return len(s) }
|
||||||
|
func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
|
||||||
|
func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||||
|
|
||||||
|
// Check whether a tuple is in the list. Assumes that the list is sorted.
|
||||||
|
func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
|
||||||
|
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
|
||||||
|
return i < len(s) && s[i] == value
|
||||||
|
}
|
||||||
|
|
||||||
|
// List the unique eventTypeNIDs and eventStateKeyNIDs.
|
||||||
|
// Assumes that the list is sorted.
|
||||||
|
func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) {
|
||||||
|
eventTypeNIDs = make(pq.Int64Array, len(s))
|
||||||
|
eventStateKeyNIDs = make(pq.Int64Array, len(s))
|
||||||
|
for i := range s {
|
||||||
|
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
|
||||||
|
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
|
||||||
|
}
|
||||||
|
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
|
||||||
|
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type int64Sorter []int64
|
||||||
|
|
||||||
|
func (s int64Sorter) Len() int { return len(s) }
|
||||||
|
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
|
||||||
|
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStateKeyTupleSorter(t *testing.T) {
|
||||||
|
input := stateKeyTupleSorter{
|
||||||
|
{1, 2},
|
||||||
|
{1, 4},
|
||||||
|
{2, 2},
|
||||||
|
{1, 1},
|
||||||
|
}
|
||||||
|
want := []types.StateKeyTuple{
|
||||||
|
{1, 1},
|
||||||
|
{1, 2},
|
||||||
|
{1, 4},
|
||||||
|
{2, 2},
|
||||||
|
}
|
||||||
|
doNotWant := []types.StateKeyTuple{
|
||||||
|
{0, 0},
|
||||||
|
{1, 3},
|
||||||
|
{2, 1},
|
||||||
|
{3, 1},
|
||||||
|
}
|
||||||
|
wantTypeNIDs := []int64{1, 2}
|
||||||
|
wantStateKeyNIDs := []int64{1, 2, 4}
|
||||||
|
|
||||||
|
// Sort the input and check it's in the right order.
|
||||||
|
sort.Sort(input)
|
||||||
|
gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
|
||||||
|
|
||||||
|
for i := range want {
|
||||||
|
if input[i] != want[i] {
|
||||||
|
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
if !input.contains(want[i]) {
|
||||||
|
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range doNotWant {
|
||||||
|
if input.contains(doNotWant[i]) {
|
||||||
|
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(wantTypeNIDs) != len(gotTypeNIDs) {
|
||||||
|
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range wantTypeNIDs {
|
||||||
|
if wantTypeNIDs[i] != gotTypeNIDs[i] {
|
||||||
|
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
|
||||||
|
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range wantStateKeyNIDs {
|
||||||
|
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
|
||||||
|
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -145,7 +145,12 @@ func (d *Database) StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntr
|
||||||
return d.statements.bulkSelectStateEventByID(eventIDs)
|
return d.statements.bulkSelectStateEventByID(eventIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventStateKeyNIDs implements input.EventDatabase
|
// EventTypeNIDs implements state.RoomStateDatabase
|
||||||
|
func (d *Database) EventTypeNIDs(eventTypes []string) (map[string]types.EventTypeNID, error) {
|
||||||
|
return d.statements.bulkSelectEventTypeNID(eventTypes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventStateKeyNIDs implements state.RoomStateDatabase
|
||||||
func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) {
|
func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) {
|
||||||
return d.statements.bulkSelectEventStateKeyNID(eventStateKeys)
|
return d.statements.bulkSelectEventStateKeyNID(eventStateKeys)
|
||||||
}
|
}
|
||||||
|
@ -195,14 +200,14 @@ func (d *Database) StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, err
|
||||||
return d.statements.bulkSelectStateAtEventByID(eventIDs)
|
return d.statements.bulkSelectStateAtEventByID(eventIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StateBlockNIDs implements input.EventDatabase
|
// StateBlockNIDs implements state.RoomStateDatabase
|
||||||
func (d *Database) StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) {
|
func (d *Database) StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) {
|
||||||
return d.statements.bulkSelectStateBlockNIDs(stateNIDs)
|
return d.statements.bulkSelectStateBlockNIDs(stateNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StateEntries implements input.EventDatabase
|
// StateEntries implements state.RoomStateDatabase
|
||||||
func (d *Database) StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) {
|
func (d *Database) StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) {
|
||||||
return d.statements.bulkSelectStateDataEntries(stateBlockNIDs)
|
return d.statements.bulkSelectStateBlockEntries(stateBlockNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventIDs implements input.RoomEventDatabase
|
// EventIDs implements input.RoomEventDatabase
|
||||||
|
@ -324,10 +329,21 @@ func (d *Database) RoomNID(roomID string) (types.RoomNID, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// LatestEventIDs implements query.RoomserverQueryAPIDB
|
// LatestEventIDs implements query.RoomserverQueryAPIDB
|
||||||
func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, error) {
|
func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, error) {
|
||||||
eventNIDs, err := d.statements.selectLatestEventNIDs(roomNID)
|
eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(roomNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
return d.statements.bulkSelectEventReference(eventNIDs)
|
references, err := d.statements.bulkSelectEventReference(eventNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return references, currentStateSnapshotNID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StateEntriesForTuples implements state.RoomStateDatabase
|
||||||
|
func (d *Database) StateEntriesForTuples(
|
||||||
|
stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple,
|
||||||
|
) ([]types.StateEntryList, error) {
|
||||||
|
return d.statements.bulkSelectFilteredStateBlockEntries(stateBlockNIDs, stateKeyTuples)
|
||||||
}
|
}
|
||||||
|
|
2
vendor/manifest
vendored
2
vendor/manifest
vendored
|
@ -206,4 +206,4 @@
|
||||||
"branch": "master"
|
"branch": "master"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
Loading…
Reference in a new issue