mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-26 00:01:55 -06:00
Check that events pass authentication checks. (#4)
* Check that events pass authentication checks. Record the list of events that the event passes authentication checks against.
This commit is contained in:
parent
600f56b4b8
commit
fc4eb85379
|
@ -30,7 +30,13 @@ type InputRoomEvent struct {
|
||||||
Kind int
|
Kind int
|
||||||
// The event JSON for the event to add.
|
// The event JSON for the event to add.
|
||||||
Event []byte
|
Event []byte
|
||||||
|
// List of state event IDs that authenticate this event.
|
||||||
|
// These are likely derived from the "auth_events" JSON key of the event.
|
||||||
|
// But can be different because the "auth_events" key can be incomplete or wrong.
|
||||||
|
// For example many matrix events forget to reference the m.room.create event even though it is needed for auth.
|
||||||
|
// (since synapse allows this to happen we have to allow it as well.)
|
||||||
|
AuthEventIDs []string
|
||||||
// Optional list of state event IDs forming the state before this event.
|
// Optional list of state event IDs forming the state before this event.
|
||||||
// These state events must have already been persisted.
|
// These state events must have already been persisted.
|
||||||
State []string
|
StateEventIDs []string
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,12 +2,26 @@ package input
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"sort"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A RoomEventDatabase has the storage APIs needed to store a room event.
|
// A RoomEventDatabase has the storage APIs needed to store a room event.
|
||||||
type RoomEventDatabase interface {
|
type RoomEventDatabase interface {
|
||||||
StoreEvent(event gomatrixserverlib.Event) error
|
// Stores a matrix room event in the database
|
||||||
|
StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) error
|
||||||
|
// Lookup the state entries for a list of string event IDs
|
||||||
|
// Returns a sorted list of state entries.
|
||||||
|
// Returns a error if the there is an error talking to the database
|
||||||
|
// or if the event IDs aren't in the database.
|
||||||
|
StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error)
|
||||||
|
// Lookup the numeric IDs for a list of string event state keys.
|
||||||
|
// Returns a map from string state key to numeric ID for the state key.
|
||||||
|
EventStateKeyNIDs(eventStateKeys []string) (map[string]int64, error)
|
||||||
|
// Lookup the Events for a list of numeric event IDs.
|
||||||
|
// Returns a sorted list of events.
|
||||||
|
Events(eventNIDs []int64) ([]types.Event, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
|
func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
|
||||||
|
@ -17,12 +31,16 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := db.StoreEvent(event); err != nil {
|
// Check that the event passes authentication checks and work out the numeric IDs for the auth events.
|
||||||
|
authEventNIDs, err := checkAuthEvents(db, event, input.AuthEventIDs)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO:
|
// Store the event
|
||||||
// * Check that the event passes authentication checks.
|
if err := db.StoreEvent(event, authEventNIDs); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if input.Kind == api.KindOutlier {
|
if input.Kind == api.KindOutlier {
|
||||||
// For outliers we can stop after we've stored the event itself as it
|
// For outliers we can stop after we've stored the event itself as it
|
||||||
|
@ -44,3 +62,193 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
|
||||||
// - The changes to the current state of the room.
|
// - The changes to the current state of the room.
|
||||||
panic("Not implemented")
|
panic("Not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkAuthEvents checks that the event passes authentication checks
|
||||||
|
// Returns the numeric IDs for the auth events.
|
||||||
|
func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEventIDs []string) ([]int64, error) {
|
||||||
|
// Grab the numeric IDs for the supplied auth state events from the database.
|
||||||
|
authStateEntries, err := db.StateEntriesForEventIDs(authEventIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// TODO: check for duplicate state keys here.
|
||||||
|
|
||||||
|
// Work out which of the state events we actually need.
|
||||||
|
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event})
|
||||||
|
|
||||||
|
// Load the actual auth events from the database.
|
||||||
|
authEvents, err := loadAuthEvents(db, stateNeeded, authStateEntries)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the event is allowed.
|
||||||
|
if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the numeric IDs for the auth events.
|
||||||
|
result := make([]int64, len(authStateEntries))
|
||||||
|
for i := range authStateEntries {
|
||||||
|
result[i] = authStateEntries[i].EventNID
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type authEvents struct {
|
||||||
|
stateKeyNIDMap map[string]int64
|
||||||
|
state stateEntryMap
|
||||||
|
events eventMap
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create implements gomatrixserverlib.AuthEvents
|
||||||
|
func (ae *authEvents) Create() (*gomatrixserverlib.Event, error) {
|
||||||
|
return ae.lookupEventWithEmptyStateKey(types.MRoomCreateNID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PowerLevels implements gomatrixserverlib.AuthEvents
|
||||||
|
func (ae *authEvents) PowerLevels() (*gomatrixserverlib.Event, error) {
|
||||||
|
return ae.lookupEventWithEmptyStateKey(types.MRoomPowerLevelsNID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinRules implements gomatrixserverlib.AuthEvents
|
||||||
|
func (ae *authEvents) JoinRules() (*gomatrixserverlib.Event, error) {
|
||||||
|
return ae.lookupEventWithEmptyStateKey(types.MRoomJoinRulesNID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Memmber implements gomatrixserverlib.AuthEvents
|
||||||
|
func (ae *authEvents) Member(stateKey string) (*gomatrixserverlib.Event, error) {
|
||||||
|
return ae.lookupEvent(types.MRoomMemberNID, stateKey), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ThirdPartyInvite implements gomatrixserverlib.AuthEvents
|
||||||
|
func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Event, error) {
|
||||||
|
return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID int64) *gomatrixserverlib.Event {
|
||||||
|
eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, types.EmptyStateKeyNID})
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
event, ok := ae.events.lookup(eventNID)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &event.Event
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ae *authEvents) lookupEvent(typeNID int64, stateKey string) *gomatrixserverlib.Event {
|
||||||
|
stateKeyNID, ok := ae.stateKeyNIDMap[stateKey]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, stateKeyNID})
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
event, ok := ae.events.lookup(eventNID)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &event.Event
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadAuthEvents loads the events needed for authentication from the supplied room state.
|
||||||
|
func loadAuthEvents(
|
||||||
|
db RoomEventDatabase,
|
||||||
|
needed gomatrixserverlib.StateNeeded,
|
||||||
|
state []types.StateEntry,
|
||||||
|
) (result authEvents, err error) {
|
||||||
|
// Lookup the numeric IDs for the state keys needed for auth.
|
||||||
|
var neededStateKeys []string
|
||||||
|
neededStateKeys = append(neededStateKeys, needed.Member...)
|
||||||
|
neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
|
||||||
|
if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(neededStateKeys); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the events we need.
|
||||||
|
result.state = state
|
||||||
|
var eventNIDs []int64
|
||||||
|
keyTuplesNeeded := stateKeyTuplesNeeded(result.stateKeyNIDMap, needed)
|
||||||
|
for _, keyTuple := range keyTuplesNeeded {
|
||||||
|
eventNID, ok := result.state.lookup(keyTuple)
|
||||||
|
if ok {
|
||||||
|
eventNIDs = append(eventNIDs, eventNID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if result.events, err = db.Events(eventNIDs); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events.
|
||||||
|
func stateKeyTuplesNeeded(stateKeyNIDMap map[string]int64, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple {
|
||||||
|
var keyTuples []types.StateKeyTuple
|
||||||
|
if stateNeeded.Create {
|
||||||
|
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID})
|
||||||
|
}
|
||||||
|
if stateNeeded.PowerLevels {
|
||||||
|
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomPowerLevelsNID, types.EmptyStateKeyNID})
|
||||||
|
}
|
||||||
|
if stateNeeded.JoinRules {
|
||||||
|
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomJoinRulesNID, types.EmptyStateKeyNID})
|
||||||
|
}
|
||||||
|
for _, member := range stateNeeded.Member {
|
||||||
|
stateKeyNID, ok := stateKeyNIDMap[member]
|
||||||
|
if ok {
|
||||||
|
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomMemberNID, stateKeyNID})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, token := range stateNeeded.ThirdPartyInvite {
|
||||||
|
stateKeyNID, ok := stateKeyNIDMap[token]
|
||||||
|
if ok {
|
||||||
|
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomThirdPartyInviteNID, stateKeyNID})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return keyTuples
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map from event type, state key tuple to numeric event ID.
|
||||||
|
// Implemented using binary search on a sorted array.
|
||||||
|
type stateEntryMap []types.StateEntry
|
||||||
|
|
||||||
|
// lookup an entry in the event map.
|
||||||
|
func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID int64, ok bool) {
|
||||||
|
// Since the list is sorted we can implement this using binary search.
|
||||||
|
// This is faster than using a hash map.
|
||||||
|
// We don't have to worry about pathological cases because the keys are fixed
|
||||||
|
// size and are controlled by us.
|
||||||
|
list := []types.StateEntry(m)
|
||||||
|
i := sort.Search(len(list), func(i int) bool {
|
||||||
|
return !list[i].StateKeyTuple.LessThan(stateKey)
|
||||||
|
})
|
||||||
|
if i < len(list) && list[i].StateKeyTuple == stateKey {
|
||||||
|
ok = true
|
||||||
|
eventNID = list[i].EventNID
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map from numeric event ID to event.
|
||||||
|
// Implemented using binary search on a sorted array.
|
||||||
|
type eventMap []types.Event
|
||||||
|
|
||||||
|
// lookup an entry in the event map.
|
||||||
|
func (m eventMap) lookup(eventNID int64) (event *types.Event, ok bool) {
|
||||||
|
// Since the list is sorted we can implement this using binary search.
|
||||||
|
// This is faster than using a hash map.
|
||||||
|
// We don't have to worry about pathological cases because the keys are fixed
|
||||||
|
// size are controlled by us.
|
||||||
|
list := []types.Event(m)
|
||||||
|
i := sort.Search(len(list), func(i int) bool {
|
||||||
|
return list[i].EventNID >= eventNID
|
||||||
|
})
|
||||||
|
if i < len(list) && list[i].EventNID == eventNID {
|
||||||
|
ok = true
|
||||||
|
event = &list[i]
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,112 @@
|
||||||
|
package input
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func benchmarkStateEntryMapLookup(entries, lookups int64, b *testing.B) {
|
||||||
|
var list []types.StateEntry
|
||||||
|
for i := int64(0); i < entries; i++ {
|
||||||
|
list = append(list, types.StateEntry{types.StateKeyTuple{i, i}, i})
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
entryMap := stateEntryMap(list)
|
||||||
|
for j := int64(0); j < lookups; j++ {
|
||||||
|
entryMap.lookup(types.StateKeyTuple{j, j})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkStateEntryMap100Lookup10(b *testing.B) {
|
||||||
|
benchmarkStateEntryMapLookup(100, 10, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkStateEntryMap1000Lookup100(b *testing.B) {
|
||||||
|
benchmarkStateEntryMapLookup(1000, 100, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkStateEntryMap100Lookup100(b *testing.B) {
|
||||||
|
benchmarkStateEntryMapLookup(100, 100, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkStateEntryMap1000Lookup10000(b *testing.B) {
|
||||||
|
benchmarkStateEntryMapLookup(1000, 10000, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStateEntryMap(t *testing.T) {
|
||||||
|
entryMap := stateEntryMap([]types.StateEntry{
|
||||||
|
{types.StateKeyTuple{1, 1}, 1},
|
||||||
|
{types.StateKeyTuple{1, 3}, 2},
|
||||||
|
{types.StateKeyTuple{2, 1}, 3},
|
||||||
|
})
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
inputTypeNID int64
|
||||||
|
inputStateKey int64
|
||||||
|
wantOK bool
|
||||||
|
wantEventNID int64
|
||||||
|
}{
|
||||||
|
// Check that tuples that in the array are in the map.
|
||||||
|
{1, 1, true, 1},
|
||||||
|
{1, 3, true, 2},
|
||||||
|
{2, 1, true, 3},
|
||||||
|
// Check that tuples that aren't in the array aren't in the map.
|
||||||
|
{0, 0, false, 0},
|
||||||
|
{1, 2, false, 0},
|
||||||
|
{3, 1, false, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
keyTuple := types.StateKeyTuple{testCase.inputTypeNID, testCase.inputStateKey}
|
||||||
|
gotEventNID, gotOK := entryMap.lookup(keyTuple)
|
||||||
|
if testCase.wantOK != gotOK {
|
||||||
|
t.Fatalf("stateEntryMap lookup(%v): want ok to be %v, got %v", keyTuple, testCase.wantOK, gotOK)
|
||||||
|
}
|
||||||
|
if testCase.wantEventNID != gotEventNID {
|
||||||
|
t.Fatalf("stateEntryMap lookup(%v): want eventNID to be %v, got %v", keyTuple, testCase.wantEventNID, gotEventNID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventMap(t *testing.T) {
|
||||||
|
events := eventMap([]types.Event{
|
||||||
|
{EventNID: 1},
|
||||||
|
{EventNID: 2},
|
||||||
|
{EventNID: 3},
|
||||||
|
{EventNID: 5},
|
||||||
|
{EventNID: 8},
|
||||||
|
})
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
inputEventNID int64
|
||||||
|
wantOK bool
|
||||||
|
wantEvent *types.Event
|
||||||
|
}{
|
||||||
|
// Check that the IDs that are in the array are in the map.
|
||||||
|
{1, true, &events[0]},
|
||||||
|
{2, true, &events[1]},
|
||||||
|
{3, true, &events[2]},
|
||||||
|
{5, true, &events[3]},
|
||||||
|
{8, true, &events[4]},
|
||||||
|
// Check that tuples that aren't in the array aren't in the map.
|
||||||
|
{0, false, nil},
|
||||||
|
{4, false, nil},
|
||||||
|
{6, false, nil},
|
||||||
|
{7, false, nil},
|
||||||
|
{9, false, nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
gotEvent, gotOK := events.lookup(testCase.inputEventNID)
|
||||||
|
if testCase.wantOK != gotOK {
|
||||||
|
t.Fatalf("eventMap lookup(%v): want ok to be %v, got %v", testCase.inputEventNID, testCase.wantOK, gotOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
if testCase.wantEvent != gotEvent {
|
||||||
|
t.Fatalf("eventMap lookup(%v): want event to be %v, got %v", testCase.inputEventNID, testCase.wantEvent, gotEvent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -2,20 +2,25 @@ package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"github.com/lib/pq"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type statements struct {
|
type statements struct {
|
||||||
selectPartitionOffsetsStmt *sql.Stmt
|
selectPartitionOffsetsStmt *sql.Stmt
|
||||||
upsertPartitionOffsetStmt *sql.Stmt
|
upsertPartitionOffsetStmt *sql.Stmt
|
||||||
insertEventTypeNIDStmt *sql.Stmt
|
insertEventTypeNIDStmt *sql.Stmt
|
||||||
selectEventTypeNIDStmt *sql.Stmt
|
selectEventTypeNIDStmt *sql.Stmt
|
||||||
insertEventStateKeyNIDStmt *sql.Stmt
|
insertEventStateKeyNIDStmt *sql.Stmt
|
||||||
selectEventStateKeyNIDStmt *sql.Stmt
|
selectEventStateKeyNIDStmt *sql.Stmt
|
||||||
insertRoomNIDStmt *sql.Stmt
|
bulkSelectEventStateKeyNIDStmt *sql.Stmt
|
||||||
selectRoomNIDStmt *sql.Stmt
|
insertRoomNIDStmt *sql.Stmt
|
||||||
insertEventStmt *sql.Stmt
|
selectRoomNIDStmt *sql.Stmt
|
||||||
insertEventJSONStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
|
bulkSelectStateEventByIDStmt *sql.Stmt
|
||||||
|
insertEventJSONStmt *sql.Stmt
|
||||||
|
bulkSelectEventJSONStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *statements) prepare(db *sql.DB) error {
|
func (s *statements) prepare(db *sql.DB) error {
|
||||||
|
@ -196,6 +201,9 @@ func (s *statements) prepareEventStateKeys(db *sql.DB) (err error) {
|
||||||
if s.selectEventStateKeyNIDStmt, err = db.Prepare(selectEventStateKeyNIDSQL); err != nil {
|
if s.selectEventStateKeyNIDStmt, err = db.Prepare(selectEventStateKeyNIDSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.bulkSelectEventStateKeyNIDStmt, err = db.Prepare(bulkSelectEventStateKeyNIDSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -230,6 +238,12 @@ const insertEventStateKeyNIDSQL = "" +
|
||||||
const selectEventStateKeyNIDSQL = "" +
|
const selectEventStateKeyNIDSQL = "" +
|
||||||
"SELECT event_state_key_nid FROM event_state_keys WHERE event_state_key = $1"
|
"SELECT event_state_key_nid FROM event_state_keys WHERE event_state_key = $1"
|
||||||
|
|
||||||
|
// Bulk lookup from string state key to numeric ID for that state key.
|
||||||
|
// Takes an array of strings as the query parameter.
|
||||||
|
const bulkSelectEventStateKeyNIDSQL = "" +
|
||||||
|
"SELECT event_state_key, event_state_key_nid FROM event_state_keys" +
|
||||||
|
" WHERE event_state_key = ANY($1)"
|
||||||
|
|
||||||
func (s *statements) insertEventStateKeyNID(eventStateKey string) (eventStateKeyNID int64, err error) {
|
func (s *statements) insertEventStateKeyNID(eventStateKey string) (eventStateKeyNID int64, err error) {
|
||||||
err = s.insertEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID)
|
err = s.insertEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID)
|
||||||
return
|
return
|
||||||
|
@ -240,6 +254,25 @@ func (s *statements) selectEventStateKeyNID(eventStateKey string) (eventStateKey
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *statements) bulkSelectEventStateKeyNID(eventStateKeys []string) (map[string]int64, error) {
|
||||||
|
rows, err := s.bulkSelectEventStateKeyNIDStmt.Query(pq.StringArray(eventStateKeys))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
result := make(map[string]int64, len(eventStateKeys))
|
||||||
|
for rows.Next() {
|
||||||
|
var stateKey string
|
||||||
|
var stateKeyNID int64
|
||||||
|
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result[stateKey] = stateKeyNID
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *statements) prepareRooms(db *sql.DB) (err error) {
|
func (s *statements) prepareRooms(db *sql.DB) (err error) {
|
||||||
_, err = db.Exec(roomsSchema)
|
_, err = db.Exec(roomsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -307,17 +340,27 @@ CREATE TABLE IF NOT EXISTS events (
|
||||||
event_id TEXT NOT NULL CONSTRAINT event_id_unique UNIQUE,
|
event_id TEXT NOT NULL CONSTRAINT event_id_unique UNIQUE,
|
||||||
-- The sha256 reference hash for the event.
|
-- The sha256 reference hash for the event.
|
||||||
-- Needed for setting reference hashes when sending new events.
|
-- Needed for setting reference hashes when sending new events.
|
||||||
reference_sha256 BYTEA NOT NULL
|
reference_sha256 BYTEA NOT NULL,
|
||||||
|
-- A list of numeric IDs for events that can authenticate this event.
|
||||||
|
auth_event_nids BIGINT[] NOT NULL,
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertEventSQL = "" +
|
const insertEventSQL = "" +
|
||||||
"INSERT INTO events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256)" +
|
"INSERT INTO events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids)" +
|
||||||
" VALUES ($1, $2, $3, $4, $5)" +
|
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||||
" ON CONFLICT ON CONSTRAINT event_id_unique" +
|
" ON CONFLICT ON CONSTRAINT event_id_unique" +
|
||||||
" DO UPDATE SET event_id = $1" +
|
" DO UPDATE SET event_id = $1" +
|
||||||
" RETURNING event_nid"
|
" RETURNING event_nid"
|
||||||
|
|
||||||
|
// Bulk lookup of events by string ID.
|
||||||
|
// Sort by the numeric IDs for event type and state key.
|
||||||
|
// This means we can use binary search to lookup entries by type and state key.
|
||||||
|
const bulkSelectStateEventByIDSQL = "" +
|
||||||
|
"SELECT event_type_nid, event_state_key_nid, event_nid FROM events" +
|
||||||
|
" WHERE event_id = ANY($1)" +
|
||||||
|
" ORDER BY event_type_nid, event_state_key_nid ASC"
|
||||||
|
|
||||||
func (s *statements) prepareEvents(db *sql.DB) (err error) {
|
func (s *statements) prepareEvents(db *sql.DB) (err error) {
|
||||||
_, err = db.Exec(eventsSchema)
|
_, err = db.Exec(eventsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -326,6 +369,9 @@ func (s *statements) prepareEvents(db *sql.DB) (err error) {
|
||||||
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
|
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.bulkSelectStateEventByIDStmt, err = db.Prepare(bulkSelectStateEventByIDSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -333,13 +379,48 @@ func (s *statements) insertEvent(
|
||||||
roomNID, eventTypeNID, eventStateKeyNID int64,
|
roomNID, eventTypeNID, eventStateKeyNID int64,
|
||||||
eventID string,
|
eventID string,
|
||||||
referenceSHA256 []byte,
|
referenceSHA256 []byte,
|
||||||
|
authEventNIDs []int64,
|
||||||
) (eventNID int64, err error) {
|
) (eventNID int64, err error) {
|
||||||
err = s.insertEventStmt.QueryRow(
|
err = s.insertEventStmt.QueryRow(
|
||||||
roomNID, eventTypeNID, eventStateKeyNID, eventID, referenceSHA256,
|
roomNID, eventTypeNID, eventStateKeyNID, eventID, referenceSHA256,
|
||||||
|
pq.Int64Array(authEventNIDs),
|
||||||
).Scan(&eventNID)
|
).Scan(&eventNID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *statements) bulkSelectStateEventByID(eventIDs []string) ([]types.StateEntry, error) {
|
||||||
|
rows, err := s.bulkSelectStateEventByIDStmt.Query(pq.StringArray(eventIDs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
// We know that we will only get as many results as event IDs
|
||||||
|
// because of the unique constraint on event IDs.
|
||||||
|
// So we can allocate an array of the correct size now.
|
||||||
|
// We might get fewer results than IDs so we adjust the length of the slice before returning it.
|
||||||
|
results := make([]types.StateEntry, len(eventIDs))
|
||||||
|
i := 0
|
||||||
|
for ; rows.Next(); i++ {
|
||||||
|
result := &results[i]
|
||||||
|
if err = rows.Scan(
|
||||||
|
&result.EventNID,
|
||||||
|
&result.EventTypeNID,
|
||||||
|
&result.EventStateKeyNID,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if i != len(eventIDs) {
|
||||||
|
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
|
||||||
|
// We don't know which ones were missing because we don't return the string IDs in the query.
|
||||||
|
// However it should be possible debug this by replaying queries or entries from the input kafka logs.
|
||||||
|
// If this turns out to be impossible and we do need the debug information here, it would be better
|
||||||
|
// to do it as a separate query rather than slowing down/complicating the common case.
|
||||||
|
return nil, fmt.Errorf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs))
|
||||||
|
}
|
||||||
|
return results, err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *statements) prepareEventJSON(db *sql.DB) (err error) {
|
func (s *statements) prepareEventJSON(db *sql.DB) (err error) {
|
||||||
_, err = db.Exec(eventJSONSchema)
|
_, err = db.Exec(eventJSONSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -348,6 +429,9 @@ func (s *statements) prepareEventJSON(db *sql.DB) (err error) {
|
||||||
if s.insertEventJSONStmt, err = db.Prepare(insertEventJSONSQL); err != nil {
|
if s.insertEventJSONStmt, err = db.Prepare(insertEventJSONSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.bulkSelectEventJSONStmt, err = db.Prepare(bulkSelectEventJSONSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -372,7 +456,41 @@ const insertEventJSONSQL = "" +
|
||||||
"INSERT INTO event_json (event_nid, event_json) VALUES ($1, $2)" +
|
"INSERT INTO event_json (event_nid, event_json) VALUES ($1, $2)" +
|
||||||
" ON CONFLICT DO NOTHING"
|
" ON CONFLICT DO NOTHING"
|
||||||
|
|
||||||
|
// Bulk event JSON lookup by numeric event ID.
|
||||||
|
// Sort by the numeric event ID.
|
||||||
|
// This means that we can use binary search to lookup by numeric event ID.
|
||||||
|
const bulkSelectEventJSONSQL = "" +
|
||||||
|
"SELECT event_nid, event_json FROM event_json" +
|
||||||
|
" WHERE event_nid = ANY($1)" +
|
||||||
|
" ORDER BY event_nid ASC"
|
||||||
|
|
||||||
func (s *statements) insertEventJSON(eventNID int64, eventJSON []byte) error {
|
func (s *statements) insertEventJSON(eventNID int64, eventJSON []byte) error {
|
||||||
_, err := s.insertEventJSONStmt.Exec(eventNID, eventJSON)
|
_, err := s.insertEventJSONStmt.Exec(eventNID, eventJSON)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type eventJSONPair struct {
|
||||||
|
EventNID int64
|
||||||
|
EventJSON []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *statements) bulkSelectEventJSON(eventNIDs []int64) ([]eventJSONPair, error) {
|
||||||
|
rows, err := s.bulkSelectEventJSONStmt.Query(pq.Int64Array(eventNIDs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
// We know that we will only get as many results as event NIDs
|
||||||
|
// because of the unique constraint on event NIDs.
|
||||||
|
// So we can allocate an array of the correct size now.
|
||||||
|
// We might get fewer results than NIDs so we adjust the length of the slice before returning it.
|
||||||
|
results := make([]eventJSONPair, len(eventNIDs))
|
||||||
|
i := 0
|
||||||
|
for ; rows.Next(); i++ {
|
||||||
|
if err := rows.Scan(&results[i].EventNID, &results[i].EventJSON); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return results[:i], nil
|
||||||
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6
|
||||||
}
|
}
|
||||||
|
|
||||||
// StoreEvent implements input.EventDatabase
|
// StoreEvent implements input.EventDatabase
|
||||||
func (d *Database) StoreEvent(event gomatrixserverlib.Event) error {
|
func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) error {
|
||||||
var (
|
var (
|
||||||
roomNID int64
|
roomNID int64
|
||||||
eventTypeNID int64
|
eventTypeNID int64
|
||||||
|
@ -70,6 +70,7 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event) error {
|
||||||
eventStateKeyNID,
|
eventStateKeyNID,
|
||||||
event.EventID(),
|
event.EventID(),
|
||||||
event.EventReference().EventSHA256,
|
event.EventReference().EventSHA256,
|
||||||
|
authEventNIDs,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -115,3 +116,32 @@ func (d *Database) assignStateKeyNID(eventStateKey string) (int64, error) {
|
||||||
}
|
}
|
||||||
return eventStateKeyNID, nil
|
return eventStateKeyNID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StateEntriesForEventIDs implements input.EventDatabase
|
||||||
|
func (d *Database) StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error) {
|
||||||
|
return d.statements.bulkSelectStateEventByID(eventIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventStateKeyNIDs implements input.EventDatabase
|
||||||
|
func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]int64, error) {
|
||||||
|
return d.statements.bulkSelectEventStateKeyNID(eventStateKeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Events implements input.EventDatabase
|
||||||
|
func (d *Database) Events(eventNIDs []int64) ([]types.Event, error) {
|
||||||
|
eventJSONs, err := d.statements.bulkSelectEventJSON(eventNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
results := make([]types.Event, len(eventJSONs))
|
||||||
|
for i, eventJSON := range eventJSONs {
|
||||||
|
result := &results[i]
|
||||||
|
result.EventNID = eventJSON.EventNID
|
||||||
|
// TODO: Use NewEventFromTrustedJSON for efficiency
|
||||||
|
result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON(eventJSON.EventJSON)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
// Package types provides the types that are used internally within the roomserver.
|
// Package types provides the types that are used internally within the roomserver.
|
||||||
package types
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
// A PartitionOffset is the offset into a partition of the input log.
|
// A PartitionOffset is the offset into a partition of the input log.
|
||||||
type PartitionOffset struct {
|
type PartitionOffset struct {
|
||||||
// The ID of the partition.
|
// The ID of the partition.
|
||||||
|
@ -8,3 +12,66 @@ type PartitionOffset struct {
|
||||||
// The offset into the partition.
|
// The offset into the partition.
|
||||||
Offset int64
|
Offset int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A StateKeyTuple is a pair of a numeric event type and a numeric state key.
|
||||||
|
// It is used to lookup state entries.
|
||||||
|
type StateKeyTuple struct {
|
||||||
|
// The numeric ID for the event type.
|
||||||
|
EventTypeNID int64
|
||||||
|
// The numeric ID for the state key.
|
||||||
|
EventStateKeyNID int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// LessThan returns true if this state key is less than the other state key.
|
||||||
|
// The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries.
|
||||||
|
func (a StateKeyTuple) LessThan(b StateKeyTuple) bool {
|
||||||
|
if a.EventTypeNID != b.EventTypeNID {
|
||||||
|
return a.EventTypeNID < b.EventTypeNID
|
||||||
|
}
|
||||||
|
return a.EventStateKeyNID < b.EventStateKeyNID
|
||||||
|
}
|
||||||
|
|
||||||
|
// A StateEntry is an entry in the room state of a matrix room.
|
||||||
|
type StateEntry struct {
|
||||||
|
StateKeyTuple
|
||||||
|
// The numeric ID for the event.
|
||||||
|
EventNID int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// LessThan returns true if this state entry is less than the other state entry.
|
||||||
|
// The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries.
|
||||||
|
func (a StateEntry) LessThan(b StateEntry) bool {
|
||||||
|
if a.StateKeyTuple != b.StateKeyTuple {
|
||||||
|
return a.StateKeyTuple.LessThan(b.StateKeyTuple)
|
||||||
|
}
|
||||||
|
return a.EventNID < b.EventNID
|
||||||
|
}
|
||||||
|
|
||||||
|
// An Event is a gomatrixserverlib.Event with the numeric event ID attached.
|
||||||
|
// It is when performing bulk event lookup in the database.
|
||||||
|
type Event struct {
|
||||||
|
EventNID int64
|
||||||
|
gomatrixserverlib.Event
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MRoomCreateNID is the numeric ID for the "m.room.create" event type.
|
||||||
|
MRoomCreateNID = 1
|
||||||
|
// MRoomPowerLevelsNID is the numeric ID for the "m.room.power_levels" event type.
|
||||||
|
MRoomPowerLevelsNID = 2
|
||||||
|
// MRoomJoinRulesNID is the numeric ID for the "m.room.join_rules" event type.
|
||||||
|
MRoomJoinRulesNID = 3
|
||||||
|
// MRoomThirdPartyInviteNID is the numeric ID for the "m.room.third_party_invite" event type.
|
||||||
|
MRoomThirdPartyInviteNID = 4
|
||||||
|
// MRoomMemberNID is the numeric ID for the "m.room.member" event type.
|
||||||
|
MRoomMemberNID = 5
|
||||||
|
// MRoomRedactionNID is the numeric ID for the "m.room.redaction" event type.
|
||||||
|
MRoomRedactionNID = 6
|
||||||
|
// MRoomHistoryVisibilityNID is the numeric ID for the "m.room.history_visibility" event type.
|
||||||
|
MRoomHistoryVisibilityNID = 7
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// EmptyStateKeyNID is the numeric ID for the empty state key.
|
||||||
|
EmptyStateKeyNID = 1
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in a new issue