mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-07 06:53:09 -06:00
Use type aliases for numeric IDs
This commit is contained in:
parent
bab3ca5f5f
commit
48e2edab0d
|
|
@ -8,7 +8,7 @@ import (
|
|||
|
||||
// checkAuthEvents checks that the event passes authentication checks
|
||||
// Returns the numeric IDs for the auth events.
|
||||
func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEventIDs []string) ([]int64, error) {
|
||||
func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEventIDs []string) ([]types.EventNID, error) {
|
||||
// Grab the numeric IDs for the supplied auth state events from the database.
|
||||
authStateEntries, err := db.StateEntriesForEventIDs(authEventIDs)
|
||||
if err != nil {
|
||||
|
|
@ -31,7 +31,7 @@ func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEv
|
|||
}
|
||||
|
||||
// Return the numeric IDs for the auth events.
|
||||
result := make([]int64, len(authStateEntries))
|
||||
result := make([]types.EventNID, len(authStateEntries))
|
||||
for i := range authStateEntries {
|
||||
result[i] = authStateEntries[i].EventNID
|
||||
}
|
||||
|
|
@ -39,7 +39,7 @@ func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEv
|
|||
}
|
||||
|
||||
type authEvents struct {
|
||||
stateKeyNIDMap map[string]int64
|
||||
stateKeyNIDMap map[string]types.EventStateKeyNID
|
||||
state stateEntryMap
|
||||
events eventMap
|
||||
}
|
||||
|
|
@ -69,7 +69,7 @@ func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Even
|
|||
return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil
|
||||
}
|
||||
|
||||
func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID int64) *gomatrixserverlib.Event {
|
||||
func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event {
|
||||
eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, types.EmptyStateKeyNID})
|
||||
if !ok {
|
||||
return nil
|
||||
|
|
@ -81,7 +81,7 @@ func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID int64) *gomatrixserve
|
|||
return &event.Event
|
||||
}
|
||||
|
||||
func (ae *authEvents) lookupEvent(typeNID int64, stateKey string) *gomatrixserverlib.Event {
|
||||
func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *gomatrixserverlib.Event {
|
||||
stateKeyNID, ok := ae.stateKeyNIDMap[stateKey]
|
||||
if !ok {
|
||||
return nil
|
||||
|
|
@ -113,7 +113,7 @@ func loadAuthEvents(
|
|||
|
||||
// Load the events we need.
|
||||
result.state = state
|
||||
var eventNIDs []int64
|
||||
var eventNIDs []types.EventNID
|
||||
keyTuplesNeeded := stateKeyTuplesNeeded(result.stateKeyNIDMap, needed)
|
||||
for _, keyTuple := range keyTuplesNeeded {
|
||||
eventNID, ok := result.state.lookup(keyTuple)
|
||||
|
|
@ -128,7 +128,7 @@ func loadAuthEvents(
|
|||
}
|
||||
|
||||
// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events.
|
||||
func stateKeyTuplesNeeded(stateKeyNIDMap map[string]int64, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple {
|
||||
func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple {
|
||||
var keyTuples []types.StateKeyTuple
|
||||
if stateNeeded.Create {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID})
|
||||
|
|
@ -159,7 +159,7 @@ func stateKeyTuplesNeeded(stateKeyNIDMap map[string]int64, stateNeeded gomatrixs
|
|||
type stateEntryMap []types.StateEntry
|
||||
|
||||
// lookup an entry in the event map.
|
||||
func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID int64, ok bool) {
|
||||
func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) {
|
||||
// Since the list is sorted we can implement this using binary search.
|
||||
// This is faster than using a hash map.
|
||||
// We don't have to worry about pathological cases because the keys are fixed
|
||||
|
|
@ -180,7 +180,7 @@ func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID int64, ok
|
|||
type eventMap []types.Event
|
||||
|
||||
// lookup an entry in the event map.
|
||||
func (m eventMap) lookup(eventNID int64) (event *types.Event, ok bool) {
|
||||
func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) {
|
||||
// Since the list is sorted we can implement this using binary search.
|
||||
// This is faster than using a hash map.
|
||||
// We don't have to worry about pathological cases because the keys are fixed
|
||||
|
|
|
|||
|
|
@ -8,13 +8,18 @@ import (
|
|||
func benchmarkStateEntryMapLookup(entries, lookups int64, b *testing.B) {
|
||||
var list []types.StateEntry
|
||||
for i := int64(0); i < entries; i++ {
|
||||
list = append(list, types.StateEntry{types.StateKeyTuple{i, i}, i})
|
||||
list = append(list, types.StateEntry{types.StateKeyTuple{
|
||||
types.EventTypeNID(i),
|
||||
types.EventStateKeyNID(i),
|
||||
}, types.EventNID(i)})
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
entryMap := stateEntryMap(list)
|
||||
for j := int64(0); j < lookups; j++ {
|
||||
entryMap.lookup(types.StateKeyTuple{j, j})
|
||||
entryMap.lookup(types.StateKeyTuple{
|
||||
types.EventTypeNID(j), types.EventStateKeyNID(j),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -43,10 +48,10 @@ func TestStateEntryMap(t *testing.T) {
|
|||
})
|
||||
|
||||
testCases := []struct {
|
||||
inputTypeNID int64
|
||||
inputStateKey int64
|
||||
inputTypeNID types.EventTypeNID
|
||||
inputStateKey types.EventStateKeyNID
|
||||
wantOK bool
|
||||
wantEventNID int64
|
||||
wantEventNID types.EventNID
|
||||
}{
|
||||
// Check that tuples that in the array are in the map.
|
||||
{1, 1, true, 1},
|
||||
|
|
@ -80,7 +85,7 @@ func TestEventMap(t *testing.T) {
|
|||
})
|
||||
|
||||
testCases := []struct {
|
||||
inputEventNID int64
|
||||
inputEventNID types.EventNID
|
||||
wantOK bool
|
||||
wantEvent *types.Event
|
||||
}{
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import (
|
|||
// A RoomEventDatabase has the storage APIs needed to store a room event.
|
||||
type RoomEventDatabase interface {
|
||||
// Stores a matrix room event in the database
|
||||
StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) (roomNID int64, stateAtEvent types.StateAtEvent, err error)
|
||||
StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error)
|
||||
// Lookup the state entries for a list of string event IDs
|
||||
// Returns a sorted list of state entries.
|
||||
// Returns an error if the there is an error talking to the database
|
||||
|
|
@ -17,10 +17,10 @@ type RoomEventDatabase interface {
|
|||
StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error)
|
||||
// Lookup the numeric IDs for a list of string event state keys.
|
||||
// Returns a map from string state key to numeric ID for the state key.
|
||||
EventStateKeyNIDs(eventStateKeys []string) (map[string]int64, error)
|
||||
EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
||||
// Lookup the Events for a list of numeric event IDs.
|
||||
// Returns a sorted list of events.
|
||||
Events(eventNIDs []int64) ([]types.Event, error)
|
||||
Events(eventNIDs []types.EventNID) ([]types.Event, error)
|
||||
// Lookup the state of a room at each event for a list of string event IDs.
|
||||
// Returns a sorted list of state at each event.
|
||||
// Returns an error if there is an error talking to the database
|
||||
|
|
@ -28,14 +28,14 @@ type RoomEventDatabase interface {
|
|||
StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error)
|
||||
// Lookup the numeric state data IDs for the each numeric state ID
|
||||
// The returned slice is sorted by numeric state ID.
|
||||
StateDataNIDs(stateNIDs []int64) ([]types.StateDataNIDList, error)
|
||||
StateDataNIDs(stateNIDs []types.StateNID) ([]types.StateDataNIDList, error)
|
||||
// Lookup the state data for each numeric state data ID
|
||||
// The returned slice is sorted by numeric state data ID.
|
||||
StateEntries(stateDataNIDs []int64) ([]types.StateEntryList, error)
|
||||
StateEntries(stateDataNIDs []types.StateDataNID) ([]types.StateEntryList, error)
|
||||
// Store the room state at an event in the database
|
||||
AddState(roomNID int64, stateDataNIDs []int64, state []types.StateEntry) (stateNID int64, err error)
|
||||
AddState(roomNID types.RoomNID, stateDataNIDs []types.StateDataNID, state []types.StateEntry) (types.StateNID, error)
|
||||
// Set the state at an event.
|
||||
SetState(eventNID, stateNID int64) error
|
||||
SetState(eventNID types.EventNID, stateNID types.StateNID) error
|
||||
}
|
||||
|
||||
func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ import (
|
|||
)
|
||||
|
||||
func calculateAndStoreState(
|
||||
db RoomEventDatabase, event gomatrixserverlib.Event, roomNID int64, stateEventIDs []string,
|
||||
) (int64, error) {
|
||||
db RoomEventDatabase, event gomatrixserverlib.Event, roomNID types.RoomNID, stateEventIDs []string,
|
||||
) (types.StateNID, error) {
|
||||
if stateEventIDs != nil {
|
||||
// 1) We've been told what the state at the event is.
|
||||
// Check that those state events are in the database and store the state.
|
||||
|
|
@ -50,7 +50,7 @@ func calculateAndStoreState(
|
|||
}
|
||||
// The previous event was a state event so we need to store a copy
|
||||
// of the previous state updated with that event.
|
||||
stateDataNIDLists, err := db.StateDataNIDs([]int64{prevState.BeforeStateNID})
|
||||
stateDataNIDLists, err := db.StateDataNIDs([]types.StateNID{prevState.BeforeStateNID})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
|
@ -70,23 +70,23 @@ func calculateAndStoreState(
|
|||
|
||||
const maxStateDataNIDs = 64
|
||||
|
||||
func calculateAndStoreStateMany(db RoomEventDatabase, roomNID int64, prevStates []types.StateAtEvent) (int64, error) {
|
||||
func calculateAndStoreStateMany(db RoomEventDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent) (types.StateNID, error) {
|
||||
// Conflict resolution.
|
||||
// First stage: load the state datablocks for the prev events.
|
||||
stateNIDs := make([]int64, len(prevStates))
|
||||
stateNIDs := make([]types.StateNID, len(prevStates))
|
||||
for i, state := range prevStates {
|
||||
stateNIDs[i] = state.BeforeStateNID
|
||||
}
|
||||
stateDataNIDLists, err := db.StateDataNIDs(uniqueNIDs(stateNIDs))
|
||||
stateDataNIDLists, err := db.StateDataNIDs(uniqueStateNIDs(stateNIDs))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var stateDataNIDs []int64
|
||||
var stateDataNIDs []types.StateDataNID
|
||||
for _, list := range stateDataNIDLists {
|
||||
stateDataNIDs = append(stateDataNIDs, list.StateDataNIDs...)
|
||||
}
|
||||
stateEntryLists, err := db.StateEntries(uniqueNIDs(stateDataNIDs))
|
||||
stateEntryLists, err := db.StateEntries(uniqueStateDataNIDs(stateDataNIDs))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
|
@ -172,14 +172,9 @@ func duplicateStateKeys(a []types.StateEntry) []types.StateEntry {
|
|||
return result
|
||||
}
|
||||
|
||||
func uniqueNIDs(nids []int64) []int64 {
|
||||
sort.Sort(int64Sorter(nids))
|
||||
return nids[:unique(int64Sorter(nids))]
|
||||
}
|
||||
|
||||
type stateDataNIDListMap []types.StateDataNIDList
|
||||
|
||||
func (m stateDataNIDListMap) lookup(stateNID int64) (stateDataNIDs []int64, ok bool) {
|
||||
func (m stateDataNIDListMap) lookup(stateNID types.StateNID) (stateDataNIDs []types.StateDataNID, ok bool) {
|
||||
list := []types.StateDataNIDList(m)
|
||||
i := sort.Search(len(list), func(i int) bool {
|
||||
return list[i].StateNID >= stateNID
|
||||
|
|
@ -193,7 +188,7 @@ func (m stateDataNIDListMap) lookup(stateNID int64) (stateDataNIDs []int64, ok b
|
|||
|
||||
type stateEntryListMap []types.StateEntryList
|
||||
|
||||
func (m stateEntryListMap) lookup(stateDataNID int64) (stateEntries []types.StateEntry, ok bool) {
|
||||
func (m stateEntryListMap) lookup(stateDataNID types.StateDataNID) (stateEntries []types.StateEntry, ok bool) {
|
||||
list := []types.StateEntryList(m)
|
||||
i := sort.Search(len(list), func(i int) bool {
|
||||
return list[i].StateDataNID >= stateDataNID
|
||||
|
|
@ -219,11 +214,27 @@ func (s stateEntrySorter) Len() int { return len(s) }
|
|||
func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
|
||||
func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
||||
type int64Sorter []int64
|
||||
type stateNIDSorter []types.StateNID
|
||||
|
||||
func (s int64Sorter) Len() int { return len(s) }
|
||||
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
|
||||
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
func (s stateNIDSorter) Len() int { return len(s) }
|
||||
func (s stateNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
|
||||
func (s stateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
||||
func uniqueStateNIDs(nids []types.StateNID) []types.StateNID {
|
||||
sort.Sort(stateNIDSorter(nids))
|
||||
return nids[:unique(stateNIDSorter(nids))]
|
||||
}
|
||||
|
||||
type stateDataNIDSorter []types.StateDataNID
|
||||
|
||||
func (s stateDataNIDSorter) Len() int { return len(s) }
|
||||
func (s stateDataNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
|
||||
func (s stateDataNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
||||
func uniqueStateDataNIDs(nids []types.StateDataNID) []types.StateDataNID {
|
||||
sort.Sort(stateDataNIDSorter(nids))
|
||||
return nids[:unique(stateDataNIDSorter(nids))]
|
||||
}
|
||||
|
||||
// Remove duplicate items from a sorted list.
|
||||
// Takes the same interface as sort.Sort
|
||||
|
|
|
|||
|
|
@ -187,14 +187,16 @@ const insertEventTypeNIDSQL = "" +
|
|||
const selectEventTypeNIDSQL = "" +
|
||||
"SELECT event_type_nid FROM event_types WHERE event_type = $1"
|
||||
|
||||
func (s *statements) insertEventTypeNID(eventType string) (eventTypeNID int64, err error) {
|
||||
err = s.insertEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID)
|
||||
return
|
||||
func (s *statements) insertEventTypeNID(eventType string) (types.EventTypeNID, error) {
|
||||
var eventTypeNID int64
|
||||
err := s.insertEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID)
|
||||
return types.EventTypeNID(eventTypeNID), err
|
||||
}
|
||||
|
||||
func (s *statements) selectEventTypeNID(eventType string) (eventTypeNID int64, err error) {
|
||||
err = s.selectEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID)
|
||||
return
|
||||
func (s *statements) selectEventTypeNID(eventType string) (types.EventTypeNID, error) {
|
||||
var eventTypeNID int64
|
||||
err := s.selectEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID)
|
||||
return types.EventTypeNID(eventTypeNID), err
|
||||
}
|
||||
|
||||
func (s *statements) prepareEventStateKeys(db *sql.DB) (err error) {
|
||||
|
|
@ -251,31 +253,33 @@ const bulkSelectEventStateKeyNIDSQL = "" +
|
|||
"SELECT event_state_key, event_state_key_nid FROM event_state_keys" +
|
||||
" WHERE event_state_key = ANY($1)"
|
||||
|
||||
func (s *statements) insertEventStateKeyNID(eventStateKey string) (eventStateKeyNID int64, err error) {
|
||||
err = s.insertEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID)
|
||||
return
|
||||
func (s *statements) insertEventStateKeyNID(eventStateKey string) (types.EventStateKeyNID, error) {
|
||||
var eventStateKeyNID int64
|
||||
err := s.insertEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID)
|
||||
return types.EventStateKeyNID(eventStateKeyNID), err
|
||||
}
|
||||
|
||||
func (s *statements) selectEventStateKeyNID(eventStateKey string) (eventStateKeyNID int64, err error) {
|
||||
err = s.selectEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID)
|
||||
return
|
||||
func (s *statements) selectEventStateKeyNID(eventStateKey string) (types.EventStateKeyNID, error) {
|
||||
var eventStateKeyNID int64
|
||||
err := s.selectEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID)
|
||||
return types.EventStateKeyNID(eventStateKeyNID), err
|
||||
}
|
||||
|
||||
func (s *statements) bulkSelectEventStateKeyNID(eventStateKeys []string) (map[string]int64, error) {
|
||||
func (s *statements) bulkSelectEventStateKeyNID(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) {
|
||||
rows, err := s.bulkSelectEventStateKeyNIDStmt.Query(pq.StringArray(eventStateKeys))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result := make(map[string]int64, len(eventStateKeys))
|
||||
result := make(map[string]types.EventStateKeyNID, len(eventStateKeys))
|
||||
for rows.Next() {
|
||||
var stateKey string
|
||||
var stateKeyNID int64
|
||||
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[stateKey] = stateKeyNID
|
||||
result[stateKey] = types.EventStateKeyNID(stateKeyNID)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
|
@ -314,14 +318,16 @@ const insertRoomNIDSQL = "" +
|
|||
const selectRoomNIDSQL = "" +
|
||||
"SELECT room_nid FROM rooms WHERE room_id = $1"
|
||||
|
||||
func (s *statements) insertRoomNID(roomID string) (roomNID int64, err error) {
|
||||
err = s.insertRoomNIDStmt.QueryRow(roomID).Scan(&roomNID)
|
||||
return
|
||||
func (s *statements) insertRoomNID(roomID string) (types.RoomNID, error) {
|
||||
var roomNID int64
|
||||
err := s.insertRoomNIDStmt.QueryRow(roomID).Scan(&roomNID)
|
||||
return types.RoomNID(roomNID), err
|
||||
}
|
||||
|
||||
func (s *statements) selectRoomNID(roomID string) (roomNID int64, err error) {
|
||||
err = s.selectRoomNIDStmt.QueryRow(roomID).Scan(&roomNID)
|
||||
return
|
||||
func (s *statements) selectRoomNID(roomID string) (types.RoomNID, error) {
|
||||
var roomNID int64
|
||||
err := s.selectRoomNIDStmt.QueryRow(roomID).Scan(&roomNID)
|
||||
return types.RoomNID(roomNID), err
|
||||
}
|
||||
|
||||
const eventsSchema = `
|
||||
|
|
@ -404,16 +410,22 @@ func (s *statements) prepareEvents(db *sql.DB) (err error) {
|
|||
}
|
||||
|
||||
func (s *statements) insertEvent(
|
||||
roomNID, eventTypeNID, eventStateKeyNID int64,
|
||||
roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID,
|
||||
eventID string,
|
||||
referenceSHA256 []byte,
|
||||
authEventNIDs []int64,
|
||||
) (eventNID, stateNID int64, err error) {
|
||||
err = s.insertEventStmt.QueryRow(
|
||||
roomNID, eventTypeNID, eventStateKeyNID, eventID, referenceSHA256,
|
||||
pq.Int64Array(authEventNIDs),
|
||||
authEventNIDs []types.EventNID,
|
||||
) (types.EventNID, types.StateNID, error) {
|
||||
nids := make([]int64, len(authEventNIDs))
|
||||
for i := range authEventNIDs {
|
||||
nids[i] = int64(authEventNIDs[i])
|
||||
}
|
||||
var eventNID int64
|
||||
var stateNID int64
|
||||
err := s.insertEventStmt.QueryRow(
|
||||
int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256,
|
||||
pq.Int64Array(nids),
|
||||
).Scan(&eventNID, &stateNID)
|
||||
return
|
||||
return types.EventNID(eventNID), types.StateNID(stateNID), err
|
||||
}
|
||||
|
||||
func (s *statements) bulkSelectStateEventByID(eventIDs []string) ([]types.StateEntry, error) {
|
||||
|
|
@ -477,8 +489,8 @@ func (s *statements) bulkSelectStateAtEventByID(eventIDs []string) ([]types.Stat
|
|||
return results, err
|
||||
}
|
||||
|
||||
func (s *statements) updateEventState(eventNID, stateNID int64) error {
|
||||
_, err := s.updateEventStateStmt.Exec(eventNID, stateNID)
|
||||
func (s *statements) updateEventState(eventNID types.EventNID, stateNID types.StateNID) error {
|
||||
_, err := s.updateEventStateStmt.Exec(int64(eventNID), int64(stateNID))
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -525,18 +537,22 @@ const bulkSelectEventJSONSQL = "" +
|
|||
" WHERE event_nid = ANY($1)" +
|
||||
" ORDER BY event_nid ASC"
|
||||
|
||||
func (s *statements) insertEventJSON(eventNID int64, eventJSON []byte) error {
|
||||
_, err := s.insertEventJSONStmt.Exec(eventNID, eventJSON)
|
||||
func (s *statements) insertEventJSON(eventNID types.EventNID, eventJSON []byte) error {
|
||||
_, err := s.insertEventJSONStmt.Exec(int64(eventNID), eventJSON)
|
||||
return err
|
||||
}
|
||||
|
||||
type eventJSONPair struct {
|
||||
EventNID int64
|
||||
EventNID types.EventNID
|
||||
EventJSON []byte
|
||||
}
|
||||
|
||||
func (s *statements) bulkSelectEventJSON(eventNIDs []int64) ([]eventJSONPair, error) {
|
||||
rows, err := s.bulkSelectEventJSONStmt.Query(pq.Int64Array(eventNIDs))
|
||||
func (s *statements) bulkSelectEventJSON(eventNIDs []types.EventNID) ([]eventJSONPair, error) {
|
||||
nids := make([]int64, len(eventNIDs))
|
||||
for i := range eventNIDs {
|
||||
nids[i] = int64(eventNIDs[i])
|
||||
}
|
||||
rows, err := s.bulkSelectEventJSONStmt.Query(pq.Int64Array(nids))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -549,9 +565,12 @@ func (s *statements) bulkSelectEventJSON(eventNIDs []int64) ([]eventJSONPair, er
|
|||
results := make([]eventJSONPair, len(eventNIDs))
|
||||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
if err := rows.Scan(&results[i].EventNID, &results[i].EventJSON); err != nil {
|
||||
result := &results[i]
|
||||
var eventNID int64
|
||||
if err := rows.Scan(&eventNID, &result.EventJSON); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.EventNID = types.EventNID(eventNID)
|
||||
}
|
||||
return results[:i], nil
|
||||
}
|
||||
|
|
@ -602,13 +621,21 @@ func (s *statements) prepareState(db *sql.DB) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *statements) insertState(roomNID int64, stateDataNIDs []int64) (stateNID int64, err error) {
|
||||
err = s.insertStateStmt.QueryRow(roomNID, pq.Int64Array(stateDataNIDs)).Scan(&stateNID)
|
||||
func (s *statements) insertState(roomNID types.RoomNID, stateDataNIDs []types.StateDataNID) (stateNID types.StateNID, err error) {
|
||||
nids := make([]int64, len(stateDataNIDs))
|
||||
for i := range stateDataNIDs {
|
||||
nids[i] = int64(stateDataNIDs[i])
|
||||
}
|
||||
err = s.insertStateStmt.QueryRow(int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *statements) bulkSelectStateDataNIDs(stateNIDs []int64) ([]types.StateDataNIDList, error) {
|
||||
rows, err := s.bulkSelectStateDataNIDsStmt.Query(pq.Int64Array(stateNIDs))
|
||||
func (s *statements) bulkSelectStateDataNIDs(stateNIDs []types.StateNID) ([]types.StateDataNIDList, error) {
|
||||
nids := make([]int64, len(stateNIDs))
|
||||
for i := range stateNIDs {
|
||||
nids[i] = int64(stateNIDs[i])
|
||||
}
|
||||
rows, err := s.bulkSelectStateDataNIDsStmt.Query(pq.Int64Array(nids))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -617,11 +644,14 @@ func (s *statements) bulkSelectStateDataNIDs(stateNIDs []int64) ([]types.StateDa
|
|||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
result := &results[i]
|
||||
var stateDataNids pq.Int64Array
|
||||
if err := rows.Scan(&result.StateNID, &stateDataNids); err != nil {
|
||||
var stateDataNIDs pq.Int64Array
|
||||
if err := rows.Scan(&result.StateNID, &stateDataNIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.StateDataNIDs = stateDataNids
|
||||
result.StateDataNIDs = make([]types.StateDataNID, len(stateDataNIDs))
|
||||
for k := range stateDataNIDs {
|
||||
result.StateDataNIDs[k] = types.StateDataNID(stateDataNIDs[k])
|
||||
}
|
||||
}
|
||||
if i != len(stateNIDs) {
|
||||
return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))
|
||||
|
|
@ -680,13 +710,13 @@ func (s *statements) prepareStateData(db *sql.DB) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *statements) bulkInsertStateData(stateDataNID int64, entries []types.StateEntry) error {
|
||||
func (s *statements) bulkInsertStateData(stateDataNID types.StateDataNID, entries []types.StateEntry) error {
|
||||
for _, entry := range entries {
|
||||
_, err := s.insertStateDataStmt.Exec(
|
||||
stateDataNID,
|
||||
entry.EventTypeNID,
|
||||
entry.EventStateKeyNID,
|
||||
entry.EventNID,
|
||||
int64(stateDataNID),
|
||||
int64(entry.EventTypeNID),
|
||||
int64(entry.EventStateKeyNID),
|
||||
int64(entry.EventNID),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -695,13 +725,18 @@ func (s *statements) bulkInsertStateData(stateDataNID int64, entries []types.Sta
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *statements) selectNextStateDataNID() (stateDataNID int64, err error) {
|
||||
err = s.selectNextStateDataNIDStmt.QueryRow().Scan(&stateDataNID)
|
||||
return
|
||||
func (s *statements) selectNextStateDataNID() (types.StateDataNID, error) {
|
||||
var stateDataNID int64
|
||||
err := s.selectNextStateDataNIDStmt.QueryRow().Scan(&stateDataNID)
|
||||
return types.StateDataNID(stateDataNID), err
|
||||
}
|
||||
|
||||
func (s *statements) bulkSelectStateDataEntries(stateDataNIDs []int64) ([]types.StateEntryList, error) {
|
||||
rows, err := s.bulkSelectStateDataEntriesStmt.Query(pq.Int64Array(stateDataNIDs))
|
||||
func (s *statements) bulkSelectStateDataEntries(stateDataNIDs []types.StateDataNID) ([]types.StateEntryList, error) {
|
||||
nids := make([]int64, len(stateDataNIDs))
|
||||
for i := range stateDataNIDs {
|
||||
nids[i] = int64(stateDataNIDs[i])
|
||||
}
|
||||
rows, err := s.bulkSelectStateDataEntriesStmt.Query(pq.Int64Array(nids))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -712,19 +747,26 @@ func (s *statements) bulkSelectStateDataEntries(stateDataNIDs []int64) ([]types.
|
|||
var current *types.StateEntryList
|
||||
i := 0
|
||||
for rows.Next() {
|
||||
var stateDataNID int64
|
||||
var entry types.StateEntry
|
||||
var (
|
||||
stateDataNID int64
|
||||
eventTypeNID int64
|
||||
eventStateKeyNID int64
|
||||
eventNID int64
|
||||
entry types.StateEntry
|
||||
)
|
||||
if err := rows.Scan(
|
||||
&stateDataNID,
|
||||
&entry.EventTypeNID, &entry.EventStateKeyNID, &entry.EventNID,
|
||||
&stateDataNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if current == nil || stateDataNID != current.StateDataNID {
|
||||
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
|
||||
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
|
||||
entry.EventNID = types.EventNID(eventNID)
|
||||
if current == nil || types.StateDataNID(stateDataNID) != current.StateDataNID {
|
||||
// The state entry row is for a different state data block to the current one.
|
||||
// So we start appending to the next entry in the list.
|
||||
current = &results[i]
|
||||
current.StateDataNID = stateDataNID
|
||||
current.StateDataNID = types.StateDataNID(stateDataNID)
|
||||
i++
|
||||
}
|
||||
current.StateEntries = append(current.StateEntries, entry)
|
||||
|
|
|
|||
|
|
@ -38,13 +38,13 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6
|
|||
}
|
||||
|
||||
// StoreEvent implements input.EventDatabase
|
||||
func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) (int64, types.StateAtEvent, error) {
|
||||
func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) {
|
||||
var (
|
||||
roomNID int64
|
||||
eventTypeNID int64
|
||||
eventStateKeyNID int64
|
||||
eventNID int64
|
||||
stateNID int64
|
||||
roomNID types.RoomNID
|
||||
eventTypeNID types.EventTypeNID
|
||||
eventStateKeyNID types.EventStateKeyNID
|
||||
eventNID types.EventNID
|
||||
stateNID types.StateNID
|
||||
err error
|
||||
)
|
||||
|
||||
|
|
@ -92,7 +92,7 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (d *Database) assignRoomNID(roomID string) (int64, error) {
|
||||
func (d *Database) assignRoomNID(roomID string) (types.RoomNID, error) {
|
||||
// Check if we already have a numeric ID in the database.
|
||||
roomNID, err := d.statements.selectRoomNID(roomID)
|
||||
if err == sql.ErrNoRows {
|
||||
|
|
@ -105,7 +105,7 @@ func (d *Database) assignRoomNID(roomID string) (int64, error) {
|
|||
return roomNID, nil
|
||||
}
|
||||
|
||||
func (d *Database) assignEventTypeNID(eventType string) (int64, error) {
|
||||
func (d *Database) assignEventTypeNID(eventType string) (types.EventTypeNID, error) {
|
||||
// Check if we already have a numeric ID in the database.
|
||||
eventTypeNID, err := d.statements.selectEventTypeNID(eventType)
|
||||
if err == sql.ErrNoRows {
|
||||
|
|
@ -118,7 +118,7 @@ func (d *Database) assignEventTypeNID(eventType string) (int64, error) {
|
|||
return eventTypeNID, nil
|
||||
}
|
||||
|
||||
func (d *Database) assignStateKeyNID(eventStateKey string) (int64, error) {
|
||||
func (d *Database) assignStateKeyNID(eventStateKey string) (types.EventStateKeyNID, error) {
|
||||
// Check if we already have a numeric ID in the database.
|
||||
eventStateKeyNID, err := d.statements.selectEventStateKeyNID(eventStateKey)
|
||||
if err == sql.ErrNoRows {
|
||||
|
|
@ -137,12 +137,12 @@ func (d *Database) StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntr
|
|||
}
|
||||
|
||||
// EventStateKeyNIDs implements input.EventDatabase
|
||||
func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]int64, error) {
|
||||
func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) {
|
||||
return d.statements.bulkSelectEventStateKeyNID(eventStateKeys)
|
||||
}
|
||||
|
||||
// Events implements input.EventDatabase
|
||||
func (d *Database) Events(eventNIDs []int64) ([]types.Event, error) {
|
||||
func (d *Database) Events(eventNIDs []types.EventNID) ([]types.Event, error) {
|
||||
eventJSONs, err := d.statements.bulkSelectEventJSON(eventNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -161,14 +161,14 @@ func (d *Database) Events(eventNIDs []int64) ([]types.Event, error) {
|
|||
}
|
||||
|
||||
// AddState implements input.EventDatabase
|
||||
func (d *Database) AddState(roomNID int64, stateDataNIDs []int64, state []types.StateEntry) (stateNID int64, err error) {
|
||||
func (d *Database) AddState(roomNID types.RoomNID, stateDataNIDs []types.StateDataNID, state []types.StateEntry) (types.StateNID, error) {
|
||||
if len(state) > 0 {
|
||||
var stateDataNID int64
|
||||
if stateDataNID, err = d.statements.selectNextStateDataNID(); err != nil {
|
||||
return
|
||||
stateDataNID, err := d.statements.selectNextStateDataNID()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err = d.statements.bulkInsertStateData(stateDataNID, state); err != nil {
|
||||
return
|
||||
return 0, err
|
||||
}
|
||||
stateDataNIDs = append(stateDataNIDs[:len(stateDataNIDs):len(stateDataNIDs)], stateDataNID)
|
||||
}
|
||||
|
|
@ -177,7 +177,7 @@ func (d *Database) AddState(roomNID int64, stateDataNIDs []int64, state []types.
|
|||
}
|
||||
|
||||
// SetState implements input.EventDatabase
|
||||
func (d *Database) SetState(eventNID, stateNID int64) error {
|
||||
func (d *Database) SetState(eventNID types.EventNID, stateNID types.StateNID) error {
|
||||
return d.statements.updateEventState(eventNID, stateNID)
|
||||
}
|
||||
|
||||
|
|
@ -187,11 +187,11 @@ func (d *Database) StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, err
|
|||
}
|
||||
|
||||
// StateDataNIDs implements input.EventDatabase
|
||||
func (d *Database) StateDataNIDs(stateNIDs []int64) ([]types.StateDataNIDList, error) {
|
||||
func (d *Database) StateDataNIDs(stateNIDs []types.StateNID) ([]types.StateDataNIDList, error) {
|
||||
return d.statements.bulkSelectStateDataNIDs(stateNIDs)
|
||||
}
|
||||
|
||||
// StateEntries implements input.EventDatabase
|
||||
func (d *Database) StateEntries(stateDataNIDs []int64) ([]types.StateEntryList, error) {
|
||||
func (d *Database) StateEntries(stateDataNIDs []types.StateDataNID) ([]types.StateEntryList, error) {
|
||||
return d.statements.bulkSelectStateDataEntries(stateDataNIDs)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,13 +13,32 @@ type PartitionOffset struct {
|
|||
Offset int64
|
||||
}
|
||||
|
||||
// EventTypeNID is a numeric ID for an event type.
|
||||
type EventTypeNID int64
|
||||
|
||||
// EventStateKeyNID is a numeric ID for an event state_key.
|
||||
type EventStateKeyNID int64
|
||||
|
||||
// EventNID is a numeric ID for an event.
|
||||
type EventNID int64
|
||||
|
||||
// RoomNID is a numeric ID for a room.
|
||||
type RoomNID int64
|
||||
|
||||
// StateNID is a numeric ID for the state at an event.
|
||||
type StateNID int64
|
||||
|
||||
// StateDataNID is a numeric ID for a block of state data.
|
||||
// These blocks of state data are combined to form the actual state.
|
||||
type StateDataNID int64
|
||||
|
||||
// A StateKeyTuple is a pair of a numeric event type and a numeric state key.
|
||||
// It is used to lookup state entries.
|
||||
type StateKeyTuple struct {
|
||||
// The numeric ID for the event type.
|
||||
EventTypeNID int64
|
||||
EventTypeNID EventTypeNID
|
||||
// The numeric ID for the state key.
|
||||
EventStateKeyNID int64
|
||||
EventStateKeyNID EventStateKeyNID
|
||||
}
|
||||
|
||||
// LessThan returns true if this state key is less than the other state key.
|
||||
|
|
@ -35,7 +54,7 @@ func (a StateKeyTuple) LessThan(b StateKeyTuple) bool {
|
|||
type StateEntry struct {
|
||||
StateKeyTuple
|
||||
// The numeric ID for the event.
|
||||
EventNID int64
|
||||
EventNID EventNID
|
||||
}
|
||||
|
||||
// LessThan returns true if this state entry is less than the other state entry.
|
||||
|
|
@ -50,7 +69,7 @@ func (a StateEntry) LessThan(b StateEntry) bool {
|
|||
// StateAtEvent is the state before and after a matrix event.
|
||||
type StateAtEvent struct {
|
||||
// The state before the event.
|
||||
BeforeStateNID int64
|
||||
BeforeStateNID StateNID
|
||||
// The state entry for the event itself, allows us to calculate the state after the event.
|
||||
StateEntry
|
||||
}
|
||||
|
|
@ -58,7 +77,7 @@ type StateAtEvent struct {
|
|||
// An Event is a gomatrixserverlib.Event with the numeric event ID attached.
|
||||
// It is when performing bulk event lookup in the database.
|
||||
type Event struct {
|
||||
EventNID int64
|
||||
EventNID EventNID
|
||||
gomatrixserverlib.Event
|
||||
}
|
||||
|
||||
|
|
@ -86,12 +105,12 @@ const (
|
|||
|
||||
// StateDataNIDList is used to return the result of bulk StateDataNID lookups from the database.
|
||||
type StateDataNIDList struct {
|
||||
StateNID int64
|
||||
StateDataNIDs []int64
|
||||
StateNID StateNID
|
||||
StateDataNIDs []StateDataNID
|
||||
}
|
||||
|
||||
// StateEntryList is used to return the result of bulk state entry lookups from the database.
|
||||
type StateEntryList struct {
|
||||
StateDataNID int64
|
||||
StateDataNID StateDataNID
|
||||
StateEntries []StateEntry
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue