diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/authevents.go b/src/github.com/matrix-org/dendrite/roomserver/input/authevents.go index bb15750b7..7dcaca915 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/authevents.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/authevents.go @@ -8,7 +8,7 @@ import ( // checkAuthEvents checks that the event passes authentication checks // Returns the numeric IDs for the auth events. -func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEventIDs []string) ([]int64, error) { +func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEventIDs []string) ([]types.EventNID, error) { // Grab the numeric IDs for the supplied auth state events from the database. authStateEntries, err := db.StateEntriesForEventIDs(authEventIDs) if err != nil { @@ -31,7 +31,7 @@ func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEv } // Return the numeric IDs for the auth events. - result := make([]int64, len(authStateEntries)) + result := make([]types.EventNID, len(authStateEntries)) for i := range authStateEntries { result[i] = authStateEntries[i].EventNID } @@ -39,7 +39,7 @@ func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEv } type authEvents struct { - stateKeyNIDMap map[string]int64 + stateKeyNIDMap map[string]types.EventStateKeyNID state stateEntryMap events eventMap } @@ -69,7 +69,7 @@ func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Even return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil } -func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID int64) *gomatrixserverlib.Event { +func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event { eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, types.EmptyStateKeyNID}) if !ok { return nil @@ -81,7 +81,7 @@ func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID int64) *gomatrixserve return &event.Event } -func (ae *authEvents) lookupEvent(typeNID int64, stateKey string) *gomatrixserverlib.Event { +func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *gomatrixserverlib.Event { stateKeyNID, ok := ae.stateKeyNIDMap[stateKey] if !ok { return nil @@ -113,7 +113,7 @@ func loadAuthEvents( // Load the events we need. result.state = state - var eventNIDs []int64 + var eventNIDs []types.EventNID keyTuplesNeeded := stateKeyTuplesNeeded(result.stateKeyNIDMap, needed) for _, keyTuple := range keyTuplesNeeded { eventNID, ok := result.state.lookup(keyTuple) @@ -128,7 +128,7 @@ func loadAuthEvents( } // stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events. -func stateKeyTuplesNeeded(stateKeyNIDMap map[string]int64, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple { +func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple { var keyTuples []types.StateKeyTuple if stateNeeded.Create { keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID}) @@ -159,7 +159,7 @@ func stateKeyTuplesNeeded(stateKeyNIDMap map[string]int64, stateNeeded gomatrixs type stateEntryMap []types.StateEntry // lookup an entry in the event map. -func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID int64, ok bool) { +func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) { // Since the list is sorted we can implement this using binary search. // This is faster than using a hash map. // We don't have to worry about pathological cases because the keys are fixed @@ -180,7 +180,7 @@ func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID int64, ok type eventMap []types.Event // lookup an entry in the event map. -func (m eventMap) lookup(eventNID int64) (event *types.Event, ok bool) { +func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) { // Since the list is sorted we can implement this using binary search. // This is faster than using a hash map. // We don't have to worry about pathological cases because the keys are fixed diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/authevents_test.go b/src/github.com/matrix-org/dendrite/roomserver/input/authevents_test.go index aba1de092..69be65d78 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/authevents_test.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/authevents_test.go @@ -8,13 +8,18 @@ import ( func benchmarkStateEntryMapLookup(entries, lookups int64, b *testing.B) { var list []types.StateEntry for i := int64(0); i < entries; i++ { - list = append(list, types.StateEntry{types.StateKeyTuple{i, i}, i}) + list = append(list, types.StateEntry{types.StateKeyTuple{ + types.EventTypeNID(i), + types.EventStateKeyNID(i), + }, types.EventNID(i)}) } for i := 0; i < b.N; i++ { entryMap := stateEntryMap(list) for j := int64(0); j < lookups; j++ { - entryMap.lookup(types.StateKeyTuple{j, j}) + entryMap.lookup(types.StateKeyTuple{ + types.EventTypeNID(j), types.EventStateKeyNID(j), + }) } } } @@ -43,10 +48,10 @@ func TestStateEntryMap(t *testing.T) { }) testCases := []struct { - inputTypeNID int64 - inputStateKey int64 + inputTypeNID types.EventTypeNID + inputStateKey types.EventStateKeyNID wantOK bool - wantEventNID int64 + wantEventNID types.EventNID }{ // Check that tuples that in the array are in the map. {1, 1, true, 1}, @@ -80,7 +85,7 @@ func TestEventMap(t *testing.T) { }) testCases := []struct { - inputEventNID int64 + inputEventNID types.EventNID wantOK bool wantEvent *types.Event }{ diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/events.go b/src/github.com/matrix-org/dendrite/roomserver/input/events.go index 56369150d..576530339 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/events.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/events.go @@ -9,7 +9,7 @@ import ( // A RoomEventDatabase has the storage APIs needed to store a room event. type RoomEventDatabase interface { // Stores a matrix room event in the database - StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) (roomNID int64, stateAtEvent types.StateAtEvent, err error) + StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) // Lookup the state entries for a list of string event IDs // Returns a sorted list of state entries. // Returns an error if the there is an error talking to the database @@ -17,10 +17,10 @@ type RoomEventDatabase interface { StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error) // Lookup the numeric IDs for a list of string event state keys. // Returns a map from string state key to numeric ID for the state key. - EventStateKeyNIDs(eventStateKeys []string) (map[string]int64, error) + EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) // Lookup the Events for a list of numeric event IDs. // Returns a sorted list of events. - Events(eventNIDs []int64) ([]types.Event, error) + Events(eventNIDs []types.EventNID) ([]types.Event, error) // Lookup the state of a room at each event for a list of string event IDs. // Returns a sorted list of state at each event. // Returns an error if there is an error talking to the database @@ -28,14 +28,14 @@ type RoomEventDatabase interface { StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error) // Lookup the numeric state data IDs for the each numeric state ID // The returned slice is sorted by numeric state ID. - StateDataNIDs(stateNIDs []int64) ([]types.StateDataNIDList, error) + StateDataNIDs(stateNIDs []types.StateNID) ([]types.StateDataNIDList, error) // Lookup the state data for each numeric state data ID // The returned slice is sorted by numeric state data ID. - StateEntries(stateDataNIDs []int64) ([]types.StateEntryList, error) + StateEntries(stateDataNIDs []types.StateDataNID) ([]types.StateEntryList, error) // Store the room state at an event in the database - AddState(roomNID int64, stateDataNIDs []int64, state []types.StateEntry) (stateNID int64, err error) + AddState(roomNID types.RoomNID, stateDataNIDs []types.StateDataNID, state []types.StateEntry) (types.StateNID, error) // Set the state at an event. - SetState(eventNID, stateNID int64) error + SetState(eventNID types.EventNID, stateNID types.StateNID) error } func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error { diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/state.go b/src/github.com/matrix-org/dendrite/roomserver/input/state.go index b8cf80550..8a39df140 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/state.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/state.go @@ -8,8 +8,8 @@ import ( ) func calculateAndStoreState( - db RoomEventDatabase, event gomatrixserverlib.Event, roomNID int64, stateEventIDs []string, -) (int64, error) { + db RoomEventDatabase, event gomatrixserverlib.Event, roomNID types.RoomNID, stateEventIDs []string, +) (types.StateNID, error) { if stateEventIDs != nil { // 1) We've been told what the state at the event is. // Check that those state events are in the database and store the state. @@ -50,7 +50,7 @@ func calculateAndStoreState( } // The previous event was a state event so we need to store a copy // of the previous state updated with that event. - stateDataNIDLists, err := db.StateDataNIDs([]int64{prevState.BeforeStateNID}) + stateDataNIDLists, err := db.StateDataNIDs([]types.StateNID{prevState.BeforeStateNID}) if err != nil { return 0, err } @@ -70,23 +70,23 @@ func calculateAndStoreState( const maxStateDataNIDs = 64 -func calculateAndStoreStateMany(db RoomEventDatabase, roomNID int64, prevStates []types.StateAtEvent) (int64, error) { +func calculateAndStoreStateMany(db RoomEventDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent) (types.StateNID, error) { // Conflict resolution. // First stage: load the state datablocks for the prev events. - stateNIDs := make([]int64, len(prevStates)) + stateNIDs := make([]types.StateNID, len(prevStates)) for i, state := range prevStates { stateNIDs[i] = state.BeforeStateNID } - stateDataNIDLists, err := db.StateDataNIDs(uniqueNIDs(stateNIDs)) + stateDataNIDLists, err := db.StateDataNIDs(uniqueStateNIDs(stateNIDs)) if err != nil { return 0, err } - var stateDataNIDs []int64 + var stateDataNIDs []types.StateDataNID for _, list := range stateDataNIDLists { stateDataNIDs = append(stateDataNIDs, list.StateDataNIDs...) } - stateEntryLists, err := db.StateEntries(uniqueNIDs(stateDataNIDs)) + stateEntryLists, err := db.StateEntries(uniqueStateDataNIDs(stateDataNIDs)) if err != nil { return 0, err } @@ -172,14 +172,9 @@ func duplicateStateKeys(a []types.StateEntry) []types.StateEntry { return result } -func uniqueNIDs(nids []int64) []int64 { - sort.Sort(int64Sorter(nids)) - return nids[:unique(int64Sorter(nids))] -} - type stateDataNIDListMap []types.StateDataNIDList -func (m stateDataNIDListMap) lookup(stateNID int64) (stateDataNIDs []int64, ok bool) { +func (m stateDataNIDListMap) lookup(stateNID types.StateNID) (stateDataNIDs []types.StateDataNID, ok bool) { list := []types.StateDataNIDList(m) i := sort.Search(len(list), func(i int) bool { return list[i].StateNID >= stateNID @@ -193,7 +188,7 @@ func (m stateDataNIDListMap) lookup(stateNID int64) (stateDataNIDs []int64, ok b type stateEntryListMap []types.StateEntryList -func (m stateEntryListMap) lookup(stateDataNID int64) (stateEntries []types.StateEntry, ok bool) { +func (m stateEntryListMap) lookup(stateDataNID types.StateDataNID) (stateEntries []types.StateEntry, ok bool) { list := []types.StateEntryList(m) i := sort.Search(len(list), func(i int) bool { return list[i].StateDataNID >= stateDataNID @@ -219,11 +214,27 @@ func (s stateEntrySorter) Len() int { return len(s) } func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -type int64Sorter []int64 +type stateNIDSorter []types.StateNID -func (s int64Sorter) Len() int { return len(s) } -func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] } -func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s stateNIDSorter) Len() int { return len(s) } +func (s stateNIDSorter) Less(i, j int) bool { return s[i] < s[j] } +func (s stateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func uniqueStateNIDs(nids []types.StateNID) []types.StateNID { + sort.Sort(stateNIDSorter(nids)) + return nids[:unique(stateNIDSorter(nids))] +} + +type stateDataNIDSorter []types.StateDataNID + +func (s stateDataNIDSorter) Len() int { return len(s) } +func (s stateDataNIDSorter) Less(i, j int) bool { return s[i] < s[j] } +func (s stateDataNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func uniqueStateDataNIDs(nids []types.StateDataNID) []types.StateDataNID { + sort.Sort(stateDataNIDSorter(nids)) + return nids[:unique(stateDataNIDSorter(nids))] +} // Remove duplicate items from a sorted list. // Takes the same interface as sort.Sort diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go b/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go index c03793ecb..025f46dd0 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go @@ -187,14 +187,16 @@ const insertEventTypeNIDSQL = "" + const selectEventTypeNIDSQL = "" + "SELECT event_type_nid FROM event_types WHERE event_type = $1" -func (s *statements) insertEventTypeNID(eventType string) (eventTypeNID int64, err error) { - err = s.insertEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID) - return +func (s *statements) insertEventTypeNID(eventType string) (types.EventTypeNID, error) { + var eventTypeNID int64 + err := s.insertEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID) + return types.EventTypeNID(eventTypeNID), err } -func (s *statements) selectEventTypeNID(eventType string) (eventTypeNID int64, err error) { - err = s.selectEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID) - return +func (s *statements) selectEventTypeNID(eventType string) (types.EventTypeNID, error) { + var eventTypeNID int64 + err := s.selectEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID) + return types.EventTypeNID(eventTypeNID), err } func (s *statements) prepareEventStateKeys(db *sql.DB) (err error) { @@ -251,31 +253,33 @@ const bulkSelectEventStateKeyNIDSQL = "" + "SELECT event_state_key, event_state_key_nid FROM event_state_keys" + " WHERE event_state_key = ANY($1)" -func (s *statements) insertEventStateKeyNID(eventStateKey string) (eventStateKeyNID int64, err error) { - err = s.insertEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID) - return +func (s *statements) insertEventStateKeyNID(eventStateKey string) (types.EventStateKeyNID, error) { + var eventStateKeyNID int64 + err := s.insertEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID) + return types.EventStateKeyNID(eventStateKeyNID), err } -func (s *statements) selectEventStateKeyNID(eventStateKey string) (eventStateKeyNID int64, err error) { - err = s.selectEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID) - return +func (s *statements) selectEventStateKeyNID(eventStateKey string) (types.EventStateKeyNID, error) { + var eventStateKeyNID int64 + err := s.selectEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID) + return types.EventStateKeyNID(eventStateKeyNID), err } -func (s *statements) bulkSelectEventStateKeyNID(eventStateKeys []string) (map[string]int64, error) { +func (s *statements) bulkSelectEventStateKeyNID(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) { rows, err := s.bulkSelectEventStateKeyNIDStmt.Query(pq.StringArray(eventStateKeys)) if err != nil { return nil, err } defer rows.Close() - result := make(map[string]int64, len(eventStateKeys)) + result := make(map[string]types.EventStateKeyNID, len(eventStateKeys)) for rows.Next() { var stateKey string var stateKeyNID int64 if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { return nil, err } - result[stateKey] = stateKeyNID + result[stateKey] = types.EventStateKeyNID(stateKeyNID) } return result, nil } @@ -314,14 +318,16 @@ const insertRoomNIDSQL = "" + const selectRoomNIDSQL = "" + "SELECT room_nid FROM rooms WHERE room_id = $1" -func (s *statements) insertRoomNID(roomID string) (roomNID int64, err error) { - err = s.insertRoomNIDStmt.QueryRow(roomID).Scan(&roomNID) - return +func (s *statements) insertRoomNID(roomID string) (types.RoomNID, error) { + var roomNID int64 + err := s.insertRoomNIDStmt.QueryRow(roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err } -func (s *statements) selectRoomNID(roomID string) (roomNID int64, err error) { - err = s.selectRoomNIDStmt.QueryRow(roomID).Scan(&roomNID) - return +func (s *statements) selectRoomNID(roomID string) (types.RoomNID, error) { + var roomNID int64 + err := s.selectRoomNIDStmt.QueryRow(roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err } const eventsSchema = ` @@ -404,16 +410,22 @@ func (s *statements) prepareEvents(db *sql.DB) (err error) { } func (s *statements) insertEvent( - roomNID, eventTypeNID, eventStateKeyNID int64, + roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, eventID string, referenceSHA256 []byte, - authEventNIDs []int64, -) (eventNID, stateNID int64, err error) { - err = s.insertEventStmt.QueryRow( - roomNID, eventTypeNID, eventStateKeyNID, eventID, referenceSHA256, - pq.Int64Array(authEventNIDs), + authEventNIDs []types.EventNID, +) (types.EventNID, types.StateNID, error) { + nids := make([]int64, len(authEventNIDs)) + for i := range authEventNIDs { + nids[i] = int64(authEventNIDs[i]) + } + var eventNID int64 + var stateNID int64 + err := s.insertEventStmt.QueryRow( + int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, + pq.Int64Array(nids), ).Scan(&eventNID, &stateNID) - return + return types.EventNID(eventNID), types.StateNID(stateNID), err } func (s *statements) bulkSelectStateEventByID(eventIDs []string) ([]types.StateEntry, error) { @@ -477,8 +489,8 @@ func (s *statements) bulkSelectStateAtEventByID(eventIDs []string) ([]types.Stat return results, err } -func (s *statements) updateEventState(eventNID, stateNID int64) error { - _, err := s.updateEventStateStmt.Exec(eventNID, stateNID) +func (s *statements) updateEventState(eventNID types.EventNID, stateNID types.StateNID) error { + _, err := s.updateEventStateStmt.Exec(int64(eventNID), int64(stateNID)) return err } @@ -525,18 +537,22 @@ const bulkSelectEventJSONSQL = "" + " WHERE event_nid = ANY($1)" + " ORDER BY event_nid ASC" -func (s *statements) insertEventJSON(eventNID int64, eventJSON []byte) error { - _, err := s.insertEventJSONStmt.Exec(eventNID, eventJSON) +func (s *statements) insertEventJSON(eventNID types.EventNID, eventJSON []byte) error { + _, err := s.insertEventJSONStmt.Exec(int64(eventNID), eventJSON) return err } type eventJSONPair struct { - EventNID int64 + EventNID types.EventNID EventJSON []byte } -func (s *statements) bulkSelectEventJSON(eventNIDs []int64) ([]eventJSONPair, error) { - rows, err := s.bulkSelectEventJSONStmt.Query(pq.Int64Array(eventNIDs)) +func (s *statements) bulkSelectEventJSON(eventNIDs []types.EventNID) ([]eventJSONPair, error) { + nids := make([]int64, len(eventNIDs)) + for i := range eventNIDs { + nids[i] = int64(eventNIDs[i]) + } + rows, err := s.bulkSelectEventJSONStmt.Query(pq.Int64Array(nids)) if err != nil { return nil, err } @@ -549,9 +565,12 @@ func (s *statements) bulkSelectEventJSON(eventNIDs []int64) ([]eventJSONPair, er results := make([]eventJSONPair, len(eventNIDs)) i := 0 for ; rows.Next(); i++ { - if err := rows.Scan(&results[i].EventNID, &results[i].EventJSON); err != nil { + result := &results[i] + var eventNID int64 + if err := rows.Scan(&eventNID, &result.EventJSON); err != nil { return nil, err } + result.EventNID = types.EventNID(eventNID) } return results[:i], nil } @@ -602,13 +621,21 @@ func (s *statements) prepareState(db *sql.DB) (err error) { return } -func (s *statements) insertState(roomNID int64, stateDataNIDs []int64) (stateNID int64, err error) { - err = s.insertStateStmt.QueryRow(roomNID, pq.Int64Array(stateDataNIDs)).Scan(&stateNID) +func (s *statements) insertState(roomNID types.RoomNID, stateDataNIDs []types.StateDataNID) (stateNID types.StateNID, err error) { + nids := make([]int64, len(stateDataNIDs)) + for i := range stateDataNIDs { + nids[i] = int64(stateDataNIDs[i]) + } + err = s.insertStateStmt.QueryRow(int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) return } -func (s *statements) bulkSelectStateDataNIDs(stateNIDs []int64) ([]types.StateDataNIDList, error) { - rows, err := s.bulkSelectStateDataNIDsStmt.Query(pq.Int64Array(stateNIDs)) +func (s *statements) bulkSelectStateDataNIDs(stateNIDs []types.StateNID) ([]types.StateDataNIDList, error) { + nids := make([]int64, len(stateNIDs)) + for i := range stateNIDs { + nids[i] = int64(stateNIDs[i]) + } + rows, err := s.bulkSelectStateDataNIDsStmt.Query(pq.Int64Array(nids)) if err != nil { return nil, err } @@ -617,11 +644,14 @@ func (s *statements) bulkSelectStateDataNIDs(stateNIDs []int64) ([]types.StateDa i := 0 for ; rows.Next(); i++ { result := &results[i] - var stateDataNids pq.Int64Array - if err := rows.Scan(&result.StateNID, &stateDataNids); err != nil { + var stateDataNIDs pq.Int64Array + if err := rows.Scan(&result.StateNID, &stateDataNIDs); err != nil { return nil, err } - result.StateDataNIDs = stateDataNids + result.StateDataNIDs = make([]types.StateDataNID, len(stateDataNIDs)) + for k := range stateDataNIDs { + result.StateDataNIDs[k] = types.StateDataNID(stateDataNIDs[k]) + } } if i != len(stateNIDs) { return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs)) @@ -680,13 +710,13 @@ func (s *statements) prepareStateData(db *sql.DB) (err error) { return } -func (s *statements) bulkInsertStateData(stateDataNID int64, entries []types.StateEntry) error { +func (s *statements) bulkInsertStateData(stateDataNID types.StateDataNID, entries []types.StateEntry) error { for _, entry := range entries { _, err := s.insertStateDataStmt.Exec( - stateDataNID, - entry.EventTypeNID, - entry.EventStateKeyNID, - entry.EventNID, + int64(stateDataNID), + int64(entry.EventTypeNID), + int64(entry.EventStateKeyNID), + int64(entry.EventNID), ) if err != nil { return err @@ -695,13 +725,18 @@ func (s *statements) bulkInsertStateData(stateDataNID int64, entries []types.Sta return nil } -func (s *statements) selectNextStateDataNID() (stateDataNID int64, err error) { - err = s.selectNextStateDataNIDStmt.QueryRow().Scan(&stateDataNID) - return +func (s *statements) selectNextStateDataNID() (types.StateDataNID, error) { + var stateDataNID int64 + err := s.selectNextStateDataNIDStmt.QueryRow().Scan(&stateDataNID) + return types.StateDataNID(stateDataNID), err } -func (s *statements) bulkSelectStateDataEntries(stateDataNIDs []int64) ([]types.StateEntryList, error) { - rows, err := s.bulkSelectStateDataEntriesStmt.Query(pq.Int64Array(stateDataNIDs)) +func (s *statements) bulkSelectStateDataEntries(stateDataNIDs []types.StateDataNID) ([]types.StateEntryList, error) { + nids := make([]int64, len(stateDataNIDs)) + for i := range stateDataNIDs { + nids[i] = int64(stateDataNIDs[i]) + } + rows, err := s.bulkSelectStateDataEntriesStmt.Query(pq.Int64Array(nids)) if err != nil { return nil, err } @@ -712,19 +747,26 @@ func (s *statements) bulkSelectStateDataEntries(stateDataNIDs []int64) ([]types. var current *types.StateEntryList i := 0 for rows.Next() { - var stateDataNID int64 - var entry types.StateEntry + var ( + stateDataNID int64 + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + entry types.StateEntry + ) if err := rows.Scan( - &stateDataNID, - &entry.EventTypeNID, &entry.EventStateKeyNID, &entry.EventNID, + &stateDataNID, &eventTypeNID, &eventStateKeyNID, &eventNID, ); err != nil { return nil, err } - if current == nil || stateDataNID != current.StateDataNID { + entry.EventTypeNID = types.EventTypeNID(eventTypeNID) + entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) + entry.EventNID = types.EventNID(eventNID) + if current == nil || types.StateDataNID(stateDataNID) != current.StateDataNID { // The state entry row is for a different state data block to the current one. // So we start appending to the next entry in the list. current = &results[i] - current.StateDataNID = stateDataNID + current.StateDataNID = types.StateDataNID(stateDataNID) i++ } current.StateEntries = append(current.StateEntries, entry) diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go index 7d46a9698..93ea2e579 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go @@ -38,13 +38,13 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6 } // StoreEvent implements input.EventDatabase -func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) (int64, types.StateAtEvent, error) { +func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) { var ( - roomNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - stateNID int64 + roomNID types.RoomNID + eventTypeNID types.EventTypeNID + eventStateKeyNID types.EventStateKeyNID + eventNID types.EventNID + stateNID types.StateNID err error ) @@ -92,7 +92,7 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int }, nil } -func (d *Database) assignRoomNID(roomID string) (int64, error) { +func (d *Database) assignRoomNID(roomID string) (types.RoomNID, error) { // Check if we already have a numeric ID in the database. roomNID, err := d.statements.selectRoomNID(roomID) if err == sql.ErrNoRows { @@ -105,7 +105,7 @@ func (d *Database) assignRoomNID(roomID string) (int64, error) { return roomNID, nil } -func (d *Database) assignEventTypeNID(eventType string) (int64, error) { +func (d *Database) assignEventTypeNID(eventType string) (types.EventTypeNID, error) { // Check if we already have a numeric ID in the database. eventTypeNID, err := d.statements.selectEventTypeNID(eventType) if err == sql.ErrNoRows { @@ -118,7 +118,7 @@ func (d *Database) assignEventTypeNID(eventType string) (int64, error) { return eventTypeNID, nil } -func (d *Database) assignStateKeyNID(eventStateKey string) (int64, error) { +func (d *Database) assignStateKeyNID(eventStateKey string) (types.EventStateKeyNID, error) { // Check if we already have a numeric ID in the database. eventStateKeyNID, err := d.statements.selectEventStateKeyNID(eventStateKey) if err == sql.ErrNoRows { @@ -137,12 +137,12 @@ func (d *Database) StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntr } // EventStateKeyNIDs implements input.EventDatabase -func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]int64, error) { +func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) { return d.statements.bulkSelectEventStateKeyNID(eventStateKeys) } // Events implements input.EventDatabase -func (d *Database) Events(eventNIDs []int64) ([]types.Event, error) { +func (d *Database) Events(eventNIDs []types.EventNID) ([]types.Event, error) { eventJSONs, err := d.statements.bulkSelectEventJSON(eventNIDs) if err != nil { return nil, err @@ -161,14 +161,14 @@ func (d *Database) Events(eventNIDs []int64) ([]types.Event, error) { } // AddState implements input.EventDatabase -func (d *Database) AddState(roomNID int64, stateDataNIDs []int64, state []types.StateEntry) (stateNID int64, err error) { +func (d *Database) AddState(roomNID types.RoomNID, stateDataNIDs []types.StateDataNID, state []types.StateEntry) (types.StateNID, error) { if len(state) > 0 { - var stateDataNID int64 - if stateDataNID, err = d.statements.selectNextStateDataNID(); err != nil { - return + stateDataNID, err := d.statements.selectNextStateDataNID() + if err != nil { + return 0, err } if err = d.statements.bulkInsertStateData(stateDataNID, state); err != nil { - return + return 0, err } stateDataNIDs = append(stateDataNIDs[:len(stateDataNIDs):len(stateDataNIDs)], stateDataNID) } @@ -177,7 +177,7 @@ func (d *Database) AddState(roomNID int64, stateDataNIDs []int64, state []types. } // SetState implements input.EventDatabase -func (d *Database) SetState(eventNID, stateNID int64) error { +func (d *Database) SetState(eventNID types.EventNID, stateNID types.StateNID) error { return d.statements.updateEventState(eventNID, stateNID) } @@ -187,11 +187,11 @@ func (d *Database) StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, err } // StateDataNIDs implements input.EventDatabase -func (d *Database) StateDataNIDs(stateNIDs []int64) ([]types.StateDataNIDList, error) { +func (d *Database) StateDataNIDs(stateNIDs []types.StateNID) ([]types.StateDataNIDList, error) { return d.statements.bulkSelectStateDataNIDs(stateNIDs) } // StateEntries implements input.EventDatabase -func (d *Database) StateEntries(stateDataNIDs []int64) ([]types.StateEntryList, error) { +func (d *Database) StateEntries(stateDataNIDs []types.StateDataNID) ([]types.StateEntryList, error) { return d.statements.bulkSelectStateDataEntries(stateDataNIDs) } diff --git a/src/github.com/matrix-org/dendrite/roomserver/types/types.go b/src/github.com/matrix-org/dendrite/roomserver/types/types.go index 096547a6f..5cdf59695 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/types/types.go +++ b/src/github.com/matrix-org/dendrite/roomserver/types/types.go @@ -13,13 +13,32 @@ type PartitionOffset struct { Offset int64 } +// EventTypeNID is a numeric ID for an event type. +type EventTypeNID int64 + +// EventStateKeyNID is a numeric ID for an event state_key. +type EventStateKeyNID int64 + +// EventNID is a numeric ID for an event. +type EventNID int64 + +// RoomNID is a numeric ID for a room. +type RoomNID int64 + +// StateNID is a numeric ID for the state at an event. +type StateNID int64 + +// StateDataNID is a numeric ID for a block of state data. +// These blocks of state data are combined to form the actual state. +type StateDataNID int64 + // A StateKeyTuple is a pair of a numeric event type and a numeric state key. // It is used to lookup state entries. type StateKeyTuple struct { // The numeric ID for the event type. - EventTypeNID int64 + EventTypeNID EventTypeNID // The numeric ID for the state key. - EventStateKeyNID int64 + EventStateKeyNID EventStateKeyNID } // LessThan returns true if this state key is less than the other state key. @@ -35,7 +54,7 @@ func (a StateKeyTuple) LessThan(b StateKeyTuple) bool { type StateEntry struct { StateKeyTuple // The numeric ID for the event. - EventNID int64 + EventNID EventNID } // LessThan returns true if this state entry is less than the other state entry. @@ -50,7 +69,7 @@ func (a StateEntry) LessThan(b StateEntry) bool { // StateAtEvent is the state before and after a matrix event. type StateAtEvent struct { // The state before the event. - BeforeStateNID int64 + BeforeStateNID StateNID // The state entry for the event itself, allows us to calculate the state after the event. StateEntry } @@ -58,7 +77,7 @@ type StateAtEvent struct { // An Event is a gomatrixserverlib.Event with the numeric event ID attached. // It is when performing bulk event lookup in the database. type Event struct { - EventNID int64 + EventNID EventNID gomatrixserverlib.Event } @@ -86,12 +105,12 @@ const ( // StateDataNIDList is used to return the result of bulk StateDataNID lookups from the database. type StateDataNIDList struct { - StateNID int64 - StateDataNIDs []int64 + StateNID StateNID + StateDataNIDs []StateDataNID } // StateEntryList is used to return the result of bulk state entry lookups from the database. type StateEntryList struct { - StateDataNID int64 + StateDataNID StateDataNID StateEntries []StateEntry }