Refactor StoreEvent, add MaybeRedactEvent, create a EventDatabase

This commit is contained in:
Till Faelligen 2023-02-28 12:42:06 +01:00
parent eddf31f915
commit e47ad14a18
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
32 changed files with 439 additions and 375 deletions

View file

@ -122,6 +122,7 @@ func (s *OutputRoomEventConsumer) onMessage(
if len(output.NewRoomEvent.AddsStateEventIDs) > 0 { if len(output.NewRoomEvent.AddsStateEventIDs) > 0 {
newEventID := output.NewRoomEvent.Event.EventID() newEventID := output.NewRoomEvent.Event.EventID()
eventsReq := &api.QueryEventsByIDRequest{ eventsReq := &api.QueryEventsByIDRequest{
RoomID: output.NewRoomEvent.Event.RoomID(),
EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)), EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)),
} }
eventsRes := &api.QueryEventsByIDResponse{} eventsRes := &api.QueryEventsByIDResponse{}

View file

@ -57,7 +57,7 @@ func SendRedaction(
} }
} }
ev := roomserverAPI.GetEvent(req.Context(), rsAPI, eventID) ev := roomserverAPI.GetEvent(req.Context(), rsAPI, roomID, eventID)
if ev == nil { if ev == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: 400, Code: 400,

View file

@ -62,9 +62,10 @@ func main() {
panic(err) panic(err)
} }
stateres := state.NewStateResolution(roomserverDB, &types.RoomInfo{ roomInfo := &types.RoomInfo{
RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion), RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion),
}) }
stateres := state.NewStateResolution(roomserverDB, roomInfo)
if *difference { if *difference {
if len(snapshotNIDs) != 2 { if len(snapshotNIDs) != 2 {
@ -87,7 +88,7 @@ func main() {
} }
var eventEntries []types.Event var eventEntries []types.Event
eventEntries, err = roomserverDB.Events(ctx, 0, eventNIDs) eventEntries, err = roomserverDB.Events(ctx, roomInfo, eventNIDs)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -145,7 +146,7 @@ func main() {
} }
fmt.Println("Fetching", len(eventNIDMap), "state events") fmt.Println("Fetching", len(eventNIDMap), "state events")
eventEntries, err := roomserverDB.Events(ctx, 0, eventNIDs) eventEntries, err := roomserverDB.Events(ctx, roomInfo, eventNIDs)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -165,7 +166,7 @@ func main() {
} }
fmt.Println("Fetching", len(authEventIDs), "auth events") fmt.Println("Fetching", len(authEventIDs), "auth events")
authEventEntries, err := roomserverDB.EventsFromIDs(ctx, 0, authEventIDs) authEventEntries, err := roomserverDB.EventsFromIDs(ctx, roomInfo, authEventIDs)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View file

@ -173,6 +173,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
// Finally, work out if there are any more events missing. // Finally, work out if there are any more events missing.
if len(missingEventIDs) > 0 { if len(missingEventIDs) > 0 {
eventsReq := &api.QueryEventsByIDRequest{ eventsReq := &api.QueryEventsByIDRequest{
RoomID: ore.Event.RoomID(),
EventIDs: missingEventIDs, EventIDs: missingEventIDs,
} }
eventsRes := &api.QueryEventsByIDResponse{} eventsRes := &api.QueryEventsByIDResponse{}
@ -483,7 +484,7 @@ func (s *OutputRoomEventConsumer) lookupStateEvents(
// At this point the missing events are neither the event itself nor are // At this point the missing events are neither the event itself nor are
// they present in our local database. Our only option is to fetch them // they present in our local database. Our only option is to fetch them
// from the roomserver using the query API. // from the roomserver using the query API.
eventReq := api.QueryEventsByIDRequest{EventIDs: missing} eventReq := api.QueryEventsByIDRequest{EventIDs: missing, RoomID: event.RoomID()}
var eventResp api.QueryEventsByIDResponse var eventResp api.QueryEventsByIDResponse
if err := s.rsAPI.QueryEventsByID(s.ctx, &eventReq, &eventResp); err != nil { if err := s.rsAPI.QueryEventsByID(s.ctx, &eventReq, &eventResp); err != nil {
return nil, err return nil, err

View file

@ -36,7 +36,7 @@ func GetEventAuth(
return *err return *err
} }
event, resErr := fetchEvent(ctx, rsAPI, eventID) event, resErr := fetchEvent(ctx, rsAPI, roomID, eventID)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }

View file

@ -38,7 +38,7 @@ func GetEvent(
if err != nil { if err != nil {
return *err return *err
} }
event, err := fetchEvent(ctx, rsAPI, eventID) event, err := fetchEvent(ctx, rsAPI, "", eventID)
if err != nil { if err != nil {
return *err return *err
} }
@ -83,11 +83,11 @@ func allowedToSeeEvent(
} }
// fetchEvent fetches the event without auth checks. Returns an error if the event cannot be found. // fetchEvent fetches the event without auth checks. Returns an error if the event cannot be found.
func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, eventID string) (*gomatrixserverlib.Event, *util.JSONResponse) { func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, roomID, eventID string) (*gomatrixserverlib.Event, *util.JSONResponse) {
var eventsResponse api.QueryEventsByIDResponse var eventsResponse api.QueryEventsByIDResponse
err := rsAPI.QueryEventsByID( err := rsAPI.QueryEventsByID(
ctx, ctx,
&api.QueryEventsByIDRequest{EventIDs: []string{eventID}}, &api.QueryEventsByIDRequest{EventIDs: []string{eventID}, RoomID: roomID},
&eventsResponse, &eventsResponse,
) )
if err != nil { if err != nil {

View file

@ -107,7 +107,7 @@ func getState(
return nil, nil, err return nil, nil, err
} }
event, resErr := fetchEvent(ctx, rsAPI, eventID) event, resErr := fetchEvent(ctx, rsAPI, roomID, eventID)
if resErr != nil { if resErr != nil {
return nil, nil, resErr return nil, nil, resErr
} }

View file

@ -16,7 +16,9 @@
// Hooks can only be run in monolith mode. // Hooks can only be run in monolith mode.
package hooks package hooks
import "sync" import (
"sync"
)
const ( const (
// KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent // KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent

View file

@ -86,6 +86,7 @@ type QueryStateAfterEventsResponse struct {
// QueryEventsByIDRequest is a request to QueryEventsByID // QueryEventsByIDRequest is a request to QueryEventsByID
type QueryEventsByIDRequest struct { type QueryEventsByIDRequest struct {
RoomID string `json:"room_id"`
// The event IDs to look up. // The event IDs to look up.
EventIDs []string `json:"event_ids"` EventIDs []string `json:"event_ids"`
} }

View file

@ -108,9 +108,10 @@ func SendInputRoomEvents(
} }
// GetEvent returns the event or nil, even on errors. // GetEvent returns the event or nil, even on errors.
func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, eventID string) *gomatrixserverlib.HeaderedEvent { func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID, eventID string) *gomatrixserverlib.HeaderedEvent {
var res QueryEventsByIDResponse var res QueryEventsByIDResponse
err := rsAPI.QueryEventsByID(ctx, &QueryEventsByIDRequest{ err := rsAPI.QueryEventsByID(ctx, &QueryEventsByIDRequest{
RoomID: roomID,
EventIDs: []string{eventID}, EventIDs: []string{eventID},
}, &res) }, &res)
if err != nil { if err != nil {

View file

@ -67,7 +67,7 @@ func CheckForSoftFail(
stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()})
// Load the actual auth events from the database. // Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, roomInfo.RoomNID, stateNeeded, authStateEntries) authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries)
if err != nil { if err != nil {
return true, fmt.Errorf("loadAuthEvents: %w", err) return true, fmt.Errorf("loadAuthEvents: %w", err)
} }
@ -85,7 +85,7 @@ func CheckForSoftFail(
func CheckAuthEvents( func CheckAuthEvents(
ctx context.Context, ctx context.Context,
db storage.RoomDatabase, db storage.RoomDatabase,
roomNID types.RoomNID, roomInfo *types.RoomInfo,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
authEventIDs []string, authEventIDs []string,
) ([]types.EventNID, error) { ) ([]types.EventNID, error) {
@ -100,7 +100,7 @@ func CheckAuthEvents(
stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()})
// Load the actual auth events from the database. // Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, roomNID, stateNeeded, authStateEntries) authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries)
if err != nil { if err != nil {
return nil, fmt.Errorf("loadAuthEvents: %w", err) return nil, fmt.Errorf("loadAuthEvents: %w", err)
} }
@ -193,7 +193,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
func loadAuthEvents( func loadAuthEvents(
ctx context.Context, ctx context.Context,
db state.StateResolutionStorage, db state.StateResolutionStorage,
roomNID types.RoomNID, roomInfo *types.RoomInfo,
needed gomatrixserverlib.StateNeeded, needed gomatrixserverlib.StateNeeded,
state []types.StateEntry, state []types.StateEntry,
) (result authEvents, err error) { ) (result authEvents, err error) {
@ -216,7 +216,7 @@ func loadAuthEvents(
eventNIDs = append(eventNIDs, eventNID) eventNIDs = append(eventNIDs, eventNID)
} }
} }
if result.events, err = db.Events(ctx, roomNID, eventNIDs); err != nil { if result.events, err = db.Events(ctx, roomInfo, eventNIDs); err != nil {
return return
} }
roomID := "" roomID := ""

View file

@ -85,7 +85,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
return false, err return false, err
} }
events, err := db.Events(ctx, info.RoomNID, eventNIDs) events, err := db.Events(ctx, info, eventNIDs)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -157,7 +157,7 @@ func IsInvitePending(
// only keep the "m.room.member" events with a "join" membership. These events are returned. // only keep the "m.room.member" events with a "join" membership. These events are returned.
// Returns an error if there was an issue fetching the events. // Returns an error if there was an issue fetching the events.
func GetMembershipsAtState( func GetMembershipsAtState(
ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry, joinedOnly bool, ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, joinedOnly bool,
) ([]types.Event, error) { ) ([]types.Event, error) {
var eventNIDs types.EventNIDs var eventNIDs types.EventNIDs
@ -177,7 +177,7 @@ func GetMembershipsAtState(
util.Unique(eventNIDs) util.Unique(eventNIDs)
// Get all of the events in this state // Get all of the events in this state
stateEvents, err := db.Events(ctx, roomNID, eventNIDs) stateEvents, err := db.Events(ctx, roomInfo, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -227,9 +227,9 @@ func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types
} }
func LoadEvents( func LoadEvents(
ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, eventNIDs []types.EventNID, ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, eventNIDs []types.EventNID,
) ([]*gomatrixserverlib.Event, error) { ) ([]*gomatrixserverlib.Event, error) {
stateEvents, err := db.Events(ctx, roomNID, eventNIDs) stateEvents, err := db.Events(ctx, roomInfo, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -242,13 +242,13 @@ func LoadEvents(
} }
func LoadStateEvents( func LoadStateEvents(
ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry, ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry,
) ([]*gomatrixserverlib.Event, error) { ) ([]*gomatrixserverlib.Event, error) {
eventNIDs := make([]types.EventNID, len(stateEntries)) eventNIDs := make([]types.EventNID, len(stateEntries))
for i := range stateEntries { for i := range stateEntries {
eventNIDs[i] = stateEntries[i].EventNID eventNIDs[i] = stateEntries[i].EventNID
} }
return LoadEvents(ctx, db, roomNID, eventNIDs) return LoadEvents(ctx, db, roomInfo, eventNIDs)
} }
func CheckServerAllowedToSeeEvent( func CheckServerAllowedToSeeEvent(
@ -326,7 +326,7 @@ func slowGetHistoryVisibilityState(
return nil, nil return nil, nil
} }
return LoadStateEvents(ctx, db, info.RoomNID, filteredEntries) return LoadStateEvents(ctx, db, info, filteredEntries)
} }
// TODO: Remove this when we have tests to assert correctness of this function // TODO: Remove this when we have tests to assert correctness of this function
@ -366,7 +366,7 @@ BFSLoop:
next = make([]string, 0) next = make([]string, 0)
} }
// Retrieve the events to process from the database. // Retrieve the events to process from the database.
events, err = db.EventsFromIDs(ctx, info.RoomNID, front) events, err = db.EventsFromIDs(ctx, info, front)
if err != nil { if err != nil {
return resultNIDs, redactEventIDs, err return resultNIDs, redactEventIDs, err
} }
@ -467,7 +467,7 @@ func QueryLatestEventsAndState(
return err return err
} }
stateEvents, err := LoadStateEvents(ctx, db, roomInfo.RoomNID, stateEntries) stateEvents, err := LoadStateEvents(ctx, db, roomInfo, stateEntries)
if err != nil { if err != nil {
return err return err
} }

View file

@ -38,7 +38,7 @@ func TestIsInvitePendingWithoutNID(t *testing.T) {
var authNIDs []types.EventNID var authNIDs []types.EventNID
for _, x := range room.Events() { for _, x := range room.Events() {
roomNID, err := db.GetOrCreateRoomNID(context.Background(), x.Unwrap()) roomNID, roomInfo, err := db.GetOrCreateRoomNID(context.Background(), x.Unwrap())
assert.NoError(t, err) assert.NoError(t, err)
assert.Greater(t, roomNID, types.RoomNID(0)) assert.Greater(t, roomNID, types.RoomNID(0))
@ -49,7 +49,7 @@ func TestIsInvitePendingWithoutNID(t *testing.T) {
eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(context.Background(), x.StateKey()) eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(context.Background(), x.StateKey())
assert.NoError(t, err) assert.NoError(t, err)
evNID, _, _, _, err := db.StoreEvent(context.Background(), x.Event, roomNID, eventTypeNID, eventStateKeyNID, authNIDs, false) evNID, _, err := db.StoreEvent(context.Background(), x.Event, roomInfo, eventTypeNID, eventStateKeyNID, authNIDs, false)
assert.NoError(t, err) assert.NoError(t, err)
authNIDs = append(authNIDs, evNID) authNIDs = append(authNIDs, evNID)
} }

View file

@ -274,8 +274,10 @@ func (r *Inputer) processRoomEvent(
// Check if the event is allowed by its auth events. If it isn't then // Check if the event is allowed by its auth events. If it isn't then
// we consider the event to be "rejected" — it will still be persisted. // we consider the event to be "rejected" — it will still be persisted.
redactAllowed := true
if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {
isRejected = true isRejected = true
redactAllowed = false
rejectionErr = err rejectionErr = err
logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID())
} }
@ -323,7 +325,7 @@ func (r *Inputer) processRoomEvent(
// burning CPU time. // burning CPU time.
historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared. historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared.
if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected && !isCreateEvent { if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected && !isCreateEvent {
historyVisibility, rejectionErr, err = r.processStateBefore(ctx, roomInfo.RoomNID, input, missingPrev) historyVisibility, rejectionErr, err = r.processStateBefore(ctx, roomInfo, input, missingPrev)
if err != nil { if err != nil {
return fmt.Errorf("r.processStateBefore: %w", err) return fmt.Errorf("r.processStateBefore: %w", err)
} }
@ -332,10 +334,12 @@ func (r *Inputer) processRoomEvent(
} }
} }
roomNID, err := r.DB.GetOrCreateRoomNID(ctx, event) if roomInfo == nil {
_, roomInfo, err = r.DB.GetOrCreateRoomNID(ctx, event)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err) return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err)
} }
}
eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, event.Type()) eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, event.Type())
if err != nil { if err != nil {
@ -348,15 +352,24 @@ func (r *Inputer) processRoomEvent(
} }
// Store the event. // Store the event.
_, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected) eventNID, stateAtEvent, err := r.DB.StoreEvent(ctx, event, roomInfo, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected)
if err != nil { if err != nil {
return fmt.Errorf("updater.StoreEvent: %w", err) return fmt.Errorf("updater.StoreEvent: %w", err)
} }
// if storing this event results in it being redacted then do so. // if storing this event results in it being redacted then do so.
if !isRejected && redactedEventID == event.EventID() { var (
if err = eventutil.RedactEvent(redactionEvent, event); err != nil { redactedEventID string
return fmt.Errorf("eventutil.RedactEvent: %w", rerr) redactionEvent *gomatrixserverlib.Event
redactedEvent *gomatrixserverlib.Event
)
if !isRejected && !isCreateEvent {
redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, redactAllowed)
if err != nil {
return err
}
if redactedEvent != nil {
redactedEventID = redactedEvent.EventID()
} }
} }
@ -489,7 +502,7 @@ func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixse
// nolint:nakedret // nolint:nakedret
func (r *Inputer) processStateBefore( func (r *Inputer) processStateBefore(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID, roomInfo *types.RoomInfo,
input *api.InputRoomEvent, input *api.InputRoomEvent,
missingPrev bool, missingPrev bool,
) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) { ) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) {
@ -505,7 +518,7 @@ func (r *Inputer) processStateBefore(
case input.HasState: case input.HasState:
// If we're overriding the state then we need to go and retrieve // If we're overriding the state then we need to go and retrieve
// them from the database. It's a hard error if they are missing. // them from the database. It's a hard error if they are missing.
stateEvents, err := r.DB.EventsFromIDs(ctx, roomNID, input.StateEventIDs) stateEvents, err := r.DB.EventsFromIDs(ctx, roomInfo, input.StateEventIDs)
if err != nil { if err != nil {
return "", nil, fmt.Errorf("r.DB.EventsFromIDs: %w", err) return "", nil, fmt.Errorf("r.DB.EventsFromIDs: %w", err)
} }
@ -604,7 +617,7 @@ func (r *Inputer) fetchAuthEvents(
} }
for _, authEventID := range authEventIDs { for _, authEventID := range authEventIDs {
authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo.RoomNID, []string{authEventID}) authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo, []string{authEventID})
if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil {
unknown[authEventID] = struct{}{} unknown[authEventID] = struct{}{}
continue continue
@ -690,10 +703,12 @@ nextAuthEvent:
logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID())
} }
roomNID, err := r.DB.GetOrCreateRoomNID(ctx, authEvent) if roomInfo == nil {
_, roomInfo, err = r.DB.GetOrCreateRoomNID(ctx, authEvent)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err) return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err)
} }
}
eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, authEvent.Type()) eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, authEvent.Type())
if err != nil { if err != nil {
@ -706,7 +721,7 @@ nextAuthEvent:
} }
// Finally, store the event in the database. // Finally, store the event in the database.
eventNID, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected) eventNID, _, err := r.DB.StoreEvent(ctx, authEvent, roomInfo, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected)
if err != nil { if err != nil {
return fmt.Errorf("updater.StoreEvent: %w", err) return fmt.Errorf("updater.StoreEvent: %w", err)
} }
@ -782,7 +797,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event
return err return err
} }
memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, membershipNIDs) memberEvents, err := r.DB.Events(ctx, roomInfo, membershipNIDs)
if err != nil { if err != nil {
return err return err
} }

View file

@ -53,7 +53,7 @@ func (r *Inputer) updateMemberships(
// Load the event JSON so we can look up the "membership" key. // Load the event JSON so we can look up the "membership" key.
// TODO: Maybe add a membership key to the events table so we can load that // TODO: Maybe add a membership key to the events table so we can load that
// key without having to load the entire event JSON? // key without having to load the entire event JSON?
events, err := updater.Events(ctx, 0, eventNIDs) events, err := updater.Events(ctx, nil, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -395,7 +395,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
for _, entry := range stateEntries { for _, entry := range stateEntries {
stateEventNIDs = append(stateEventNIDs, entry.EventNID) stateEventNIDs = append(stateEventNIDs, entry.EventNID)
} }
stateEvents, err := t.db.Events(ctx, t.roomInfo.RoomNID, stateEventNIDs) stateEvents, err := t.db.Events(ctx, t.roomInfo, stateEventNIDs)
if err != nil { if err != nil {
t.log.WithError(err).Warnf("failed to load state events locally") t.log.WithError(err).Warnf("failed to load state events locally")
return nil return nil
@ -432,7 +432,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
missingEventList = append(missingEventList, evID) missingEventList = append(missingEventList, evID)
} }
t.log.WithField("count", len(missingEventList)).Debugf("Fetching missing auth events") t.log.WithField("count", len(missingEventList)).Debugf("Fetching missing auth events")
events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList) events, err := t.db.EventsFromIDs(ctx, t.roomInfo, missingEventList)
if err != nil { if err != nil {
return nil return nil
} }
@ -702,7 +702,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
} }
t.haveEventsMutex.Unlock() t.haveEventsMutex.Unlock()
events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList) events, err := t.db.EventsFromIDs(ctx, t.roomInfo, missingEventList)
if err != nil { if err != nil {
return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err) return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err)
} }
@ -844,7 +844,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
if localFirst { if localFirst {
// fetch from the roomserver // fetch from the roomserver
events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, []string{missingEventID}) events, err := t.db.EventsFromIDs(ctx, t.roomInfo, []string{missingEventID})
if err != nil { if err != nil {
t.log.Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) t.log.Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err)
} else if len(events) == 1 { } else if len(events) == 1 {

View file

@ -70,7 +70,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
return nil return nil
} }
memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, memberNIDs) memberEvents, err := r.DB.Events(ctx, roomInfo, memberNIDs)
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,

View file

@ -23,7 +23,6 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
federationAPI "github.com/matrix-org/dendrite/federationapi/api" federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/auth" "github.com/matrix-org/dendrite/roomserver/auth"
"github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/internal/helpers"
@ -86,7 +85,7 @@ func (r *Backfiller) PerformBackfill(
// Retrieve events from the list that was filled previously. If we fail to get // Retrieve events from the list that was filled previously. If we fail to get
// events from the database then attempt once to get them from federation instead. // events from the database then attempt once to get them from federation instead.
var loadedEvents []*gomatrixserverlib.Event var loadedEvents []*gomatrixserverlib.Event
loadedEvents, err = helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs) loadedEvents, err = helpers.LoadEvents(ctx, r.DB, info, resultNIDs)
if err != nil { if err != nil {
if _, ok := err.(types.MissingEventError); ok { if _, ok := err.(types.MissingEventError); ok {
return r.backfillViaFederation(ctx, request, response) return r.backfillViaFederation(ctx, request, response)
@ -473,7 +472,7 @@ FindSuccessor:
// Retrieve all "m.room.member" state events of "join" membership, which // Retrieve all "m.room.member" state events of "join" membership, which
// contains the list of users in the room before the event, therefore all // contains the list of users in the room before the event, therefore all
// the servers in it at that moment. // the servers in it at that moment.
memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, info.RoomNID, stateEntries, true) memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, info, stateEntries, true)
if err != nil { if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event") logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event")
return nil return nil
@ -532,7 +531,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
roomNID = nid.RoomNID roomNID = nid.RoomNID
} }
} }
eventsWithNids, err := b.db.Events(ctx, roomNID, eventNIDs) eventsWithNids, err := b.db.Events(ctx, &b.roomInfo, eventNIDs)
if err != nil { if err != nil {
logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events")
return nil, err return nil, err
@ -562,7 +561,7 @@ func joinEventsFromHistoryVisibility(
} }
// Get all of the events in this state // Get all of the events in this state
stateEvents, err := db.Events(ctx, roomInfo.RoomNID, eventNIDs) stateEvents, err := db.Events(ctx, roomInfo, eventNIDs)
if err != nil { if err != nil {
// even though the default should be shared, restricting the visibility to joined // even though the default should be shared, restricting the visibility to joined
// feels more secure here. // feels more secure here.
@ -585,7 +584,7 @@ func joinEventsFromHistoryVisibility(
if err != nil { if err != nil {
return nil, visibility, err return nil, visibility, err
} }
evs, err := db.Events(ctx, roomInfo.RoomNID, joinEventNIDs) evs, err := db.Events(ctx, roomInfo, joinEventNIDs)
return evs, visibility, err return evs, visibility, err
} }
@ -606,7 +605,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs
i++ i++
} }
roomNID, err = db.GetOrCreateRoomNID(ctx, ev.Unwrap()) _, roomInfo, err := db.GetOrCreateRoomNID(ctx, ev.Unwrap())
if err != nil { if err != nil {
logrus.WithError(err).Error("failed to get or create roomNID") logrus.WithError(err).Error("failed to get or create roomNID")
continue continue
@ -624,23 +623,21 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs
continue continue
} }
var redactedEventID string eventNID, _, err = db.StoreEvent(ctx, ev.Unwrap(), roomInfo, eventTypeNID, eventStateKeyNID, authNids, false)
var redactionEvent *gomatrixserverlib.Event
eventNID, _, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), roomNID, eventTypeNID, eventStateKeyNID, authNids, false)
if err != nil { if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
continue continue
} }
_, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.Unwrap(), true)
if err != nil {
continue
}
// If storing this event results in it being redacted, then do so. // If storing this event results in it being redacted, then do so.
// It's also possible for this event to be a redaction which results in another event being // It's also possible for this event to be a redaction which results in another event being
// redacted, which we don't care about since we aren't returning it in this backfill. // redacted, which we don't care about since we aren't returning it in this backfill.
if redactedEventID == ev.EventID() { if redactedEvent != nil && redactedEvent.EventID() == ev.EventID() {
eventToRedact := ev.Unwrap() ev = redactedEvent.Headered(ev.RoomVersion)
if err := eventutil.RedactEvent(redactionEvent, eventToRedact); err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event")
continue
}
ev = eventToRedact.Headered(ev.RoomVersion)
events[j] = ev events[j] = ev
} }
backfilledEventMap[ev.EventID()] = types.Event{ backfilledEventMap[ev.EventID()] = types.Event{

View file

@ -64,7 +64,7 @@ func (r *InboundPeeker) PerformInboundPeek(
if err != nil { if err != nil {
return err return err
} }
latestEvents, err := r.DB.EventsFromIDs(ctx, info.RoomNID, []string{latestEventRefs[0].EventID}) latestEvents, err := r.DB.EventsFromIDs(ctx, info, []string{latestEventRefs[0].EventID})
if err != nil { if err != nil {
return err return err
} }
@ -88,7 +88,7 @@ func (r *InboundPeeker) PerformInboundPeek(
if err != nil { if err != nil {
return err return err
} }
stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries) stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, info, stateEntries)
if err != nil { if err != nil {
return err return err
} }
@ -100,7 +100,7 @@ func (r *InboundPeeker) PerformInboundPeek(
} }
authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe
authEvents, err := query.GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) authEvents, err := query.GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs)
if err != nil { if err != nil {
return err return err
} }

View file

@ -194,7 +194,7 @@ func (r *Inviter) PerformInvite(
// try and see if the user is allowed to make this invite. We can't do // try and see if the user is allowed to make this invite. We can't do
// this for invites coming in over federation - we have to take those on // this for invites coming in over federation - we have to take those on
// trust. // trust.
_, err = helpers.CheckAuthEvents(ctx, r.DB, info.RoomNID, event, event.AuthEventIDs()) _, err = helpers.CheckAuthEvents(ctx, r.DB, info, event, event.AuthEventIDs())
if err != nil { if err != nil {
logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
"processInviteEvent.checkAuthEvents failed for event", "processInviteEvent.checkAuthEvents failed for event",
@ -291,7 +291,7 @@ func buildInviteStrippedState(
for _, stateNID := range stateEntries { for _, stateNID := range stateEntries {
stateNIDs = append(stateNIDs, stateNID.EventNID) stateNIDs = append(stateNIDs, stateNID.EventNID)
} }
stateEvents, err := db.Events(ctx, info.RoomNID, stateNIDs) stateEvents, err := db.Events(ctx, info, stateNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -102,7 +102,7 @@ func (r *Queryer) QueryStateAfterEvents(
return err return err
} }
stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries) stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info, stateEntries)
if err != nil { if err != nil {
return err return err
} }
@ -114,7 +114,7 @@ func (r *Queryer) QueryStateAfterEvents(
} }
authEventIDs = util.UniqueStrings(authEventIDs) authEventIDs = util.UniqueStrings(authEventIDs)
authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs)
if err != nil { if err != nil {
return fmt.Errorf("getAuthChain: %w", err) return fmt.Errorf("getAuthChain: %w", err)
} }
@ -138,18 +138,39 @@ func (r *Queryer) QueryEventsByID(
request *api.QueryEventsByIDRequest, request *api.QueryEventsByIDRequest,
response *api.QueryEventsByIDResponse, response *api.QueryEventsByIDResponse,
) error { ) error {
events, err := r.DB.EventsFromIDs(ctx, 0, request.EventIDs) if len(request.EventIDs) == 0 {
return nil
}
var err error
// We didn't receive a room ID, we need to fetch it first before we can continue.
// This happens for e.g. ` /_matrix/federation/v1/event/{eventId}`
var roomInfo *types.RoomInfo
if request.RoomID == "" {
var eventNIDs map[string]types.EventMetadata
eventNIDs, err = r.DB.EventNIDs(ctx, []string{request.EventIDs[0]})
if err != nil {
return err
}
if len(eventNIDs) == 0 {
return nil
}
roomInfo, err = r.DB.RoomInfoByNID(ctx, eventNIDs[request.EventIDs[0]].RoomNID)
} else {
roomInfo, err = r.DB.RoomInfo(ctx, request.RoomID)
}
if err != nil {
return err
}
if roomInfo == nil {
return nil
}
events, err := r.DB.EventsFromIDs(ctx, roomInfo, request.EventIDs)
if err != nil { if err != nil {
return err return err
} }
for _, event := range events { for _, event := range events {
roomVersion, verr := r.roomVersion(event.RoomID()) response.Events = append(response.Events, event.Headered(roomInfo.RoomVersion))
if verr != nil {
return verr
}
response.Events = append(response.Events, event.Headered(roomVersion))
} }
return nil return nil
@ -186,7 +207,7 @@ func (r *Queryer) QueryMembershipForUser(
response.IsInRoom = stillInRoom response.IsInRoom = stillInRoom
response.HasBeenInRoom = true response.HasBeenInRoom = true
evs, err := r.DB.Events(ctx, info.RoomNID, []types.EventNID{membershipEventNID}) evs, err := r.DB.Events(ctx, info, []types.EventNID{membershipEventNID})
if err != nil { if err != nil {
return err return err
} }
@ -268,10 +289,10 @@ func (r *Queryer) QueryMembershipAtEvent(
// once. If we have more than one membership event, we need to get the state for each state entry. // once. If we have more than one membership event, we need to get the state for each state entry.
if canShortCircuit { if canShortCircuit {
if len(memberships) == 0 { if len(memberships) == 0 {
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false) memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
} }
} else { } else {
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false) memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
} }
if err != nil { if err != nil {
return fmt.Errorf("unable to get memberships at state: %w", err) return fmt.Errorf("unable to get memberships at state: %w", err)
@ -318,7 +339,7 @@ func (r *Queryer) QueryMembershipsForRoom(
} }
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
} }
events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs) events, err = r.DB.Events(ctx, info, eventNIDs)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.Events: %w", err) return fmt.Errorf("r.DB.Events: %w", err)
} }
@ -357,14 +378,14 @@ func (r *Queryer) QueryMembershipsForRoom(
return err return err
} }
events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs) events, err = r.DB.Events(ctx, info, eventNIDs)
} else { } else {
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID) stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID)
if err != nil { if err != nil {
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
return err return err
} }
events, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntries, request.JoinedOnly) events, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntries, request.JoinedOnly)
} }
if err != nil { if err != nil {
@ -415,7 +436,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
request *api.QueryServerAllowedToSeeEventRequest, request *api.QueryServerAllowedToSeeEventRequest,
response *api.QueryServerAllowedToSeeEventResponse, response *api.QueryServerAllowedToSeeEventResponse,
) (err error) { ) (err error) {
events, err := r.DB.EventsFromIDs(ctx, 0, []string{request.EventID}) events, err := r.DB.EventsFromIDs(ctx, nil, []string{request.EventID})
if err != nil { if err != nil {
return return
} }
@ -466,7 +487,7 @@ func (r *Queryer) QueryMissingEvents(
eventsToFilter[id] = true eventsToFilter[id] = true
} }
} }
events, err := r.DB.EventsFromIDs(ctx, 0, front) events, err := r.DB.EventsFromIDs(ctx, nil, front)
if err != nil { if err != nil {
return err return err
} }
@ -486,7 +507,7 @@ func (r *Queryer) QueryMissingEvents(
return err return err
} }
loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs) loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info, resultNIDs)
if err != nil { if err != nil {
return err return err
} }
@ -529,7 +550,7 @@ func (r *Queryer) QueryStateAndAuthChain(
// TODO: this probably means it should be a different query operation... // TODO: this probably means it should be a different query operation...
if request.OnlyFetchAuthChain { if request.OnlyFetchAuthChain {
var authEvents []*gomatrixserverlib.Event var authEvents []*gomatrixserverlib.Event
authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, request.AuthEventIDs) authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, info, request.AuthEventIDs)
if err != nil { if err != nil {
return err return err
} }
@ -556,7 +577,7 @@ func (r *Queryer) QueryStateAndAuthChain(
} }
authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe
authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs)
if err != nil { if err != nil {
return err return err
} }
@ -611,18 +632,18 @@ func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomI
return nil, rejected, false, err return nil, rejected, false, err
} }
events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo.RoomNID, stateEntries) events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo, stateEntries)
return events, rejected, false, err return events, rejected, false, err
} }
type eventsFromIDs func(context.Context, types.RoomNID, []string) ([]types.Event, error) type eventsFromIDs func(context.Context, *types.RoomInfo, []string) ([]types.Event, error)
// GetAuthChain fetches the auth chain for the given auth events. An auth chain // GetAuthChain fetches the auth chain for the given auth events. An auth chain
// is the list of all events that are referenced in the auth_events section, and // is the list of all events that are referenced in the auth_events section, and
// all their auth_events, recursively. The returned set of events contain the // all their auth_events, recursively. The returned set of events contain the
// given events. Will *not* error if we don't have all auth events. // given events. Will *not* error if we don't have all auth events.
func GetAuthChain( func GetAuthChain(
ctx context.Context, fn eventsFromIDs, authEventIDs []string, ctx context.Context, fn eventsFromIDs, roomInfo *types.RoomInfo, authEventIDs []string,
) ([]*gomatrixserverlib.Event, error) { ) ([]*gomatrixserverlib.Event, error) {
// List of event IDs to fetch. On each pass, these events will be requested // List of event IDs to fetch. On each pass, these events will be requested
// from the database and the `eventsToFetch` will be updated with any new // from the database and the `eventsToFetch` will be updated with any new
@ -633,7 +654,7 @@ func GetAuthChain(
for len(eventsToFetch) > 0 { for len(eventsToFetch) > 0 {
// Try to retrieve the events from the database. // Try to retrieve the events from the database.
events, err := fn(ctx, 0, eventsToFetch) events, err := fn(ctx, roomInfo, eventsToFetch)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -852,7 +873,7 @@ func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryS
} }
func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse) error { func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse) error {
chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, req.EventIDs) chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, nil, req.EventIDs)
if err != nil { if err != nil {
return err return err
} }
@ -971,7 +992,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
// For each of the joined users, let's see if we can get a valid // For each of the joined users, let's see if we can get a valid
// membership event. // membership event.
for _, joinNID := range joinNIDs { for _, joinNID := range joinNIDs {
events, err := r.DB.Events(ctx, roomInfo.RoomNID, []types.EventNID{joinNID}) events, err := r.DB.Events(ctx, roomInfo, []types.EventNID{joinNID})
if err != nil || len(events) != 1 { if err != nil || len(events) != 1 {
continue continue
} }

View file

@ -80,7 +80,7 @@ func (db *getEventDB) addFakeEvents(graph map[string][]string) error {
} }
// EventsFromIDs implements RoomserverInternalAPIEventDB // EventsFromIDs implements RoomserverInternalAPIEventDB
func (db *getEventDB) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) (res []types.Event, err error) { func (db *getEventDB) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) (res []types.Event, err error) {
for _, evID := range eventIDs { for _, evID := range eventIDs {
res = append(res, types.Event{ res = append(res, types.Event{
EventNID: 0, EventNID: 0,
@ -106,7 +106,7 @@ func TestGetAuthChainSingle(t *testing.T) {
t.Fatalf("Failed to add events to db: %v", err) t.Fatalf("Failed to add events to db: %v", err)
} }
result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e"}) result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, nil, []string{"e"})
if err != nil { if err != nil {
t.Fatalf("getAuthChain failed: %v", err) t.Fatalf("getAuthChain failed: %v", err)
} }
@ -139,7 +139,7 @@ func TestGetAuthChainMultiple(t *testing.T) {
t.Fatalf("Failed to add events to db: %v", err) t.Fatalf("Failed to add events to db: %v", err)
} }
result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e", "f"}) result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, nil, []string{"e", "f"})
if err != nil { if err != nil {
t.Fatalf("getAuthChain failed: %v", err) t.Fatalf("getAuthChain failed: %v", err)
} }

View file

@ -41,8 +41,8 @@ type StateResolutionStorage interface {
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
} }
type StateResolution struct { type StateResolution struct {
@ -975,7 +975,7 @@ func (v *StateResolution) resolveConflictsV2(
// Store the newly found auth events in the auth set for this event. // Store the newly found auth events in the auth set for this event.
var authEventMap map[string]types.StateEntry var authEventMap map[string]types.StateEntry
authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, v.roomInfo.RoomNID, conflictedEvent, knownAuthEvents) authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, v.roomInfo, conflictedEvent, knownAuthEvents)
if err != nil { if err != nil {
return err return err
} }
@ -1091,7 +1091,7 @@ func (v *StateResolution) loadStateEvents(
eventNIDs = append(eventNIDs, entry.EventNID) eventNIDs = append(eventNIDs, entry.EventNID)
} }
} }
events, err := v.db.Events(ctx, v.roomInfo.RoomNID, eventNIDs) events, err := v.db.Events(ctx, v.roomInfo, eventNIDs)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -1120,7 +1120,7 @@ type authEventLoader struct {
// loadAuthEvents loads all of the auth events for a given event recursively, // loadAuthEvents loads all of the auth events for a given event recursively,
// along with a map that contains state entries for all of the auth events. // along with a map that contains state entries for all of the auth events.
func (l *authEventLoader) loadAuthEvents( func (l *authEventLoader) loadAuthEvents(
ctx context.Context, roomNID types.RoomNID, event *gomatrixserverlib.Event, eventMap map[string]types.Event, ctx context.Context, roomInfo *types.RoomInfo, event *gomatrixserverlib.Event, eventMap map[string]types.Event,
) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) {
l.Lock() l.Lock()
defer l.Unlock() defer l.Unlock()
@ -1155,7 +1155,7 @@ func (l *authEventLoader) loadAuthEvents(
// If we need to get events from the database, go and fetch // If we need to get events from the database, go and fetch
// those now. // those now.
if len(l.lookupFromDB) > 0 { if len(l.lookupFromDB) > 0 {
eventsFromDB, err := l.v.db.EventsFromIDs(ctx, roomNID, l.lookupFromDB) eventsFromDB, err := l.v.db.EventsFromIDs(ctx, roomInfo, l.lookupFromDB)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err) return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err)
} }

View file

@ -29,6 +29,7 @@ type Database interface {
SupportsConcurrentRoomInputs() bool SupportsConcurrentRoomInputs() bool
// RoomInfo returns room information for the given room ID, or nil if there is no room. // RoomInfo returns room information for the given room ID, or nil if there is no room.
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error)
// Store the room state at an event in the database // Store the room state at an event in the database
AddState( AddState(
ctx context.Context, ctx context.Context,
@ -41,6 +42,8 @@ type Database interface {
ctx context.Context, e *gomatrixserverlib.Event, ctx context.Context, e *gomatrixserverlib.Event,
) (missingAuth, missingPrev []string, err error) ) (missingAuth, missingPrev []string, err error)
// GetEventDatabase returns an EventDatabase to work with events only.
GetEventDatabase() *shared.EventDatabase
// Look up the state of a room at each event for a list of string event IDs. // Look up the state of a room at each event for a list of string event IDs.
// Returns an error if there is an error talking to the database. // Returns an error if there is an error talking to the database.
// The length of []types.StateAtEvent is guaranteed to equal the length of eventIDs if no error is returned. // The length of []types.StateAtEvent is guaranteed to equal the length of eventIDs if no error is returned.
@ -69,12 +72,12 @@ type Database interface {
) ([]types.StateEntryList, error) ) ([]types.StateEntryList, error)
// Look up the Events for a list of numeric event IDs. // Look up the Events for a list of numeric event IDs.
// Returns a sorted list of events. // Returns a sorted list of events.
Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error)
// Look up snapshot NID for an event ID string // Look up snapshot NID for an event ID string
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. // Stores a matrix room event in the database. Returns the room NID, the state snapshot or an error.
StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error)
// Look up the state entries for a list of string event IDs // Look up the state entries for a list of string event IDs
// Returns an error if the there is an error talking to the database // Returns an error if the there is an error talking to the database
// Returns a types.MissingEventError if the event IDs aren't in the database. // Returns a types.MissingEventError if the event IDs aren't in the database.
@ -135,7 +138,7 @@ type Database interface {
// EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was // EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was
// not found. // not found.
// Returns an error if the retrieval went wrong. // Returns an error if the retrieval went wrong.
EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
// Publish or unpublish a room from the room directory. // Publish or unpublish a room from the room directory.
PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error
// Returns a list of room IDs for rooms which are published. // Returns a list of room IDs for rooms which are published.
@ -179,36 +182,53 @@ type Database interface {
GetMembershipForHistoryVisibility( GetMembershipForHistoryVisibility(
ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string, ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string,
) (map[string]*gomatrixserverlib.HeaderedEvent, error) ) (map[string]*gomatrixserverlib.HeaderedEvent, error)
GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, error) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, *types.RoomInfo, error)
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
MaybeRedactEvent(
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, redactAllowed bool,
) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error)
} }
type RoomDatabase interface { type RoomDatabase interface {
EventDatabase
// RoomInfo returns room information for the given room ID, or nil if there is no room. // RoomInfo returns room information for the given room ID, or nil if there is no room.
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error)
// IsEventRejected returns true if the event is known and rejected. // IsEventRejected returns true if the event is known and rejected.
IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (rejected bool, err error) IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (rejected bool, err error)
MissingAuthPrevEvents(ctx context.Context, e *gomatrixserverlib.Event) (missingAuth, missingPrev []string, err error) MissingAuthPrevEvents(ctx context.Context, e *gomatrixserverlib.Event) (missingAuth, missingPrev []string, err error)
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error)
UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error
GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error) GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error)
StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error)
GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error) GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error)
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error)
LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, *types.RoomInfo, error)
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, error)
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
}
type EventDatabase interface {
EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error)
EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventMetadata, error)
SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error)
// MaybeRedactEvent returns the redaction event and the redacted event if this call resulted in a redaction, else an error
// (nil if there was nothing to do)
MaybeRedactEvent(
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, redactAllowed bool,
) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error)
StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error)
} }

View file

@ -194,22 +194,27 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
return err return err
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: db,
EventDatabase: shared.EventDatabase{
DB: db, DB: db,
Cache: cache, Cache: cache,
Writer: writer, Writer: writer,
EventsTable: events,
EventJSONTable: eventJSON,
EventTypesTable: eventTypes, EventTypesTable: eventTypes,
EventStateKeysTable: eventStateKeys, EventStateKeysTable: eventStateKeys,
EventJSONTable: eventJSON, PrevEventsTable: prevEvents,
EventsTable: events, RedactionsTable: redactions,
},
Cache: cache,
Writer: writer,
RoomsTable: rooms, RoomsTable: rooms,
StateBlockTable: stateBlock, StateBlockTable: stateBlock,
StateSnapshotTable: stateSnapshot, StateSnapshotTable: stateSnapshot,
PrevEventsTable: prevEvents,
RoomAliasesTable: roomAliases, RoomAliasesTable: roomAliases,
InvitesTable: invites, InvitesTable: invites,
MembershipTable: membership, MembershipTable: membership,
PublishedTable: published, PublishedTable: published,
RedactionsTable: redactions,
Purge: purge, Purge: purge,
} }
return nil return nil

View file

@ -116,8 +116,8 @@ func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEvent
}) })
} }
func (u *RoomUpdater) Events(ctx context.Context, _ types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) { func (u *RoomUpdater) Events(ctx context.Context, _ *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) {
return u.d.events(ctx, u.txn, u.roomInfo.RoomNID, eventNIDs) return u.d.events(ctx, u.txn, u.roomInfo, eventNIDs)
} }
func (u *RoomUpdater) SnapshotNIDFromEventID( func (u *RoomUpdater) SnapshotNIDFromEventID(
@ -195,8 +195,8 @@ func (u *RoomUpdater) StateAtEventIDs(
return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs) return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
} }
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) { func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, roomNID, eventIDs, NoFilter) return u.d.eventsFromIDs(ctx, u.txn, u.roomInfo, eventIDs, NoFilter)
} }
// IsReferenced implements types.RoomRecentEventsUpdater // IsReferenced implements types.RoomRecentEventsUpdater

View file

@ -7,9 +7,9 @@ import (
"fmt" "fmt"
"sort" "sort"
"github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
@ -28,6 +28,23 @@ import (
const redactionsArePermanent = true const redactionsArePermanent = true
type Database struct { type Database struct {
DB *sql.DB
EventDatabase
Cache caching.RoomServerCaches
Writer sqlutil.Writer
RoomsTable tables.Rooms
StateSnapshotTable tables.StateSnapshot
StateBlockTable tables.StateBlock
RoomAliasesTable tables.RoomAliases
InvitesTable tables.Invites
MembershipTable tables.Membership
PublishedTable tables.Published
Purge tables.Purge
GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
}
// EventDatabase contains all tables needed to work with events
type EventDatabase struct {
DB *sql.DB DB *sql.DB
Cache caching.RoomServerCaches Cache caching.RoomServerCaches
Writer sqlutil.Writer Writer sqlutil.Writer
@ -35,17 +52,24 @@ type Database struct {
EventJSONTable tables.EventJSON EventJSONTable tables.EventJSON
EventTypesTable tables.EventTypes EventTypesTable tables.EventTypes
EventStateKeysTable tables.EventStateKeys EventStateKeysTable tables.EventStateKeys
RoomsTable tables.Rooms
StateSnapshotTable tables.StateSnapshot
StateBlockTable tables.StateBlock
RoomAliasesTable tables.RoomAliases
PrevEventsTable tables.PreviousEvents PrevEventsTable tables.PreviousEvents
InvitesTable tables.Invites
MembershipTable tables.Membership
PublishedTable tables.Published
RedactionsTable tables.Redactions RedactionsTable tables.Redactions
Purge tables.Purge }
GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
// GetEventDatabase returns an EventDatabase to work with events only.
func (d *Database) GetEventDatabase() *EventDatabase {
db := &EventDatabase{
DB: d.DB,
Cache: d.Cache,
Writer: d.Writer,
EventsTable: d.EventsTable,
EventJSONTable: d.EventJSONTable,
EventTypesTable: d.EventTypesTable,
EventStateKeysTable: d.EventStateKeysTable,
PrevEventsTable: d.PrevEventsTable,
RedactionsTable: d.RedactionsTable,
}
return db
} }
func (d *Database) SupportsConcurrentRoomInputs() bool { func (d *Database) SupportsConcurrentRoomInputs() bool {
@ -58,13 +82,13 @@ func (d *Database) GetMembershipForHistoryVisibility(
return d.StateSnapshotTable.BulkSelectMembershipForHistoryVisibility(ctx, nil, userNID, roomInfo, eventIDs...) return d.StateSnapshotTable.BulkSelectMembershipForHistoryVisibility(ctx, nil, userNID, roomInfo, eventIDs...)
} }
func (d *Database) EventTypeNIDs( func (d *EventDatabase) EventTypeNIDs(
ctx context.Context, eventTypes []string, ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) { ) (map[string]types.EventTypeNID, error) {
return d.eventTypeNIDs(ctx, nil, eventTypes) return d.eventTypeNIDs(ctx, nil, eventTypes)
} }
func (d *Database) eventTypeNIDs( func (d *EventDatabase) eventTypeNIDs(
ctx context.Context, txn *sql.Tx, eventTypes []string, ctx context.Context, txn *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) { ) (map[string]types.EventTypeNID, error) {
result := make(map[string]types.EventTypeNID) result := make(map[string]types.EventTypeNID)
@ -91,7 +115,7 @@ func (d *Database) eventTypeNIDs(
return result, nil return result, nil
} }
func (d *Database) EventStateKeys( func (d *EventDatabase) EventStateKeys(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) { ) (map[types.EventStateKeyNID]string, error) {
result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs)) result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs))
@ -116,13 +140,13 @@ func (d *Database) EventStateKeys(
return result, nil return result, nil
} }
func (d *Database) EventStateKeyNIDs( func (d *EventDatabase) EventStateKeyNIDs(
ctx context.Context, eventStateKeys []string, ctx context.Context, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) { ) (map[string]types.EventStateKeyNID, error) {
return d.eventStateKeyNIDs(ctx, nil, eventStateKeys) return d.eventStateKeyNIDs(ctx, nil, eventStateKeys)
} }
func (d *Database) eventStateKeyNIDs( func (d *EventDatabase) eventStateKeyNIDs(
ctx context.Context, txn *sql.Tx, eventStateKeys []string, ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) { ) (map[string]types.EventStateKeyNID, error) {
result := make(map[string]types.EventStateKeyNID) result := make(map[string]types.EventStateKeyNID)
@ -174,7 +198,7 @@ func (d *Database) eventStateKeyNIDs(
return result, nil return result, nil
} }
func (d *Database) StateEntriesForEventIDs( func (d *EventDatabase) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string, excludeRejected bool, ctx context.Context, eventIDs []string, excludeRejected bool,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs, excludeRejected) return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs, excludeRejected)
@ -213,6 +237,17 @@ func (d *Database) stateEntriesForTuples(
return lists, nil return lists, nil
} }
func (d *Database) RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) {
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, []types.RoomNID{roomNID})
if err != nil {
return nil, err
}
if len(roomIDs) == 0 {
return nil, fmt.Errorf("room does not exist")
}
return d.roomInfo(ctx, nil, roomIDs[0])
}
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
return d.roomInfo(ctx, nil, roomID) return d.roomInfo(ctx, nil, roomID)
} }
@ -292,7 +327,7 @@ func (d *Database) addState(
return return
} }
func (d *Database) EventNIDs( func (d *EventDatabase) EventNIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) (map[string]types.EventMetadata, error) { ) (map[string]types.EventMetadata, error) {
return d.eventNIDs(ctx, nil, eventIDs, NoFilter) return d.eventNIDs(ctx, nil, eventIDs, NoFilter)
@ -305,7 +340,7 @@ const (
FilterUnsentOnly UnsentFilter = true FilterUnsentOnly UnsentFilter = true
) )
func (d *Database) eventNIDs( func (d *EventDatabase) eventNIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter, ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter,
) (map[string]types.EventMetadata, error) { ) (map[string]types.EventMetadata, error) {
switch filter { switch filter {
@ -318,7 +353,7 @@ func (d *Database) eventNIDs(
} }
} }
func (d *Database) SetState( func (d *EventDatabase) SetState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error { ) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
@ -326,19 +361,19 @@ func (d *Database) SetState(
}) })
} }
func (d *Database) StateAtEventIDs( func (d *EventDatabase) StateAtEventIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) { ) ([]types.StateAtEvent, error) {
return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs) return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs)
} }
func (d *Database) SnapshotNIDFromEventID( func (d *EventDatabase) SnapshotNIDFromEventID(
ctx context.Context, eventID string, ctx context.Context, eventID string,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
return d.snapshotNIDFromEventID(ctx, nil, eventID) return d.snapshotNIDFromEventID(ctx, nil, eventID)
} }
func (d *Database) snapshotNIDFromEventID( func (d *EventDatabase) snapshotNIDFromEventID(
ctx context.Context, txn *sql.Tx, eventID string, ctx context.Context, txn *sql.Tx, eventID string,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
_, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID) _, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID)
@ -351,17 +386,17 @@ func (d *Database) snapshotNIDFromEventID(
return stateNID, err return stateNID, err
} }
func (d *Database) EventIDs( func (d *EventDatabase) EventIDs(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) { ) (map[types.EventNID]string, error) {
return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
} }
func (d *Database) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) { func (d *EventDatabase) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) {
return d.eventsFromIDs(ctx, nil, roomNID, eventIDs, NoFilter) return d.eventsFromIDs(ctx, nil, roomInfo, eventIDs, NoFilter)
} }
func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventIDs []string, filter UnsentFilter) ([]types.Event, error) { func (d *EventDatabase) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, eventIDs []string, filter UnsentFilter) ([]types.Event, error) {
nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter) nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter)
if err != nil { if err != nil {
return nil, err return nil, err
@ -370,15 +405,9 @@ func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomNID types
var nids []types.EventNID var nids []types.EventNID
for _, nid := range nidMap { for _, nid := range nidMap {
nids = append(nids, nid.EventNID) nids = append(nids, nid.EventNID)
if roomNID != 0 && roomNID != nid.RoomNID {
logrus.Errorf("expected events from room %d, but also found %d", roomNID, nid.RoomNID)
}
if roomNID == 0 {
roomNID = nid.RoomNID
}
} }
return d.events(ctx, txn, roomNID, nids) return d.events(ctx, txn, roomInfo, nids)
} }
func (d *Database) LatestEventIDs( func (d *Database) LatestEventIDs(
@ -517,19 +546,18 @@ func (d *Database) GetInvitesForUser(
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID) return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
} }
func (d *Database) Events( func (d *EventDatabase) Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) {
ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID, return d.events(ctx, nil, roomInfo, eventNIDs)
) ([]types.Event, error) {
return d.events(ctx, nil, roomNID, eventNIDs)
} }
func (d *Database) events( func (d *EventDatabase) events(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, inputEventNIDs types.EventNIDs, ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, inputEventNIDs types.EventNIDs,
) ([]types.Event, error) { ) ([]types.Event, error) {
if roomNID == 0 { if roomInfo == nil {
// No need to go further, as we won't find any events for this room.
return nil, nil roomInfo = &types.RoomInfo{RoomVersion: version.DefaultRoomVersion()}
} }
sort.Sort(inputEventNIDs) sort.Sort(inputEventNIDs)
events := make(map[types.EventNID]*gomatrixserverlib.Event, len(inputEventNIDs)) events := make(map[types.EventNID]*gomatrixserverlib.Event, len(inputEventNIDs))
eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs)) eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs))
@ -566,31 +594,9 @@ func (d *Database) events(
eventIDs = map[types.EventNID]string{} eventIDs = map[types.EventNID]string{}
} }
var roomVersion gomatrixserverlib.RoomVersion
var fetchRoomVersion bool
var ok bool
var roomID string
if roomID, ok = d.Cache.GetRoomServerRoomID(roomNID); ok {
roomVersion, ok = d.Cache.GetRoomVersion(roomID)
if !ok {
fetchRoomVersion = true
}
}
if roomVersion == "" || fetchRoomVersion {
var dbRoomVersions map[types.RoomNID]gomatrixserverlib.RoomVersion
dbRoomVersions, err = d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, []types.RoomNID{roomNID})
if err != nil {
return nil, err
}
if roomVersion, ok = dbRoomVersions[roomNID]; !ok {
return nil, fmt.Errorf("unable to find roomversion for room %d", roomNID)
}
}
for _, eventJSON := range eventJSONs { for _, eventJSON := range eventJSONs {
events[eventJSON.EventNID], err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID( events[eventJSON.EventNID], err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID(
eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomVersion, eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomInfo.RoomVersion,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -660,8 +666,9 @@ func (d *Database) IsEventRejected(ctx context.Context, roomNID types.RoomNID, e
return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID) return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID)
} }
// GetOrCreateRoomNID gets or creates a new roomNID for the given event // GetOrCreateRoomNID gets or creates a new roomNID for the given event. Also returns a RoomInfo, which is only safe to use
func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (roomNID types.RoomNID, err error) { // with functions only needing a roomVersion or roomNID.
func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (roomNID types.RoomNID, roomInfo *types.RoomInfo, err error) {
// Get the default room version. If the client doesn't supply a room_version // Get the default room version. If the client doesn't supply a room_version
// then we will use our configured default to create the room. // then we will use our configured default to create the room.
// https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom
@ -670,7 +677,7 @@ func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserver
// room. // room.
var roomVersion gomatrixserverlib.RoomVersion var roomVersion gomatrixserverlib.RoomVersion
if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil {
return 0, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err) return 0, nil, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err)
} }
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion) roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion)
@ -679,7 +686,10 @@ func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserver
} }
return nil return nil
}) })
return roomNID, err return roomNID, &types.RoomInfo{
RoomVersion: roomVersion,
RoomNID: roomNID,
}, err
} }
func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) { func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) {
@ -710,25 +720,22 @@ func (d *Database) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKe
return eventStateKeyNID, nil return eventStateKeyNID, nil
} }
func (d *Database) StoreEvent( func (d *EventDatabase) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event, ctx context.Context, event *gomatrixserverlib.Event,
roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID,
authEventNIDs []types.EventNID, isRejected bool, authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { ) (types.EventNID, types.StateAtEvent, error) {
var ( var (
eventNID types.EventNID eventNID types.EventNID
stateNID types.StateSnapshotNID stateNID types.StateSnapshotNID
redactionEvent *gomatrixserverlib.Event
redactedEventID string
err error err error
) )
// Second writer is using the database-provided transaction, probably from the
// room updater, for easy roll-back if required.
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if eventNID, stateNID, err = d.EventsTable.InsertEvent( if eventNID, stateNID, err = d.EventsTable.InsertEvent(
ctx, ctx,
txn, txn,
roomNID, roomInfo.RoomNID,
eventTypeNID, eventTypeNID,
eventStateKeyNID, eventStateKeyNID,
event.EventID(), event.EventID(),
@ -751,16 +758,26 @@ func (d *Database) StoreEvent(
if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
} }
if !isRejected { // ignore rejected redaction events
redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, roomNID, eventNID, event) if prevEvents := event.PrevEvents(); len(prevEvents) > 0 {
if err != nil { // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of
return fmt.Errorf("d.handleRedactions: %w", err) // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This
// function only does SELECTs though so the created txn (at this point) is just a read txn like
// any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
// to do writes however then this will need to go inside `Writer.Do`.
// The following is a copy of RoomUpdater.StorePreviousEvents
for _, ref := range prevEvents {
if err = d.PrevEventsTable.InsertPreviousEvent(ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
} }
} }
}
return nil return nil
}) })
if err != nil { if err != nil {
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err) return 0, types.StateAtEvent{}, fmt.Errorf("d.Writer.Do: %w", err)
} }
// We should attempt to update the previous events table with any // We should attempt to update the previous events table with any
@ -768,33 +785,6 @@ func (d *Database) StoreEvent(
// events updater because it somewhat works as a mutex, ensuring // events updater because it somewhat works as a mutex, ensuring
// that there's a row-level lock on the latest room events (well, // that there's a row-level lock on the latest room events (well,
// on Postgres at least). // on Postgres at least).
if prevEvents := event.PrevEvents(); len(prevEvents) > 0 {
// Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of
// GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This
// function only does SELECTs though so the created txn (at this point) is just a read txn like
// any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
// to do writes however then this will need to go inside `Writer.Do`.
succeeded := false
var roomInfo *types.RoomInfo
roomInfo, err = d.roomInfo(ctx, nil, event.RoomID())
if err != nil {
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
}
if roomInfo == nil && len(prevEvents) > 0 {
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
}
var updater *RoomUpdater
updater, err = d.GetRoomUpdater(ctx, roomInfo)
if err != nil {
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err)
}
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil {
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err)
}
succeeded = true
}
return eventNID, types.StateAtEvent{ return eventNID, types.StateAtEvent{
BeforeStateSnapshotNID: stateNID, BeforeStateSnapshotNID: stateNID,
@ -805,7 +795,7 @@ func (d *Database) StoreEvent(
}, },
EventNID: eventNID, EventNID: eventNID,
}, },
}, redactionEvent, redactedEventID, err }, err
} }
func (d *Database) PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error { func (d *Database) PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error {
@ -893,7 +883,7 @@ func (d *Database) assignEventTypeNID(
return eventTypeNID, nil return eventTypeNID, nil
} }
func (d *Database) assignStateKeyNID( func (d *EventDatabase) assignStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string, ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) { ) (types.EventStateKeyNID, error) {
eventStateKeyNID, ok := d.Cache.GetEventStateKeyNID(eventStateKey) eventStateKeyNID, ok := d.Cache.GetEventStateKeyNID(eventStateKey)
@ -937,7 +927,7 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) (
return roomVersion, err return roomVersion, err
} }
// handleRedactions manages the redacted status of events. There's two cases to consider in order to comply with the spec: // MaybeRedactEvent manages the redacted status of events. There's two cases to consider in order to comply with the spec:
// "servers should not apply or send redactions to clients until both the redaction event and original event have been seen, and are valid." // "servers should not apply or send redactions to clients until both the redaction event and original event have been seen, and are valid."
// https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events // https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events
// These cases are: // These cases are:
@ -952,16 +942,23 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) (
// when loading events to determine whether to apply redactions. This keeps the hot-path of reading events quick as we don't need // when loading events to determine whether to apply redactions. This keeps the hot-path of reading events quick as we don't need
// to cross-reference with other tables when loading. // to cross-reference with other tables when loading.
// //
// Returns the redaction event and the event ID of the redacted event if this call resulted in a redaction. // Returns the redaction event and the redacted event if this call resulted in a redaction.
func (d *Database) handleRedactions( func (d *EventDatabase) MaybeRedactEvent(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNID types.EventNID, event *gomatrixserverlib.Event, ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, redactAllowed bool,
) (*gomatrixserverlib.Event, string, error) { ) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error) {
var err error var (
redactionEvent, redactedEvent *types.Event
err error
validated bool
ignoreRedaction bool
)
wErr := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil
if isRedactionEvent { if isRedactionEvent {
// an event which redacts itself should be ignored // an event which redacts itself should be ignored
if event.EventID() == event.Redacts() { if event.EventID() == event.Redacts() {
return nil, "", nil return nil
} }
err = d.RedactionsTable.InsertRedaction(ctx, txn, tables.RedactionInfo{ err = d.RedactionsTable.InsertRedaction(ctx, txn, tables.RedactionInfo{
@ -970,44 +967,30 @@ func (d *Database) handleRedactions(
RedactsEventID: event.Redacts(), RedactsEventID: event.Redacts(),
}) })
if err != nil { if err != nil {
return nil, "", fmt.Errorf("d.RedactionsTable.InsertRedaction: %w", err) return fmt.Errorf("d.RedactionsTable.InsertRedaction: %w", err)
} }
} }
redactionEvent, redactedEvent, validated, err := d.loadRedactionPair(ctx, txn, roomNID, eventNID, event) redactionEvent, redactedEvent, validated, err = d.loadRedactionPair(ctx, txn, roomInfo, eventNID, event)
if err != nil {
return nil, "", fmt.Errorf("d.loadRedactionPair: %w", err)
}
if validated || redactedEvent == nil || redactionEvent == nil {
// we've seen this redaction before or there is nothing to redact
return nil, "", nil
}
if redactedEvent.RoomID() != redactionEvent.RoomID() {
// redactions across rooms aren't allowed
return nil, "", nil
}
// Get the power level from the database, so we can verify the user is allowed to redact the event
powerLevels, err := d.GetStateEvent(ctx, event.RoomID(), gomatrixserverlib.MRoomPowerLevels, "")
if err != nil {
return nil, "", fmt.Errorf("d.GetStateEvent: %w", err)
}
if powerLevels == nil {
return nil, "", fmt.Errorf("unable to fetch m.room.power_levels event from database for room %s", event.RoomID())
}
pl, err := powerLevels.PowerLevels()
if err != nil {
return nil, "", fmt.Errorf("unable to get powerlevels for room: %w", err)
}
redactUser := pl.UserLevel(redactionEvent.Sender())
switch { switch {
case redactUser >= pl.Redact: case err != nil:
// The power level of the redaction events sender is greater than or equal to the redact level. return fmt.Errorf("d.loadRedactionPair: %w", err)
case redactedEvent.Sender() == redactionEvent.Sender(): case validated || redactedEvent == nil || redactionEvent == nil:
// The domain of the redaction events sender matches that of the original events sender. // we've seen this redaction before or there is nothing to redact
default: return nil
return nil, "", nil case redactedEvent.RoomID() != redactionEvent.RoomID():
// redactions across rooms aren't allowed
ignoreRedaction = true
return nil
}
// 1. The power level of the redaction events sender is greater than or equal to the redact level. (redactAllowed)
// 2. The domain of the redaction events sender matches that of the original events sender.
_, sender1, _ := gomatrixserverlib.SplitID('@', redactedEvent.Sender())
_, sender2, _ := gomatrixserverlib.SplitID('@', redactionEvent.Sender())
if !redactAllowed || sender1 != sender2 {
ignoreRedaction = true
return nil
} }
// mark the event as redacted // mark the event as redacted
@ -1017,30 +1000,37 @@ func (d *Database) handleRedactions(
err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent) err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("redactedEvent.SetUnsignedField: %w", err) return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
} }
// NOTSPEC: sytest relies on this unspecced field existing :( // NOTSPEC: sytest relies on this unspecced field existing :(
err = redactedEvent.SetUnsignedField("redacted_by", redactionEvent.EventID()) err = redactedEvent.SetUnsignedField("redacted_by", redactionEvent.EventID())
if err != nil { if err != nil {
return nil, "", fmt.Errorf("redactedEvent.SetUnsignedField: %w", err) return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
} }
// overwrite the eventJSON table // overwrite the eventJSON table
err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON()) err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON())
if err != nil { if err != nil {
return nil, "", fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
} }
err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true) err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true)
if err != nil { if err != nil {
err = fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err) return fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err)
} }
return nil
return redactionEvent.Event, redactedEvent.EventID(), err })
if wErr != nil {
return nil, nil, err
}
if ignoreRedaction || redactionEvent == nil || redactedEvent == nil {
return nil, nil, nil
}
return redactionEvent.Event, redactedEvent.Event, nil
} }
// loadRedactionPair returns both the redaction event and the redacted event, else nil. // loadRedactionPair returns both the redaction event and the redacted event, else nil.
func (d *Database) loadRedactionPair( func (d *EventDatabase) loadRedactionPair(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNID types.EventNID, event *gomatrixserverlib.Event, ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event,
) (*types.Event, *types.Event, bool, error) { ) (*types.Event, *types.Event, bool, error) {
var redactionEvent, redactedEvent *types.Event var redactionEvent, redactedEvent *types.Event
var info *tables.RedactionInfo var info *tables.RedactionInfo
@ -1072,16 +1062,16 @@ func (d *Database) loadRedactionPair(
} }
if isRedactionEvent { if isRedactionEvent {
redactedEvent = d.loadEvent(ctx, roomNID, info.RedactsEventID) redactedEvent = d.loadEvent(ctx, roomInfo, info.RedactsEventID)
} else { } else {
redactionEvent = d.loadEvent(ctx, roomNID, info.RedactionEventID) redactionEvent = d.loadEvent(ctx, roomInfo, info.RedactionEventID)
} }
return redactionEvent, redactedEvent, info.Validated, nil return redactionEvent, redactedEvent, info.Validated, nil
} }
// applyRedactions will redact events that have an `unsigned.redacted_because` field. // applyRedactions will redact events that have an `unsigned.redacted_because` field.
func (d *Database) applyRedactions(events []types.Event) { func (d *EventDatabase) applyRedactions(events []types.Event) {
for i := range events { for i := range events {
if result := gjson.GetBytes(events[i].Unsigned(), "redacted_because"); result.Exists() { if result := gjson.GetBytes(events[i].Unsigned(), "redacted_because"); result.Exists() {
events[i].Redact() events[i].Redact()
@ -1090,7 +1080,7 @@ func (d *Database) applyRedactions(events []types.Event) {
} }
// loadEvent loads a single event or returns nil on any problems/missing event // loadEvent loads a single event or returns nil on any problems/missing event
func (d *Database) loadEvent(ctx context.Context, roomNID types.RoomNID, eventID string) *types.Event { func (d *EventDatabase) loadEvent(ctx context.Context, roomInfo *types.RoomInfo, eventID string) *types.Event {
nids, err := d.EventNIDs(ctx, []string{eventID}) nids, err := d.EventNIDs(ctx, []string{eventID})
if err != nil { if err != nil {
return nil return nil
@ -1098,7 +1088,7 @@ func (d *Database) loadEvent(ctx context.Context, roomNID types.RoomNID, eventID
if len(nids) == 0 { if len(nids) == 0 {
return nil return nil
} }
evs, err := d.Events(ctx, roomNID, []types.EventNID{nids[eventID].EventNID}) evs, err := d.Events(ctx, roomInfo, []types.EventNID{nids[eventID].EventNID})
if err != nil { if err != nil {
return nil return nil
} }
@ -1144,7 +1134,7 @@ func (d *Database) GetHistoryVisibilityState(ctx context.Context, roomInfo *type
// If no event could be found, returns nil // If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error // If there was an issue during the retrieval, returns an error
func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) { func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) {
roomInfo, err := d.RoomInfo(ctx, roomID) roomInfo, err := d.roomInfo(ctx, nil, roomID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1209,7 +1199,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
// Same as GetStateEvent but returns all matching state events with this event type. Returns no error // Same as GetStateEvent but returns all matching state events with this event type. Returns no error
// if there are no events with this event type. // if there are no events with this event type.
func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) { func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) {
roomInfo, err := d.RoomInfo(ctx, roomID) roomInfo, err := d.roomInfo(ctx, nil, roomID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1340,7 +1330,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
eventNIDToVer := make(map[types.EventNID]gomatrixserverlib.RoomVersion) eventNIDToVer := make(map[types.EventNID]gomatrixserverlib.RoomVersion)
// TODO: This feels like this is going to be really slow... // TODO: This feels like this is going to be really slow...
for _, roomID := range roomIDs { for _, roomID := range roomIDs {
roomInfo, err2 := d.RoomInfo(ctx, roomID) roomInfo, err2 := d.roomInfo(ctx, nil, roomID)
if err2 != nil { if err2 != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to load room info for room %s : %w", roomID, err2) return nil, fmt.Errorf("GetBulkStateContent: failed to load room info for room %s : %w", roomID, err2)
} }

View file

@ -52,9 +52,11 @@ func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Dat
cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false) cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false)
evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache}
return &shared.Database{ return &shared.Database{
DB: db, DB: db,
EventStateKeysTable: stateKeyTable, EventDatabase: evDb,
MembershipTable: membershipTable, MembershipTable: membershipTable,
Writer: sqlutil.NewExclusiveWriter(), Writer: sqlutil.NewExclusiveWriter(),
Cache: cache, Cache: cache,

View file

@ -203,6 +203,8 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: db,
EventDatabase: shared.EventDatabase{
DB: db, DB: db,
Cache: cache, Cache: cache,
Writer: writer, Writer: writer,
@ -210,15 +212,18 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
EventTypesTable: eventTypes, EventTypesTable: eventTypes,
EventStateKeysTable: eventStateKeys, EventStateKeysTable: eventStateKeys,
EventJSONTable: eventJSON, EventJSONTable: eventJSON,
PrevEventsTable: prevEvents,
RedactionsTable: redactions,
},
Cache: cache,
Writer: writer,
RoomsTable: rooms, RoomsTable: rooms,
StateBlockTable: stateBlock, StateBlockTable: stateBlock,
StateSnapshotTable: stateSnapshot, StateSnapshotTable: stateSnapshot,
PrevEventsTable: prevEvents,
RoomAliasesTable: roomAliases, RoomAliasesTable: roomAliases,
InvitesTable: invites, InvitesTable: invites,
MembershipTable: membership, MembershipTable: membership,
PublishedTable: published, PublishedTable: published,
RedactionsTable: redactions,
GetRoomUpdaterFn: d.GetRoomUpdater, GetRoomUpdaterFn: d.GetRoomUpdater,
Purge: purge, Purge: purge,
} }

View file

@ -253,7 +253,7 @@ func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONRespo
var res MSC2836EventRelationshipsResponse var res MSC2836EventRelationshipsResponse
var returnEvents []*gomatrixserverlib.HeaderedEvent var returnEvents []*gomatrixserverlib.HeaderedEvent
// Can the user see (according to history visibility) event_id? If no, reject the request, else continue. // Can the user see (according to history visibility) event_id? If no, reject the request, else continue.
event := rc.getLocalEvent(rc.req.EventID) event := rc.getLocalEvent(rc.req.RoomID, rc.req.EventID)
if event == nil { if event == nil {
event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID) event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID)
} }
@ -592,7 +592,7 @@ func (rc *reqCtx) remoteEventRelationships(eventID string) *MSC2836EventRelation
// lookForEvent returns the event for the event ID given, by trying to query remote servers // lookForEvent returns the event for the event ID given, by trying to query remote servers
// if the event ID is unknown via /event_relationships. // if the event ID is unknown via /event_relationships.
func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent { func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent {
event := rc.getLocalEvent(eventID) event := rc.getLocalEvent(rc.req.RoomID, eventID)
if event == nil { if event == nil {
queryRes := rc.remoteEventRelationships(eventID) queryRes := rc.remoteEventRelationships(eventID)
if queryRes != nil { if queryRes != nil {
@ -622,9 +622,10 @@ func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent
return nil return nil
} }
func (rc *reqCtx) getLocalEvent(eventID string) *gomatrixserverlib.HeaderedEvent { func (rc *reqCtx) getLocalEvent(roomID, eventID string) *gomatrixserverlib.HeaderedEvent {
var queryEventsRes roomserver.QueryEventsByIDResponse var queryEventsRes roomserver.QueryEventsByIDResponse
err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{ err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{
RoomID: roomID,
EventIDs: []string{eventID}, EventIDs: []string{eventID},
}, &queryEventsRes) }, &queryEventsRes)
if err != nil { if err != nil {

View file

@ -212,6 +212,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
// Finally, work out if there are any more events missing. // Finally, work out if there are any more events missing.
if len(missingEventIDs) > 0 { if len(missingEventIDs) > 0 {
eventsReq := &api.QueryEventsByIDRequest{ eventsReq := &api.QueryEventsByIDRequest{
RoomID: ev.RoomID(),
EventIDs: missingEventIDs, EventIDs: missingEventIDs,
} }
eventsRes := &api.QueryEventsByIDResponse{} eventsRes := &api.QueryEventsByIDResponse{}

View file

@ -109,7 +109,7 @@ func GetMemberships(
} }
qryRes := &api.QueryEventsByIDResponse{} qryRes := &api.QueryEventsByIDResponse{}
if err := rsAPI.QueryEventsByID(req.Context(), &api.QueryEventsByIDRequest{EventIDs: eventIDs}, qryRes); err != nil { if err := rsAPI.QueryEventsByID(req.Context(), &api.QueryEventsByIDRequest{EventIDs: eventIDs, RoomID: roomID}, qryRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryEventsByID failed") util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryEventsByID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }