Refactor StoreEvent and create a new RoomDatabase interface (#2985)

This PR changes a few things:
- It pulls out the creation of several NIDs from the `StoreEvent`
function to make the functions more reusable
- Uses more caching when using those NIDs to avoid DB round trips
This commit is contained in:
Till 2023-02-24 09:40:20 +01:00 committed by GitHub
parent e6aa0955ff
commit ad07b169b8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 459 additions and 302 deletions

View file

@ -87,7 +87,7 @@ func main() {
} }
var eventEntries []types.Event var eventEntries []types.Event
eventEntries, err = roomserverDB.Events(ctx, eventNIDs) eventEntries, err = roomserverDB.Events(ctx, 0, eventNIDs)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -145,7 +145,7 @@ func main() {
} }
fmt.Println("Fetching", len(eventNIDMap), "state events") fmt.Println("Fetching", len(eventNIDMap), "state events")
eventEntries, err := roomserverDB.Events(ctx, eventNIDs) eventEntries, err := roomserverDB.Events(ctx, 0, eventNIDs)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -165,7 +165,7 @@ func main() {
} }
fmt.Println("Fetching", len(authEventIDs), "auth events") fmt.Println("Fetching", len(authEventIDs), "auth events")
authEventEntries, err := roomserverDB.EventsFromIDs(ctx, authEventIDs) authEventEntries, err := roomserverDB.EventsFromIDs(ctx, 0, authEventIDs)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View file

@ -7,6 +7,7 @@ import "github.com/matrix-org/dendrite/roomserver/types"
type EventStateKeyCache interface { type EventStateKeyCache interface {
GetEventStateKey(eventStateKeyNID types.EventStateKeyNID) (string, bool) GetEventStateKey(eventStateKeyNID types.EventStateKeyNID) (string, bool)
StoreEventStateKey(eventStateKeyNID types.EventStateKeyNID, eventStateKey string) StoreEventStateKey(eventStateKeyNID types.EventStateKeyNID, eventStateKey string)
GetEventStateKeyNID(eventStateKey string) (types.EventStateKeyNID, bool)
} }
func (c Caches) GetEventStateKey(eventStateKeyNID types.EventStateKeyNID) (string, bool) { func (c Caches) GetEventStateKey(eventStateKeyNID types.EventStateKeyNID) (string, bool) {
@ -15,4 +16,23 @@ func (c Caches) GetEventStateKey(eventStateKeyNID types.EventStateKeyNID) (strin
func (c Caches) StoreEventStateKey(eventStateKeyNID types.EventStateKeyNID, eventStateKey string) { func (c Caches) StoreEventStateKey(eventStateKeyNID types.EventStateKeyNID, eventStateKey string) {
c.RoomServerStateKeys.Set(eventStateKeyNID, eventStateKey) c.RoomServerStateKeys.Set(eventStateKeyNID, eventStateKey)
c.RoomServerStateKeyNIDs.Set(eventStateKey, eventStateKeyNID)
}
func (c Caches) GetEventStateKeyNID(eventStateKey string) (types.EventStateKeyNID, bool) {
return c.RoomServerStateKeyNIDs.Get(eventStateKey)
}
type EventTypeCache interface {
GetEventTypeKey(eventType string) (types.EventTypeNID, bool)
StoreEventTypeKey(eventTypeNID types.EventTypeNID, eventType string)
}
func (c Caches) StoreEventTypeKey(eventTypeNID types.EventTypeNID, eventType string) {
c.RoomServerEventTypeNIDs.Set(eventType, eventTypeNID)
c.RoomServerEventTypes.Set(eventTypeNID, eventType)
}
func (c Caches) GetEventTypeKey(eventType string) (types.EventTypeNID, bool) {
return c.RoomServerEventTypeNIDs.Get(eventType)
} }

View file

@ -9,19 +9,28 @@ type RoomServerCaches interface {
RoomVersionCache RoomVersionCache
RoomServerEventsCache RoomServerEventsCache
EventStateKeyCache EventStateKeyCache
EventTypeCache
} }
// RoomServerNIDsCache contains the subset of functions needed for // RoomServerNIDsCache contains the subset of functions needed for
// a roomserver NID cache. // a roomserver NID cache.
type RoomServerNIDsCache interface { type RoomServerNIDsCache interface {
GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool)
// StoreRoomServerRoomID stores roomNID -> roomID and roomID -> roomNID
StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string)
GetRoomServerRoomNID(roomID string) (types.RoomNID, bool)
} }
func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) { func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) {
return c.RoomServerRoomIDs.Get(roomNID) return c.RoomServerRoomIDs.Get(roomNID)
} }
// StoreRoomServerRoomID stores roomNID -> roomID and roomID -> roomNID
func (c Caches) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) { func (c Caches) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) {
c.RoomServerRoomNIDs.Set(roomID, roomNID)
c.RoomServerRoomIDs.Set(roomNID, roomID) c.RoomServerRoomIDs.Set(roomNID, roomID)
} }
func (c Caches) GetRoomServerRoomNID(roomID string) (types.RoomNID, bool) {
return c.RoomServerRoomNIDs.Get(roomID)
}

View file

@ -28,7 +28,10 @@ type Caches struct {
RoomServerRoomNIDs Cache[string, types.RoomNID] // room ID -> room NID RoomServerRoomNIDs Cache[string, types.RoomNID] // room ID -> room NID
RoomServerRoomIDs Cache[types.RoomNID, string] // room NID -> room ID RoomServerRoomIDs Cache[types.RoomNID, string] // room NID -> room ID
RoomServerEvents Cache[int64, *gomatrixserverlib.Event] // event NID -> event RoomServerEvents Cache[int64, *gomatrixserverlib.Event] // event NID -> event
RoomServerStateKeys Cache[types.EventStateKeyNID, string] // event NID -> event state key RoomServerStateKeys Cache[types.EventStateKeyNID, string] // eventStateKey NID -> event state key
RoomServerStateKeyNIDs Cache[string, types.EventStateKeyNID] // event state key -> eventStateKey NID
RoomServerEventTypeNIDs Cache[string, types.EventTypeNID] // eventType -> eventType NID
RoomServerEventTypes Cache[types.EventTypeNID, string] // eventType NID -> eventType
FederationPDUs Cache[int64, *gomatrixserverlib.HeaderedEvent] // queue NID -> PDU FederationPDUs Cache[int64, *gomatrixserverlib.HeaderedEvent] // queue NID -> PDU
FederationEDUs Cache[int64, *gomatrixserverlib.EDU] // queue NID -> EDU FederationEDUs Cache[int64, *gomatrixserverlib.EDU] // queue NID -> EDU
SpaceSummaryRooms Cache[string, gomatrixserverlib.MSC2946SpacesResponse] // room ID -> space response SpaceSummaryRooms Cache[string, gomatrixserverlib.MSC2946SpacesResponse] // room ID -> space response

View file

@ -40,6 +40,9 @@ const (
spaceSummaryRoomsCache spaceSummaryRoomsCache
lazyLoadingCache lazyLoadingCache
eventStateKeyCache eventStateKeyCache
eventTypeCache
eventTypeNIDCache
eventStateKeyNIDCache
) )
func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enablePrometheus bool) *Caches { func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enablePrometheus bool) *Caches {
@ -105,6 +108,21 @@ func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enableProm
Prefix: eventStateKeyCache, Prefix: eventStateKeyCache,
MaxAge: maxAge, MaxAge: maxAge,
}, },
RoomServerStateKeyNIDs: &RistrettoCachePartition[string, types.EventStateKeyNID]{ // eventStateKey -> eventStateKey NID
cache: cache,
Prefix: eventStateKeyNIDCache,
MaxAge: maxAge,
},
RoomServerEventTypeNIDs: &RistrettoCachePartition[string, types.EventTypeNID]{ // eventType -> eventType NID
cache: cache,
Prefix: eventTypeCache,
MaxAge: maxAge,
},
RoomServerEventTypes: &RistrettoCachePartition[types.EventTypeNID, string]{ // eventType NID -> eventType
cache: cache,
Prefix: eventTypeNIDCache,
MaxAge: maxAge,
},
FederationPDUs: &RistrettoCostedCachePartition[int64, *gomatrixserverlib.HeaderedEvent]{ // queue NID -> PDU FederationPDUs: &RistrettoCostedCachePartition[int64, *gomatrixserverlib.HeaderedEvent]{ // queue NID -> PDU
&RistrettoCachePartition[int64, *gomatrixserverlib.HeaderedEvent]{ &RistrettoCachePartition[int64, *gomatrixserverlib.HeaderedEvent]{
cache: cache, cache: cache,

View file

@ -30,26 +30,6 @@ import (
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
// RoomserverInternalAPIDatabase has the storage APIs needed to implement the alias API.
type RoomserverInternalAPIDatabase interface {
// Save a given room alias with the room ID it refers to.
// Returns an error if there was a problem talking to the database.
SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error
// Look up the room ID a given alias refers to.
// Returns an error if there was a problem talking to the database.
GetRoomIDForAlias(ctx context.Context, alias string) (string, error)
// Look up all aliases referring to a given room ID.
// Returns an error if there was a problem talking to the database.
GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error)
// Remove a given room alias.
// Returns an error if there was a problem talking to the database.
RemoveRoomAlias(ctx context.Context, alias string) error
// Look up the room version for a given room.
GetRoomVersionForRoom(
ctx context.Context, roomID string,
) (gomatrixserverlib.RoomVersion, error)
}
// SetRoomAlias implements alias.RoomserverInternalAPI // SetRoomAlias implements alias.RoomserverInternalAPI
func (r *RoomserverInternalAPI) SetRoomAlias( func (r *RoomserverInternalAPI) SetRoomAlias(
ctx context.Context, ctx context.Context,

View file

@ -155,7 +155,6 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
r.Unpeeker = &perform.Unpeeker{ r.Unpeeker = &perform.Unpeeker{
ServerName: r.ServerName, ServerName: r.ServerName,
Cfg: r.Cfg, Cfg: r.Cfg,
DB: r.DB,
FSAPI: r.fsAPI, FSAPI: r.fsAPI,
Inputer: r.Inputer, Inputer: r.Inputer,
} }

View file

@ -31,7 +31,8 @@ import (
// the soft-fail bool. // the soft-fail bool.
func CheckForSoftFail( func CheckForSoftFail(
ctx context.Context, ctx context.Context,
db storage.Database, db storage.RoomDatabase,
roomInfo *types.RoomInfo,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
stateEventIDs []string, stateEventIDs []string,
) (bool, error) { ) (bool, error) {
@ -45,16 +46,6 @@ func CheckForSoftFail(
return true, fmt.Errorf("StateEntriesForEventIDs failed: %w", err) return true, fmt.Errorf("StateEntriesForEventIDs failed: %w", err)
} }
} else { } else {
// Work out if the room exists.
var roomInfo *types.RoomInfo
roomInfo, err = db.RoomInfo(ctx, event.RoomID())
if err != nil {
return false, fmt.Errorf("db.RoomNID: %w", err)
}
if roomInfo == nil || roomInfo.IsStub() {
return false, nil
}
// Then get the state entries for the current state snapshot. // Then get the state entries for the current state snapshot.
// We'll use this to check if the event is allowed right now. // We'll use this to check if the event is allowed right now.
roomState := state.NewStateResolution(db, roomInfo) roomState := state.NewStateResolution(db, roomInfo)
@ -76,7 +67,7 @@ func CheckForSoftFail(
stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()})
// Load the actual auth events from the database. // Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries) authEvents, err := loadAuthEvents(ctx, db, roomInfo.RoomNID, stateNeeded, authStateEntries)
if err != nil { if err != nil {
return true, fmt.Errorf("loadAuthEvents: %w", err) return true, fmt.Errorf("loadAuthEvents: %w", err)
} }
@ -93,7 +84,8 @@ func CheckForSoftFail(
// Returns the numeric IDs for the auth events. // Returns the numeric IDs for the auth events.
func CheckAuthEvents( func CheckAuthEvents(
ctx context.Context, ctx context.Context,
db storage.Database, db storage.RoomDatabase,
roomNID types.RoomNID,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
authEventIDs []string, authEventIDs []string,
) ([]types.EventNID, error) { ) ([]types.EventNID, error) {
@ -108,7 +100,7 @@ func CheckAuthEvents(
stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()})
// Load the actual auth events from the database. // Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries) authEvents, err := loadAuthEvents(ctx, db, roomNID, stateNeeded, authStateEntries)
if err != nil { if err != nil {
return nil, fmt.Errorf("loadAuthEvents: %w", err) return nil, fmt.Errorf("loadAuthEvents: %w", err)
} }
@ -201,6 +193,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
func loadAuthEvents( func loadAuthEvents(
ctx context.Context, ctx context.Context,
db state.StateResolutionStorage, db state.StateResolutionStorage,
roomNID types.RoomNID,
needed gomatrixserverlib.StateNeeded, needed gomatrixserverlib.StateNeeded,
state []types.StateEntry, state []types.StateEntry,
) (result authEvents, err error) { ) (result authEvents, err error) {
@ -223,7 +216,7 @@ func loadAuthEvents(
eventNIDs = append(eventNIDs, eventNID) eventNIDs = append(eventNIDs, eventNID)
} }
} }
if result.events, err = db.Events(ctx, eventNIDs); err != nil { if result.events, err = db.Events(ctx, roomNID, eventNIDs); err != nil {
return return
} }
roomID := "" roomID := ""

View file

@ -85,7 +85,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
return false, err return false, err
} }
events, err := db.Events(ctx, eventNIDs) events, err := db.Events(ctx, info.RoomNID, eventNIDs)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -157,7 +157,7 @@ func IsInvitePending(
// only keep the "m.room.member" events with a "join" membership. These events are returned. // only keep the "m.room.member" events with a "join" membership. These events are returned.
// Returns an error if there was an issue fetching the events. // Returns an error if there was an issue fetching the events.
func GetMembershipsAtState( func GetMembershipsAtState(
ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool, ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry, joinedOnly bool,
) ([]types.Event, error) { ) ([]types.Event, error) {
var eventNIDs types.EventNIDs var eventNIDs types.EventNIDs
@ -177,7 +177,7 @@ func GetMembershipsAtState(
util.Unique(eventNIDs) util.Unique(eventNIDs)
// Get all of the events in this state // Get all of the events in this state
stateEvents, err := db.Events(ctx, eventNIDs) stateEvents, err := db.Events(ctx, roomNID, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -220,16 +220,16 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room
return roomState.LoadCombinedStateAfterEvents(ctx, prevState) return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
} }
func MembershipAtEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) { func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) {
roomState := state.NewStateResolution(db, info) roomState := state.NewStateResolution(db, info)
// Fetch the state as it was when this event was fired // Fetch the state as it was when this event was fired
return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID) return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID)
} }
func LoadEvents( func LoadEvents(
ctx context.Context, db storage.Database, eventNIDs []types.EventNID, ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, eventNIDs []types.EventNID,
) ([]*gomatrixserverlib.Event, error) { ) ([]*gomatrixserverlib.Event, error) {
stateEvents, err := db.Events(ctx, eventNIDs) stateEvents, err := db.Events(ctx, roomNID, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -242,13 +242,13 @@ func LoadEvents(
} }
func LoadStateEvents( func LoadStateEvents(
ctx context.Context, db storage.Database, stateEntries []types.StateEntry, ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry,
) ([]*gomatrixserverlib.Event, error) { ) ([]*gomatrixserverlib.Event, error) {
eventNIDs := make([]types.EventNID, len(stateEntries)) eventNIDs := make([]types.EventNID, len(stateEntries))
for i := range stateEntries { for i := range stateEntries {
eventNIDs[i] = stateEntries[i].EventNID eventNIDs[i] = stateEntries[i].EventNID
} }
return LoadEvents(ctx, db, eventNIDs) return LoadEvents(ctx, db, roomNID, eventNIDs)
} }
func CheckServerAllowedToSeeEvent( func CheckServerAllowedToSeeEvent(
@ -326,7 +326,7 @@ func slowGetHistoryVisibilityState(
return nil, nil return nil, nil
} }
return LoadStateEvents(ctx, db, filteredEntries) return LoadStateEvents(ctx, db, info.RoomNID, filteredEntries)
} }
// TODO: Remove this when we have tests to assert correctness of this function // TODO: Remove this when we have tests to assert correctness of this function
@ -366,7 +366,7 @@ BFSLoop:
next = make([]string, 0) next = make([]string, 0)
} }
// Retrieve the events to process from the database. // Retrieve the events to process from the database.
events, err = db.EventsFromIDs(ctx, front) events, err = db.EventsFromIDs(ctx, info.RoomNID, front)
if err != nil { if err != nil {
return resultNIDs, redactEventIDs, err return resultNIDs, redactEventIDs, err
} }
@ -467,7 +467,7 @@ func QueryLatestEventsAndState(
return err return err
} }
stateEvents, err := LoadStateEvents(ctx, db, stateEntries) stateEvents, err := LoadStateEvents(ctx, db, roomInfo.RoomNID, stateEntries)
if err != nil { if err != nil {
return err return err
} }

View file

@ -38,7 +38,18 @@ func TestIsInvitePendingWithoutNID(t *testing.T) {
var authNIDs []types.EventNID var authNIDs []types.EventNID
for _, x := range room.Events() { for _, x := range room.Events() {
evNID, _, _, _, _, err := db.StoreEvent(context.Background(), x.Event, authNIDs, false) roomNID, err := db.GetOrCreateRoomNID(context.Background(), x.Unwrap())
assert.NoError(t, err)
assert.Greater(t, roomNID, types.RoomNID(0))
eventTypeNID, err := db.GetOrCreateEventTypeNID(context.Background(), x.Type())
assert.NoError(t, err)
assert.Greater(t, eventTypeNID, types.EventTypeNID(0))
eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(context.Background(), x.StateKey())
assert.NoError(t, err)
evNID, _, _, _, err := db.StoreEvent(context.Background(), x.Event, roomNID, eventTypeNID, eventStateKeyNID, authNIDs, false)
assert.NoError(t, err) assert.NoError(t, err)
authNIDs = append(authNIDs, evNID) authNIDs = append(authNIDs, evNID)
} }

View file

@ -76,7 +76,7 @@ type Inputer struct {
Cfg *config.RoomServer Cfg *config.RoomServer
Base *base.BaseDendrite Base *base.BaseDendrite
ProcessContext *process.ProcessContext ProcessContext *process.ProcessContext
DB storage.Database DB storage.RoomDatabase
NATSClient *nats.Conn NATSClient *nats.Conn
JetStream nats.JetStreamContext JetStream nats.JetStreamContext
Durable nats.SubOpt Durable nats.SubOpt

View file

@ -308,10 +308,10 @@ func (r *Inputer) processRoomEvent(
} }
var softfail bool var softfail bool
if input.Kind == api.KindNew { if input.Kind == api.KindNew && !isCreateEvent {
// Check that the event passes authentication checks based on the // Check that the event passes authentication checks based on the
// current room state. // current room state.
softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs)
if err != nil { if err != nil {
logger.WithError(err).Warn("Error authing soft-failed event") logger.WithError(err).Warn("Error authing soft-failed event")
} }
@ -322,8 +322,8 @@ func (r *Inputer) processRoomEvent(
// bother doing this if the event was already rejected as it just ends up // bother doing this if the event was already rejected as it just ends up
// burning CPU time. // burning CPU time.
historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared. historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared.
if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected { if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected && !isCreateEvent {
historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev) historyVisibility, rejectionErr, err = r.processStateBefore(ctx, roomInfo.RoomNID, input, missingPrev)
if err != nil { if err != nil {
return fmt.Errorf("r.processStateBefore: %w", err) return fmt.Errorf("r.processStateBefore: %w", err)
} }
@ -332,8 +332,23 @@ func (r *Inputer) processRoomEvent(
} }
} }
roomNID, err := r.DB.GetOrCreateRoomNID(ctx, event)
if err != nil {
return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err)
}
eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, event.Type())
if err != nil {
return fmt.Errorf("r.DB.GetOrCreateEventTypeNID: %w", err)
}
eventStateKeyNID, err := r.DB.GetOrCreateEventStateKeyNID(ctx, event.StateKey())
if err != nil {
return fmt.Errorf("r.DB.GetOrCreateEventStateKeyNID: %w", err)
}
// Store the event. // Store the event.
_, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected) _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected)
if err != nil { if err != nil {
return fmt.Errorf("updater.StoreEvent: %w", err) return fmt.Errorf("updater.StoreEvent: %w", err)
} }
@ -474,6 +489,7 @@ func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixse
// nolint:nakedret // nolint:nakedret
func (r *Inputer) processStateBefore( func (r *Inputer) processStateBefore(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID,
input *api.InputRoomEvent, input *api.InputRoomEvent,
missingPrev bool, missingPrev bool,
) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) { ) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) {
@ -489,7 +505,7 @@ func (r *Inputer) processStateBefore(
case input.HasState: case input.HasState:
// If we're overriding the state then we need to go and retrieve // If we're overriding the state then we need to go and retrieve
// them from the database. It's a hard error if they are missing. // them from the database. It's a hard error if they are missing.
stateEvents, err := r.DB.EventsFromIDs(ctx, input.StateEventIDs) stateEvents, err := r.DB.EventsFromIDs(ctx, roomNID, input.StateEventIDs)
if err != nil { if err != nil {
return "", nil, fmt.Errorf("r.DB.EventsFromIDs: %w", err) return "", nil, fmt.Errorf("r.DB.EventsFromIDs: %w", err)
} }
@ -567,6 +583,7 @@ func (r *Inputer) processStateBefore(
// we've failed to retrieve the auth chain altogether (in which case // we've failed to retrieve the auth chain altogether (in which case
// an error is returned) or we've successfully retrieved them all and // an error is returned) or we've successfully retrieved them all and
// they are now in the database. // they are now in the database.
// nolint: gocyclo
func (r *Inputer) fetchAuthEvents( func (r *Inputer) fetchAuthEvents(
ctx context.Context, ctx context.Context,
logger *logrus.Entry, logger *logrus.Entry,
@ -587,7 +604,7 @@ func (r *Inputer) fetchAuthEvents(
} }
for _, authEventID := range authEventIDs { for _, authEventID := range authEventIDs {
authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID}) authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo.RoomNID, []string{authEventID})
if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil {
unknown[authEventID] = struct{}{} unknown[authEventID] = struct{}{}
continue continue
@ -673,8 +690,23 @@ nextAuthEvent:
logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID())
} }
roomNID, err := r.DB.GetOrCreateRoomNID(ctx, authEvent)
if err != nil {
return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err)
}
eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, authEvent.Type())
if err != nil {
return fmt.Errorf("r.DB.GetOrCreateEventTypeNID: %w", err)
}
eventStateKeyNID, err := r.DB.GetOrCreateEventStateKeyNID(ctx, event.StateKey())
if err != nil {
return fmt.Errorf("r.DB.GetOrCreateEventStateKeyNID: %w", err)
}
// Finally, store the event in the database. // Finally, store the event in the database.
eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) eventNID, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected)
if err != nil { if err != nil {
return fmt.Errorf("updater.StoreEvent: %w", err) return fmt.Errorf("updater.StoreEvent: %w", err)
} }
@ -750,7 +782,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event
return err return err
} }
memberEvents, err := r.DB.Events(ctx, membershipNIDs) memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, membershipNIDs)
if err != nil { if err != nil {
return err return err
} }

View file

@ -53,7 +53,7 @@ func (r *Inputer) updateMemberships(
// Load the event JSON so we can look up the "membership" key. // Load the event JSON so we can look up the "membership" key.
// TODO: Maybe add a membership key to the events table so we can load that // TODO: Maybe add a membership key to the events table so we can load that
// key without having to load the entire event JSON? // key without having to load the entire event JSON?
events, err := updater.Events(ctx, eventNIDs) events, err := updater.Events(ctx, 0, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -43,7 +43,7 @@ type missingStateReq struct {
log *logrus.Entry log *logrus.Entry
virtualHost gomatrixserverlib.ServerName virtualHost gomatrixserverlib.ServerName
origin gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName
db storage.Database db storage.RoomDatabase
roomInfo *types.RoomInfo roomInfo *types.RoomInfo
inputer *Inputer inputer *Inputer
keys gomatrixserverlib.JSONVerifier keys gomatrixserverlib.JSONVerifier
@ -395,7 +395,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
for _, entry := range stateEntries { for _, entry := range stateEntries {
stateEventNIDs = append(stateEventNIDs, entry.EventNID) stateEventNIDs = append(stateEventNIDs, entry.EventNID)
} }
stateEvents, err := t.db.Events(ctx, stateEventNIDs) stateEvents, err := t.db.Events(ctx, t.roomInfo.RoomNID, stateEventNIDs)
if err != nil { if err != nil {
t.log.WithError(err).Warnf("failed to load state events locally") t.log.WithError(err).Warnf("failed to load state events locally")
return nil return nil
@ -432,7 +432,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
missingEventList = append(missingEventList, evID) missingEventList = append(missingEventList, evID)
} }
t.log.WithField("count", len(missingEventList)).Debugf("Fetching missing auth events") t.log.WithField("count", len(missingEventList)).Debugf("Fetching missing auth events")
events, err := t.db.EventsFromIDs(ctx, missingEventList) events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList)
if err != nil { if err != nil {
return nil return nil
} }
@ -702,7 +702,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
} }
t.haveEventsMutex.Unlock() t.haveEventsMutex.Unlock()
events, err := t.db.EventsFromIDs(ctx, missingEventList) events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList)
if err != nil { if err != nil {
return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err) return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err)
} }
@ -844,7 +844,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
if localFirst { if localFirst {
// fetch from the roomserver // fetch from the roomserver
events, err := t.db.EventsFromIDs(ctx, []string{missingEventID}) events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, []string{missingEventID})
if err != nil { if err != nil {
t.log.Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) t.log.Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err)
} else if len(events) == 1 { } else if len(events) == 1 {

View file

@ -70,7 +70,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
return nil return nil
} }
memberEvents, err := r.DB.Events(ctx, memberNIDs) memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, memberNIDs)
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,

View file

@ -86,7 +86,7 @@ func (r *Backfiller) PerformBackfill(
// Retrieve events from the list that was filled previously. If we fail to get // Retrieve events from the list that was filled previously. If we fail to get
// events from the database then attempt once to get them from federation instead. // events from the database then attempt once to get them from federation instead.
var loadedEvents []*gomatrixserverlib.Event var loadedEvents []*gomatrixserverlib.Event
loadedEvents, err = helpers.LoadEvents(ctx, r.DB, resultNIDs) loadedEvents, err = helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs)
if err != nil { if err != nil {
if _, ok := err.(types.MissingEventError); ok { if _, ok := err.(types.MissingEventError); ok {
return r.backfillViaFederation(ctx, request, response) return r.backfillViaFederation(ctx, request, response)
@ -258,6 +258,7 @@ type backfillRequester struct {
eventIDToBeforeStateIDs map[string][]string eventIDToBeforeStateIDs map[string][]string
eventIDMap map[string]*gomatrixserverlib.Event eventIDMap map[string]*gomatrixserverlib.Event
historyVisiblity gomatrixserverlib.HistoryVisibility historyVisiblity gomatrixserverlib.HistoryVisibility
roomInfo types.RoomInfo
} }
func newBackfillRequester( func newBackfillRequester(
@ -454,14 +455,14 @@ FindSuccessor:
return nil return nil
} }
stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID]) stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID)
if err != nil { if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event")
return nil return nil
} }
// possibly return all joined servers depending on history visiblity // possibly return all joined servers depending on history visiblity
memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.virtualHost) memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, info, stateEntries, b.virtualHost)
b.historyVisiblity = visibility b.historyVisiblity = visibility
if err != nil { if err != nil {
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
@ -472,7 +473,7 @@ FindSuccessor:
// Retrieve all "m.room.member" state events of "join" membership, which // Retrieve all "m.room.member" state events of "join" membership, which
// contains the list of users in the room before the event, therefore all // contains the list of users in the room before the event, therefore all
// the servers in it at that moment. // the servers in it at that moment.
memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, stateEntries, true) memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, info.RoomNID, stateEntries, true)
if err != nil { if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event") logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event")
return nil return nil
@ -523,11 +524,15 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
} }
eventNIDs := make([]types.EventNID, len(nidMap)) eventNIDs := make([]types.EventNID, len(nidMap))
i := 0 i := 0
roomNID := b.roomInfo.RoomNID
for _, nid := range nidMap { for _, nid := range nidMap {
eventNIDs[i] = nid eventNIDs[i] = nid.EventNID
i++ i++
if roomNID == 0 {
roomNID = nid.RoomNID
} }
eventsWithNids, err := b.db.Events(ctx, eventNIDs) }
eventsWithNids, err := b.db.Events(ctx, roomNID, eventNIDs)
if err != nil { if err != nil {
logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events")
return nil, err return nil, err
@ -544,7 +549,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just // TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just
// pull all events and then filter by that table. // pull all events and then filter by that table.
func joinEventsFromHistoryVisibility( func joinEventsFromHistoryVisibility(
ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry, ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry,
thisServer gomatrixserverlib.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) { thisServer gomatrixserverlib.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) {
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
@ -557,7 +562,7 @@ func joinEventsFromHistoryVisibility(
} }
// Get all of the events in this state // Get all of the events in this state
stateEvents, err := db.Events(ctx, eventNIDs) stateEvents, err := db.Events(ctx, roomInfo.RoomNID, eventNIDs)
if err != nil { if err != nil {
// even though the default should be shared, restricting the visibility to joined // even though the default should be shared, restricting the visibility to joined
// feels more secure here. // feels more secure here.
@ -570,21 +575,17 @@ func joinEventsFromHistoryVisibility(
// Can we see events in the room? // Can we see events in the room?
canSeeEvents := auth.IsServerAllowed(thisServer, true, events) canSeeEvents := auth.IsServerAllowed(thisServer, true, events)
visibility := gomatrixserverlib.HistoryVisibility(auth.HistoryVisibilityForRoom(events)) visibility := auth.HistoryVisibilityForRoom(events)
if !canSeeEvents { if !canSeeEvents {
logrus.Infof("ServersAtEvent history not visible to us: %s", visibility) logrus.Infof("ServersAtEvent history not visible to us: %s", visibility)
return nil, visibility, nil return nil, visibility, nil
} }
// get joined members // get joined members
info, err := db.RoomInfo(ctx, roomID) joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, false)
if err != nil {
return nil, visibility, nil
}
joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
if err != nil { if err != nil {
return nil, visibility, err return nil, visibility, err
} }
evs, err := db.Events(ctx, joinEventNIDs) evs, err := db.Events(ctx, roomInfo.RoomNID, joinEventNIDs)
return evs, visibility, err return evs, visibility, err
} }
@ -601,12 +602,31 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs
authNids := make([]types.EventNID, len(nidMap)) authNids := make([]types.EventNID, len(nidMap))
i := 0 i := 0
for _, nid := range nidMap { for _, nid := range nidMap {
authNids[i] = nid authNids[i] = nid.EventNID
i++ i++
} }
roomNID, err = db.GetOrCreateRoomNID(ctx, ev.Unwrap())
if err != nil {
logrus.WithError(err).Error("failed to get or create roomNID")
continue
}
eventTypeNID, err := db.GetOrCreateEventTypeNID(ctx, ev.Type())
if err != nil {
logrus.WithError(err).Error("failed to get or create eventType NID")
continue
}
eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(ctx, ev.StateKey())
if err != nil {
logrus.WithError(err).Error("failed to get or create eventStateKey NID")
continue
}
var redactedEventID string var redactedEventID string
var redactionEvent *gomatrixserverlib.Event var redactionEvent *gomatrixserverlib.Event
eventNID, roomNID, _, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), authNids, false) eventNID, _, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), roomNID, eventTypeNID, eventStateKeyNID, authNids, false)
if err != nil { if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
continue continue

View file

@ -29,7 +29,7 @@ import (
) )
type InboundPeeker struct { type InboundPeeker struct {
DB storage.Database DB storage.RoomDatabase
Inputer *input.Inputer Inputer *input.Inputer
} }
@ -64,7 +64,7 @@ func (r *InboundPeeker) PerformInboundPeek(
if err != nil { if err != nil {
return err return err
} }
latestEvents, err := r.DB.EventsFromIDs(ctx, []string{latestEventRefs[0].EventID}) latestEvents, err := r.DB.EventsFromIDs(ctx, info.RoomNID, []string{latestEventRefs[0].EventID})
if err != nil { if err != nil {
return err return err
} }
@ -88,7 +88,7 @@ func (r *InboundPeeker) PerformInboundPeek(
if err != nil { if err != nil {
return err return err
} }
stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, stateEntries) stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries)
if err != nil { if err != nil {
return err return err
} }

View file

@ -194,7 +194,7 @@ func (r *Inviter) PerformInvite(
// try and see if the user is allowed to make this invite. We can't do // try and see if the user is allowed to make this invite. We can't do
// this for invites coming in over federation - we have to take those on // this for invites coming in over federation - we have to take those on
// trust. // trust.
_, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs()) _, err = helpers.CheckAuthEvents(ctx, r.DB, info.RoomNID, event, event.AuthEventIDs())
if err != nil { if err != nil {
logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
"processInviteEvent.checkAuthEvents failed for event", "processInviteEvent.checkAuthEvents failed for event",
@ -291,7 +291,7 @@ func buildInviteStrippedState(
for _, stateNID := range stateEntries { for _, stateNID := range stateEntries {
stateNIDs = append(stateNIDs, stateNID.EventNID) stateNIDs = append(stateNIDs, stateNID.EventNID)
} }
stateEvents, err := db.Events(ctx, stateNIDs) stateEvents, err := db.Events(ctx, info.RoomNID, stateNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -22,7 +22,6 @@ import (
fsAPI "github.com/matrix-org/dendrite/federationapi/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/internal/input"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -31,8 +30,6 @@ type Unpeeker struct {
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
Cfg *config.RoomServer Cfg *config.RoomServer
FSAPI fsAPI.RoomserverFederationAPI FSAPI fsAPI.RoomserverFederationAPI
DB storage.Database
Inputer *input.Inputer Inputer *input.Inputer
} }

View file

@ -102,7 +102,7 @@ func (r *Queryer) QueryStateAfterEvents(
return err return err
} }
stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries) stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries)
if err != nil { if err != nil {
return err return err
} }
@ -138,17 +138,7 @@ func (r *Queryer) QueryEventsByID(
request *api.QueryEventsByIDRequest, request *api.QueryEventsByIDRequest,
response *api.QueryEventsByIDResponse, response *api.QueryEventsByIDResponse,
) error { ) error {
eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs) events, err := r.DB.EventsFromIDs(ctx, 0, request.EventIDs)
if err != nil {
return err
}
var eventNIDs []types.EventNID
for _, nid := range eventNIDMap {
eventNIDs = append(eventNIDs, nid)
}
events, err := helpers.LoadEvents(ctx, r.DB, eventNIDs)
if err != nil { if err != nil {
return err return err
} }
@ -196,7 +186,7 @@ func (r *Queryer) QueryMembershipForUser(
response.IsInRoom = stillInRoom response.IsInRoom = stillInRoom
response.HasBeenInRoom = true response.HasBeenInRoom = true
evs, err := r.DB.Events(ctx, []types.EventNID{membershipEventNID}) evs, err := r.DB.Events(ctx, info.RoomNID, []types.EventNID{membershipEventNID})
if err != nil { if err != nil {
return err return err
} }
@ -278,10 +268,10 @@ func (r *Queryer) QueryMembershipAtEvent(
// once. If we have more than one membership event, we need to get the state for each state entry. // once. If we have more than one membership event, we need to get the state for each state entry.
if canShortCircuit { if canShortCircuit {
if len(memberships) == 0 { if len(memberships) == 0 {
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false)
} }
} else { } else {
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false)
} }
if err != nil { if err != nil {
return fmt.Errorf("unable to get memberships at state: %w", err) return fmt.Errorf("unable to get memberships at state: %w", err)
@ -328,7 +318,7 @@ func (r *Queryer) QueryMembershipsForRoom(
} }
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
} }
events, err = r.DB.Events(ctx, eventNIDs) events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.Events: %w", err) return fmt.Errorf("r.DB.Events: %w", err)
} }
@ -367,14 +357,14 @@ func (r *Queryer) QueryMembershipsForRoom(
return err return err
} }
events, err = r.DB.Events(ctx, eventNIDs) events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs)
} else { } else {
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID) stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID)
if err != nil { if err != nil {
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
return err return err
} }
events, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly) events, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntries, request.JoinedOnly)
} }
if err != nil { if err != nil {
@ -425,7 +415,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
request *api.QueryServerAllowedToSeeEventRequest, request *api.QueryServerAllowedToSeeEventRequest,
response *api.QueryServerAllowedToSeeEventResponse, response *api.QueryServerAllowedToSeeEventResponse,
) (err error) { ) (err error) {
events, err := r.DB.EventsFromIDs(ctx, []string{request.EventID}) events, err := r.DB.EventsFromIDs(ctx, 0, []string{request.EventID})
if err != nil { if err != nil {
return return
} }
@ -476,7 +466,7 @@ func (r *Queryer) QueryMissingEvents(
eventsToFilter[id] = true eventsToFilter[id] = true
} }
} }
events, err := r.DB.EventsFromIDs(ctx, front) events, err := r.DB.EventsFromIDs(ctx, 0, front)
if err != nil { if err != nil {
return err return err
} }
@ -496,7 +486,7 @@ func (r *Queryer) QueryMissingEvents(
return err return err
} }
loadedEvents, err := helpers.LoadEvents(ctx, r.DB, resultNIDs) loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs)
if err != nil { if err != nil {
return err return err
} }
@ -621,11 +611,11 @@ func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomI
return nil, rejected, false, err return nil, rejected, false, err
} }
events, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries) events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo.RoomNID, stateEntries)
return events, rejected, false, err return events, rejected, false, err
} }
type eventsFromIDs func(context.Context, []string) ([]types.Event, error) type eventsFromIDs func(context.Context, types.RoomNID, []string) ([]types.Event, error)
// GetAuthChain fetches the auth chain for the given auth events. An auth chain // GetAuthChain fetches the auth chain for the given auth events. An auth chain
// is the list of all events that are referenced in the auth_events section, and // is the list of all events that are referenced in the auth_events section, and
@ -643,7 +633,7 @@ func GetAuthChain(
for len(eventsToFetch) > 0 { for len(eventsToFetch) > 0 {
// Try to retrieve the events from the database. // Try to retrieve the events from the database.
events, err := fn(ctx, eventsToFetch) events, err := fn(ctx, 0, eventsToFetch)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -981,7 +971,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
// For each of the joined users, let's see if we can get a valid // For each of the joined users, let's see if we can get a valid
// membership event. // membership event.
for _, joinNID := range joinNIDs { for _, joinNID := range joinNIDs {
events, err := r.DB.Events(ctx, []types.EventNID{joinNID}) events, err := r.DB.Events(ctx, roomInfo.RoomNID, []types.EventNID{joinNID})
if err != nil || len(events) != 1 { if err != nil || len(events) != 1 {
continue continue
} }

View file

@ -80,7 +80,7 @@ func (db *getEventDB) addFakeEvents(graph map[string][]string) error {
} }
// EventsFromIDs implements RoomserverInternalAPIEventDB // EventsFromIDs implements RoomserverInternalAPIEventDB
func (db *getEventDB) EventsFromIDs(ctx context.Context, eventIDs []string) (res []types.Event, err error) { func (db *getEventDB) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) (res []types.Event, err error) {
for _, evID := range eventIDs { for _, evID := range eventIDs {
res = append(res, types.Event{ res = append(res, types.Event{
EventNID: 0, EventNID: 0,

View file

@ -23,8 +23,7 @@ import (
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
) )
// NewInternalAPI returns a concerete implementation of the internal API. Callers // NewInternalAPI returns a concrete implementation of the internal API.
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI( func NewInternalAPI(
base *base.BaseDendrite, base *base.BaseDendrite,
) api.RoomserverInternalAPI { ) api.RoomserverInternalAPI {

View file

@ -41,8 +41,8 @@ type StateResolutionStorage interface {
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error)
} }
type StateResolution struct { type StateResolution struct {
@ -975,7 +975,7 @@ func (v *StateResolution) resolveConflictsV2(
// Store the newly found auth events in the auth set for this event. // Store the newly found auth events in the auth set for this event.
var authEventMap map[string]types.StateEntry var authEventMap map[string]types.StateEntry
authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, conflictedEvent, knownAuthEvents) authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, v.roomInfo.RoomNID, conflictedEvent, knownAuthEvents)
if err != nil { if err != nil {
return err return err
} }
@ -1091,7 +1091,7 @@ func (v *StateResolution) loadStateEvents(
eventNIDs = append(eventNIDs, entry.EventNID) eventNIDs = append(eventNIDs, entry.EventNID)
} }
} }
events, err := v.db.Events(ctx, eventNIDs) events, err := v.db.Events(ctx, v.roomInfo.RoomNID, eventNIDs)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -1120,7 +1120,7 @@ type authEventLoader struct {
// loadAuthEvents loads all of the auth events for a given event recursively, // loadAuthEvents loads all of the auth events for a given event recursively,
// along with a map that contains state entries for all of the auth events. // along with a map that contains state entries for all of the auth events.
func (l *authEventLoader) loadAuthEvents( func (l *authEventLoader) loadAuthEvents(
ctx context.Context, event *gomatrixserverlib.Event, eventMap map[string]types.Event, ctx context.Context, roomNID types.RoomNID, event *gomatrixserverlib.Event, eventMap map[string]types.Event,
) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) {
l.Lock() l.Lock()
defer l.Unlock() defer l.Unlock()
@ -1155,7 +1155,7 @@ func (l *authEventLoader) loadAuthEvents(
// If we need to get events from the database, go and fetch // If we need to get events from the database, go and fetch
// those now. // those now.
if len(l.lookupFromDB) > 0 { if len(l.lookupFromDB) > 0 {
eventsFromDB, err := l.v.db.EventsFromIDs(ctx, l.lookupFromDB) eventsFromDB, err := l.v.db.EventsFromIDs(ctx, roomNID, l.lookupFromDB)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err) return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err)
} }

View file

@ -69,15 +69,12 @@ type Database interface {
) ([]types.StateEntryList, error) ) ([]types.StateEntryList, error)
// Look up the Events for a list of numeric event IDs. // Look up the Events for a list of numeric event IDs.
// Returns a sorted list of events. // Returns a sorted list of events.
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error)
// Look up snapshot NID for an event ID string // Look up snapshot NID for an event ID string
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
StoreEvent( StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error)
ctx context.Context, event *gomatrixserverlib.Event, authEventNIDs []types.EventNID,
isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error)
// Look up the state entries for a list of string event IDs // Look up the state entries for a list of string event IDs
// Returns an error if the there is an error talking to the database // Returns an error if the there is an error talking to the database
// Returns a types.MissingEventError if the event IDs aren't in the database. // Returns a types.MissingEventError if the event IDs aren't in the database.
@ -87,7 +84,7 @@ type Database interface {
EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
// Look up the numeric IDs for a list of events. // Look up the numeric IDs for a list of events.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventMetadata, error)
// Set the state at an event. FIXME TODO: "at" // Set the state at an event. FIXME TODO: "at"
SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error
// Lookup the event IDs for a batch of event numeric IDs. // Lookup the event IDs for a batch of event numeric IDs.
@ -138,7 +135,7 @@ type Database interface {
// EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was // EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was
// not found. // not found.
// Returns an error if the retrieval went wrong. // Returns an error if the retrieval went wrong.
EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error)
// Publish or unpublish a room from the room directory. // Publish or unpublish a room from the room directory.
PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error
// Returns a list of room IDs for rooms which are published. // Returns a list of room IDs for rooms which are published.
@ -182,4 +179,36 @@ type Database interface {
GetMembershipForHistoryVisibility( GetMembershipForHistoryVisibility(
ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string, ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string,
) (map[string]*gomatrixserverlib.HeaderedEvent, error) ) (map[string]*gomatrixserverlib.HeaderedEvent, error)
GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, error)
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
}
type RoomDatabase interface {
// RoomInfo returns room information for the given room ID, or nil if there is no room.
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
// IsEventRejected returns true if the event is known and rejected.
IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (rejected bool, err error)
MissingAuthPrevEvents(ctx context.Context, e *gomatrixserverlib.Event) (missingAuth, missingPrev []string, err error)
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error)
UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error
GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error)
StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error)
GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error)
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error)
LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, error)
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
} }

View file

@ -140,10 +140,10 @@ const bulkSelectEventIDSQL = "" +
"SELECT event_nid, event_id FROM roomserver_events WHERE event_nid = ANY($1)" "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid = ANY($1)"
const bulkSelectEventNIDSQL = "" + const bulkSelectEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1)" "SELECT event_id, event_nid, room_nid FROM roomserver_events WHERE event_id = ANY($1)"
const bulkSelectUnsentEventNIDSQL = "" + const bulkSelectUnsentEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1) AND sent_to_output = FALSE" "SELECT event_id, event_nid, room_nid FROM roomserver_events WHERE event_id = ANY($1) AND sent_to_output = FALSE"
const selectMaxEventDepthSQL = "" + const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)" "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)"
@ -520,20 +520,20 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) {
return s.bulkSelectEventNID(ctx, txn, eventIDs, false) return s.bulkSelectEventNID(ctx, txn, eventIDs, false)
} }
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID // BulkSelectEventNIDs returns a map from string event ID to numeric event ID
// only for events that haven't already been sent to the roomserver output. // only for events that haven't already been sent to the roomserver output.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) {
return s.bulkSelectEventNID(ctx, txn, eventIDs, true) return s.bulkSelectEventNID(ctx, txn, eventIDs, true)
} }
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) { func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventMetadata, error) {
var stmt *sql.Stmt var stmt *sql.Stmt
if onlyUnsent { if onlyUnsent {
stmt = sqlutil.TxStmt(txn, s.bulkSelectUnsentEventNIDStmt) stmt = sqlutil.TxStmt(txn, s.bulkSelectUnsentEventNIDStmt)
@ -545,14 +545,18 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed")
results := make(map[string]types.EventNID, len(eventIDs)) results := make(map[string]types.EventMetadata, len(eventIDs))
var eventID string var eventID string
var eventNID int64 var eventNID int64
var roomNID int64
for rows.Next() { for rows.Next() {
if err = rows.Scan(&eventID, &eventNID); err != nil { if err = rows.Scan(&eventID, &eventNID, &roomNID); err != nil {
return nil, err return nil, err
} }
results[eventID] = types.EventNID(eventNID) results[eventID] = types.EventMetadata{
EventNID: types.EventNID(eventNID),
RoomNID: types.RoomNID(roomNID),
}
} }
return results, rows.Err() return results, rows.Err()
} }

View file

@ -116,10 +116,8 @@ func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEvent
}) })
} }
func (u *RoomUpdater) Events( func (u *RoomUpdater) Events(ctx context.Context, _ types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) {
ctx context.Context, eventNIDs []types.EventNID, return u.d.events(ctx, u.txn, u.roomInfo.RoomNID, eventNIDs)
) ([]types.Event, error) {
return u.d.events(ctx, u.txn, eventNIDs)
} }
func (u *RoomUpdater) SnapshotNIDFromEventID( func (u *RoomUpdater) SnapshotNIDFromEventID(
@ -197,12 +195,8 @@ func (u *RoomUpdater) StateAtEventIDs(
return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs) return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
} }
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false) return u.d.eventsFromIDs(ctx, u.txn, roomNID, eventIDs, NoFilter)
}
func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true)
} }
// IsReferenced implements types.RoomRecentEventsUpdater // IsReferenced implements types.RoomRecentEventsUpdater

View file

@ -9,6 +9,7 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
@ -111,16 +112,31 @@ func (d *Database) eventStateKeyNIDs(
) (map[string]types.EventStateKeyNID, error) { ) (map[string]types.EventStateKeyNID, error) {
result := make(map[string]types.EventStateKeyNID) result := make(map[string]types.EventStateKeyNID)
eventStateKeys = util.UniqueStrings(eventStateKeys) eventStateKeys = util.UniqueStrings(eventStateKeys)
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys) // first ask the cache about these keys
fetchEventStateKeys := make([]string, 0, len(eventStateKeys))
for _, eventStateKey := range eventStateKeys {
eventStateKeyNID, ok := d.Cache.GetEventStateKeyNID(eventStateKey)
if ok {
result[eventStateKey] = eventStateKeyNID
continue
}
fetchEventStateKeys = append(fetchEventStateKeys, eventStateKey)
}
if len(fetchEventStateKeys) > 0 {
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, fetchEventStateKeys)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for eventStateKey, nid := range nids { for eventStateKey, nid := range nids {
result[eventStateKey] = nid result[eventStateKey] = nid
} }
}
// We received some nids, but are still missing some, work out which and create them // We received some nids, but are still missing some, work out which and create them
if len(eventStateKeys) > len(result) { if len(eventStateKeys) > len(result) {
var nid types.EventStateKeyNID var nid types.EventStateKeyNID
var err error
err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
for _, eventStateKey := range eventStateKeys { for _, eventStateKey := range eventStateKeys {
if _, ok := result[eventStateKey]; ok { if _, ok := result[eventStateKey]; ok {
@ -262,7 +278,7 @@ func (d *Database) addState(
func (d *Database) EventNIDs( func (d *Database) EventNIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) { ) (map[string]types.EventMetadata, error) {
return d.eventNIDs(ctx, nil, eventIDs, NoFilter) return d.eventNIDs(ctx, nil, eventIDs, NoFilter)
} }
@ -275,7 +291,7 @@ const (
func (d *Database) eventNIDs( func (d *Database) eventNIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter, ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter,
) (map[string]types.EventNID, error) { ) (map[string]types.EventMetadata, error) {
switch filter { switch filter {
case FilterUnsentOnly: case FilterUnsentOnly:
return d.EventsTable.BulkSelectUnsentEventNID(ctx, txn, eventIDs) return d.EventsTable.BulkSelectUnsentEventNID(ctx, txn, eventIDs)
@ -325,11 +341,11 @@ func (d *Database) EventIDs(
return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
} }
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { func (d *Database) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) {
return d.eventsFromIDs(ctx, nil, eventIDs, NoFilter) return d.eventsFromIDs(ctx, nil, roomNID, eventIDs, NoFilter)
} }
func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter) ([]types.Event, error) { func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventIDs []string, filter UnsentFilter) ([]types.Event, error) {
nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter) nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter)
if err != nil { if err != nil {
return nil, err return nil, err
@ -337,10 +353,16 @@ func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []st
var nids []types.EventNID var nids []types.EventNID
for _, nid := range nidMap { for _, nid := range nidMap {
nids = append(nids, nid) nids = append(nids, nid.EventNID)
if roomNID != 0 && roomNID != nid.RoomNID {
logrus.Errorf("expected events from room %d, but also found %d", roomNID, nid.RoomNID)
}
if roomNID == 0 {
roomNID = nid.RoomNID
}
} }
return d.events(ctx, txn, nids) return d.events(ctx, txn, roomNID, nids)
} }
func (d *Database) LatestEventIDs( func (d *Database) LatestEventIDs(
@ -480,14 +502,18 @@ func (d *Database) GetInvitesForUser(
} }
func (d *Database) Events( func (d *Database) Events(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID,
) ([]types.Event, error) { ) ([]types.Event, error) {
return d.events(ctx, nil, eventNIDs) return d.events(ctx, nil, roomNID, eventNIDs)
} }
func (d *Database) events( func (d *Database) events(
ctx context.Context, txn *sql.Tx, inputEventNIDs types.EventNIDs, ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, inputEventNIDs types.EventNIDs,
) ([]types.Event, error) { ) ([]types.Event, error) {
if roomNID == 0 {
// No need to go further, as we won't find any events for this room.
return nil, nil
}
sort.Sort(inputEventNIDs) sort.Sort(inputEventNIDs)
events := make(map[types.EventNID]*gomatrixserverlib.Event, len(inputEventNIDs)) events := make(map[types.EventNID]*gomatrixserverlib.Event, len(inputEventNIDs))
eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs)) eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs))
@ -519,40 +545,34 @@ func (d *Database) events(
if err != nil { if err != nil {
return nil, err return nil, err
} }
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, txn, eventNIDs) eventIDs, err := d.EventsTable.BulkSelectEventID(ctx, txn, eventNIDs)
if err != nil { if err != nil {
eventIDs = map[types.EventNID]string{} eventIDs = map[types.EventNID]string{}
} }
var roomNIDs map[types.EventNID]types.RoomNID
roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, txn, eventNIDs) var roomVersion gomatrixserverlib.RoomVersion
var fetchRoomVersion bool
var ok bool
var roomID string
if roomID, ok = d.Cache.GetRoomServerRoomID(roomNID); ok {
roomVersion, ok = d.Cache.GetRoomVersion(roomID)
if !ok {
fetchRoomVersion = true
}
}
if roomVersion == "" || fetchRoomVersion {
var dbRoomVersions map[types.RoomNID]gomatrixserverlib.RoomVersion
dbRoomVersions, err = d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, []types.RoomNID{roomNID})
if err != nil { if err != nil {
return nil, err return nil, err
} }
uniqueRoomNIDs := make(map[types.RoomNID]struct{}) if roomVersion, ok = dbRoomVersions[roomNID]; !ok {
for _, n := range roomNIDs { return nil, fmt.Errorf("unable to find roomversion for room %d", roomNID)
uniqueRoomNIDs[n] = struct{}{}
}
roomVersions := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
fetchNIDList := make([]types.RoomNID, 0, len(uniqueRoomNIDs))
for n := range uniqueRoomNIDs {
if roomID, ok := d.Cache.GetRoomServerRoomID(n); ok {
if roomVersion, ok := d.Cache.GetRoomVersion(roomID); ok {
roomVersions[n] = roomVersion
continue
} }
} }
fetchNIDList = append(fetchNIDList, n)
}
dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, fetchNIDList)
if err != nil {
return nil, err
}
for n, v := range dbRoomVersions {
roomVersions[n] = v
}
for _, eventJSON := range eventJSONs { for _, eventJSON := range eventJSONs {
roomNID := roomNIDs[eventJSON.EventNID]
roomVersion := roomVersions[roomNID]
events[eventJSON.EventNID], err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID( events[eventJSON.EventNID], err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID(
eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomVersion, eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomVersion,
) )
@ -624,37 +644,8 @@ func (d *Database) IsEventRejected(ctx context.Context, roomNID types.RoomNID, e
return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID) return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID)
} }
func (d *Database) StoreEvent( // GetOrCreateRoomNID gets or creates a new roomNID for the given event
ctx context.Context, event *gomatrixserverlib.Event, func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (roomNID types.RoomNID, err error) {
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
return d.storeEvent(ctx, nil, event, authEventNIDs, isRejected)
}
func (d *Database) storeEvent(
ctx context.Context, updater *RoomUpdater, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
var (
roomNID types.RoomNID
eventTypeNID types.EventTypeNID
eventStateKeyNID types.EventStateKeyNID
eventNID types.EventNID
stateNID types.StateSnapshotNID
redactionEvent *gomatrixserverlib.Event
redactedEventID string
err error
)
var txn *sql.Tx
if updater != nil && updater.txn != nil {
txn = updater.txn
}
// First writer is with a database-provided transaction, so that NIDs are assigned
// globally outside of the updater context, to help avoid races.
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// TODO: Here we should aim to have two different code paths for new rooms
// vs existing ones.
// Get the default room version. If the client doesn't supply a room_version // Get the default room version. If the client doesn't supply a room_version
// then we will use our configured default to create the room. // then we will use our configured default to create the room.
// https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom
@ -663,34 +654,61 @@ func (d *Database) storeEvent(
// room. // room.
var roomVersion gomatrixserverlib.RoomVersion var roomVersion gomatrixserverlib.RoomVersion
if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil {
return fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err) return 0, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err)
}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion)
if err != nil {
return err
}
return nil
})
return roomNID, err
} }
if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil { func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) {
return fmt.Errorf("d.assignRoomNID: %w", err) err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
} if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, eventType); err != nil {
if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil {
return fmt.Errorf("d.assignEventTypeNID: %w", err) return fmt.Errorf("d.assignEventTypeNID: %w", err)
} }
return nil
})
return eventTypeNID, err
}
eventStateKey := event.StateKey() func (d *Database) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (eventStateKeyNID types.EventStateKeyNID, err error) {
// Assigned a numeric ID for the state_key if there is one present. if eventStateKey == nil {
// Otherwise set the numeric ID for the state_key to 0. return 0, nil
if eventStateKey != nil { }
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil { if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil {
return fmt.Errorf("d.assignStateKeyNID: %w", err) return fmt.Errorf("d.assignStateKeyNID: %w", err)
} }
}
return nil return nil
}) })
if err != nil { if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err) return 0, err
} }
return eventStateKeyNID, nil
}
func (d *Database) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event,
roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
var (
eventNID types.EventNID
stateNID types.StateSnapshotNID
redactionEvent *gomatrixserverlib.Event
redactedEventID string
err error
)
// Second writer is using the database-provided transaction, probably from the // Second writer is using the database-provided transaction, probably from the
// room updater, for easy roll-back if required. // room updater, for easy roll-back if required.
err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if eventNID, stateNID, err = d.EventsTable.InsertEvent( if eventNID, stateNID, err = d.EventsTable.InsertEvent(
ctx, ctx,
txn, txn,
@ -718,7 +736,7 @@ func (d *Database) storeEvent(
return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
} }
if !isRejected { // ignore rejected redaction events if !isRejected { // ignore rejected redaction events
redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event) redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, roomNID, eventNID, event)
if err != nil { if err != nil {
return fmt.Errorf("d.handleRedactions: %w", err) return fmt.Errorf("d.handleRedactions: %w", err)
} }
@ -726,7 +744,7 @@ func (d *Database) storeEvent(
return nil return nil
}) })
if err != nil { if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err) return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err)
} }
// We should attempt to update the previous events table with any // We should attempt to update the previous events table with any
@ -741,28 +759,28 @@ func (d *Database) storeEvent(
// any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater // any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
// to do writes however then this will need to go inside `Writer.Do`. // to do writes however then this will need to go inside `Writer.Do`.
succeeded := false succeeded := false
if updater == nil {
var roomInfo *types.RoomInfo var roomInfo *types.RoomInfo
roomInfo, err = d.roomInfo(ctx, txn, event.RoomID()) roomInfo, err = d.roomInfo(ctx, nil, event.RoomID())
if err != nil { if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
} }
if roomInfo == nil && len(prevEvents) > 0 { if roomInfo == nil && len(prevEvents) > 0 {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
} }
var updater *RoomUpdater
updater, err = d.GetRoomUpdater(ctx, roomInfo) updater, err = d.GetRoomUpdater(ctx, roomInfo)
if err != nil { if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err) return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err)
} }
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
}
if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil { if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err) return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err)
} }
succeeded = true succeeded = true
} }
return eventNID, roomNID, types.StateAtEvent{ return eventNID, types.StateAtEvent{
BeforeStateSnapshotNID: stateNID, BeforeStateSnapshotNID: stateNID,
StateEntry: types.StateEntry{ StateEntry: types.StateEntry{
StateKeyTuple: types.StateKeyTuple{ StateKeyTuple: types.StateKeyTuple{
@ -814,6 +832,10 @@ func (d *Database) MissingAuthPrevEvents(
func (d *Database) assignRoomNID( func (d *Database) assignRoomNID(
ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (types.RoomNID, error) { ) (types.RoomNID, error) {
roomNID, ok := d.Cache.GetRoomServerRoomNID(roomID)
if ok {
return roomNID, nil
}
// Check if we already have a numeric ID in the database. // Check if we already have a numeric ID in the database.
roomNID, err := d.RoomsTable.SelectRoomNID(ctx, txn, roomID) roomNID, err := d.RoomsTable.SelectRoomNID(ctx, txn, roomID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -824,12 +846,20 @@ func (d *Database) assignRoomNID(
roomNID, err = d.RoomsTable.SelectRoomNID(ctx, txn, roomID) roomNID, err = d.RoomsTable.SelectRoomNID(ctx, txn, roomID)
} }
} }
return roomNID, err if err != nil {
return 0, err
}
d.Cache.StoreRoomServerRoomID(roomNID, roomID)
return roomNID, nil
} }
func (d *Database) assignEventTypeNID( func (d *Database) assignEventTypeNID(
ctx context.Context, txn *sql.Tx, eventType string, ctx context.Context, txn *sql.Tx, eventType string,
) (types.EventTypeNID, error) { ) (types.EventTypeNID, error) {
eventTypeNID, ok := d.Cache.GetEventTypeKey(eventType)
if ok {
return eventTypeNID, nil
}
// Check if we already have a numeric ID in the database. // Check if we already have a numeric ID in the database.
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -840,12 +870,20 @@ func (d *Database) assignEventTypeNID(
eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType)
} }
} }
return eventTypeNID, err if err != nil {
return 0, err
}
d.Cache.StoreEventTypeKey(eventTypeNID, eventType)
return eventTypeNID, nil
} }
func (d *Database) assignStateKeyNID( func (d *Database) assignStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string, ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) { ) (types.EventStateKeyNID, error) {
eventStateKeyNID, ok := d.Cache.GetEventStateKeyNID(eventStateKey)
if ok {
return eventStateKeyNID, nil
}
// Check if we already have a numeric ID in the database. // Check if we already have a numeric ID in the database.
eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -856,6 +894,7 @@ func (d *Database) assignStateKeyNID(
eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey)
} }
} }
d.Cache.StoreEventStateKey(eventStateKeyNID, eventStateKey)
return eventStateKeyNID, err return eventStateKeyNID, err
} }
@ -899,7 +938,7 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) (
// //
// Returns the redaction event and the event ID of the redacted event if this call resulted in a redaction. // Returns the redaction event and the event ID of the redacted event if this call resulted in a redaction.
func (d *Database) handleRedactions( func (d *Database) handleRedactions(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event *gomatrixserverlib.Event, ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNID types.EventNID, event *gomatrixserverlib.Event,
) (*gomatrixserverlib.Event, string, error) { ) (*gomatrixserverlib.Event, string, error) {
var err error var err error
isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil
@ -919,7 +958,7 @@ func (d *Database) handleRedactions(
} }
} }
redactionEvent, redactedEvent, validated, err := d.loadRedactionPair(ctx, txn, eventNID, event) redactionEvent, redactedEvent, validated, err := d.loadRedactionPair(ctx, txn, roomNID, eventNID, event)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("d.loadRedactionPair: %w", err) return nil, "", fmt.Errorf("d.loadRedactionPair: %w", err)
} }
@ -985,7 +1024,7 @@ func (d *Database) handleRedactions(
// loadRedactionPair returns both the redaction event and the redacted event, else nil. // loadRedactionPair returns both the redaction event and the redacted event, else nil.
func (d *Database) loadRedactionPair( func (d *Database) loadRedactionPair(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event *gomatrixserverlib.Event, ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNID types.EventNID, event *gomatrixserverlib.Event,
) (*types.Event, *types.Event, bool, error) { ) (*types.Event, *types.Event, bool, error) {
var redactionEvent, redactedEvent *types.Event var redactionEvent, redactedEvent *types.Event
var info *tables.RedactionInfo var info *tables.RedactionInfo
@ -1017,9 +1056,9 @@ func (d *Database) loadRedactionPair(
} }
if isRedactionEvent { if isRedactionEvent {
redactedEvent = d.loadEvent(ctx, info.RedactsEventID) redactedEvent = d.loadEvent(ctx, roomNID, info.RedactsEventID)
} else { } else {
redactionEvent = d.loadEvent(ctx, info.RedactionEventID) redactionEvent = d.loadEvent(ctx, roomNID, info.RedactionEventID)
} }
return redactionEvent, redactedEvent, info.Validated, nil return redactionEvent, redactedEvent, info.Validated, nil
@ -1035,7 +1074,7 @@ func (d *Database) applyRedactions(events []types.Event) {
} }
// loadEvent loads a single event or returns nil on any problems/missing event // loadEvent loads a single event or returns nil on any problems/missing event
func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event { func (d *Database) loadEvent(ctx context.Context, roomNID types.RoomNID, eventID string) *types.Event {
nids, err := d.EventNIDs(ctx, []string{eventID}) nids, err := d.EventNIDs(ctx, []string{eventID})
if err != nil { if err != nil {
return nil return nil
@ -1043,7 +1082,7 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event {
if len(nids) == 0 { if len(nids) == 0 {
return nil return nil
} }
evs, err := d.Events(ctx, []types.EventNID{nids[eventID]}) evs, err := d.Events(ctx, roomNID, []types.EventNID{nids[eventID].EventNID})
if err != nil { if err != nil {
return nil return nil
} }
@ -1470,14 +1509,20 @@ func (d *Database) PurgeRoom(ctx context.Context, roomID string) error {
func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error { func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
published, err := d.PublishedTable.SelectPublishedFromRoomID(ctx, txn, oldRoomID)
if err != nil {
return fmt.Errorf("failed to get published room: %w", err)
}
if published {
// un-publish old room // un-publish old room
if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, oldRoomID, "", "", false); err != nil { if err = d.PublishedTable.UpsertRoomPublished(ctx, txn, oldRoomID, "", "", false); err != nil {
return fmt.Errorf("failed to unpublish room: %w", err) return fmt.Errorf("failed to unpublish room: %w", err)
} }
// publish new room // publish new room
if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, newRoomID, "", "", true); err != nil { if err = d.PublishedTable.UpsertRoomPublished(ctx, txn, newRoomID, "", "", true); err != nil {
return fmt.Errorf("failed to publish room: %w", err) return fmt.Errorf("failed to publish room: %w", err)
} }
}
// Migrate any existing room aliases // Migrate any existing room aliases
aliases, err := d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, txn, oldRoomID) aliases, err := d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, txn, oldRoomID)

View file

@ -3,7 +3,9 @@ package shared_test
import ( import (
"context" "context"
"testing" "testing"
"time"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -48,11 +50,14 @@ func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Dat
} }
assert.NoError(t, err) assert.NoError(t, err)
cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false)
return &shared.Database{ return &shared.Database{
DB: db, DB: db,
EventStateKeysTable: stateKeyTable, EventStateKeysTable: stateKeyTable,
MembershipTable: membershipTable, MembershipTable: membershipTable,
Writer: sqlutil.NewExclusiveWriter(), Writer: sqlutil.NewExclusiveWriter(),
Cache: cache,
}, func() { }, func() {
err := base.Close() err := base.Close()
assert.NoError(t, err) assert.NoError(t, err)

View file

@ -110,10 +110,10 @@ const bulkSelectEventIDSQL = "" +
"SELECT event_nid, event_id FROM roomserver_events WHERE event_nid IN ($1)" "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid IN ($1)"
const bulkSelectEventNIDSQL = "" + const bulkSelectEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)" "SELECT event_id, event_nid, room_nid FROM roomserver_events WHERE event_id IN ($1)"
const bulkSelectUnsentEventNIDSQL = "" + const bulkSelectUnsentEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE sent_to_output = 0 AND event_id IN ($1)" "SELECT event_id, event_nid, room_nid FROM roomserver_events WHERE sent_to_output = 0 AND event_id IN ($1)"
const selectMaxEventDepthSQL = "" + const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
@ -572,20 +572,20 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) {
return s.bulkSelectEventNID(ctx, txn, eventIDs, false) return s.bulkSelectEventNID(ctx, txn, eventIDs, false)
} }
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID // BulkSelectEventNIDs returns a map from string event ID to numeric event ID
// only for events that haven't already been sent to the roomserver output. // only for events that haven't already been sent to the roomserver output.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) {
return s.bulkSelectEventNID(ctx, txn, eventIDs, true) return s.bulkSelectEventNID(ctx, txn, eventIDs, true)
} }
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) { func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventMetadata, error) {
/////////////// ///////////////
iEventIDs := make([]interface{}, len(eventIDs)) iEventIDs := make([]interface{}, len(eventIDs))
for k, v := range eventIDs { for k, v := range eventIDs {
@ -609,14 +609,18 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed")
results := make(map[string]types.EventNID, len(eventIDs)) results := make(map[string]types.EventMetadata, len(eventIDs))
var eventID string var eventID string
var eventNID int64 var eventNID int64
var roomNID int64
for rows.Next() { for rows.Next() {
if err = rows.Scan(&eventID, &eventNID); err != nil { if err = rows.Scan(&eventID, &eventNID, &roomNID); err != nil {
return nil, err return nil, err
} }
results[eventID] = types.EventNID(eventNID) results[eventID] = types.EventMetadata{
EventNID: types.EventNID(eventNID),
RoomNID: types.RoomNID(roomNID),
}
} }
return results, nil return results, nil
} }

View file

@ -63,8 +63,8 @@ type Events interface {
BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error)
BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error)
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
SelectEventRejected(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventID string) (rejected bool, err error) SelectEventRejected(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventID string) (rejected bool, err error)

View file

@ -38,6 +38,11 @@ type EventNID int64
// RoomNID is a numeric ID for a room. // RoomNID is a numeric ID for a room.
type RoomNID int64 type RoomNID int64
type EventMetadata struct {
EventNID EventNID
RoomNID RoomNID
}
// StateSnapshotNID is a numeric ID for the state at an event. // StateSnapshotNID is a numeric ID for the state at an event.
type StateSnapshotNID int64 type StateSnapshotNID int64