1
0
Fork 0
mirror of https://github.com/matrix-org/dendrite.git synced 2025-03-28 20:44:27 -05:00

Only require room version instead of room info for db.Events() ()

This reduces the API requirements for the Events database to align with
what is actually required.
This commit is contained in:
devonh 2023-05-08 19:25:44 +00:00 committed by GitHub
parent 2b34f88fde
commit a49c9f01e2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 74 additions and 36 deletions

View file

@ -91,7 +91,7 @@ func main() {
} }
var eventEntries []types.Event var eventEntries []types.Event
eventEntries, err = roomserverDB.Events(ctx, roomInfo, eventNIDs) eventEntries, err = roomserverDB.Events(ctx, roomInfo.RoomVersion, eventNIDs)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -149,7 +149,7 @@ func main() {
} }
fmt.Println("Fetching", len(eventNIDMap), "state events") fmt.Println("Fetching", len(eventNIDMap), "state events")
eventEntries, err := roomserverDB.Events(ctx, roomInfo, eventNIDs) eventEntries, err := roomserverDB.Events(ctx, roomInfo.RoomVersion, eventNIDs)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View file

@ -219,7 +219,12 @@ func loadAuthEvents(
eventNIDs = append(eventNIDs, eventNID) eventNIDs = append(eventNIDs, eventNID)
} }
} }
if result.events, err = db.Events(ctx, roomInfo, eventNIDs); err != nil {
if roomInfo == nil {
err = types.ErrorInvalidRoomInfo
return
}
if result.events, err = db.Events(ctx, roomInfo.RoomVersion, eventNIDs); err != nil {
return return
} }
roomID := "" roomID := ""

View file

@ -86,7 +86,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
return false, err return false, err
} }
events, err := db.Events(ctx, info, eventNIDs) events, err := db.Events(ctx, info.RoomVersion, eventNIDs)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -183,7 +183,10 @@ func GetMembershipsAtState(
util.Unique(eventNIDs) util.Unique(eventNIDs)
// Get all of the events in this state // Get all of the events in this state
stateEvents, err := db.Events(ctx, roomInfo, eventNIDs) if roomInfo == nil {
return nil, types.ErrorInvalidRoomInfo
}
stateEvents, err := db.Events(ctx, roomInfo.RoomVersion, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -235,7 +238,10 @@ func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types
func LoadEvents( func LoadEvents(
ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, eventNIDs []types.EventNID, ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, eventNIDs []types.EventNID,
) ([]gomatrixserverlib.PDU, error) { ) ([]gomatrixserverlib.PDU, error) {
stateEvents, err := db.Events(ctx, roomInfo, eventNIDs) if roomInfo == nil {
return nil, types.ErrorInvalidRoomInfo
}
stateEvents, err := db.Events(ctx, roomInfo.RoomVersion, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -805,7 +805,10 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
return err return err
} }
memberEvents, err := r.DB.Events(ctx, roomInfo, membershipNIDs) if roomInfo == nil {
return types.ErrorInvalidRoomInfo
}
memberEvents, err := r.DB.Events(ctx, roomInfo.RoomVersion, membershipNIDs)
if err != nil { if err != nil {
return err return err
} }

View file

@ -55,7 +55,7 @@ func (r *Inputer) updateMemberships(
// Load the event JSON so we can look up the "membership" key. // Load the event JSON so we can look up the "membership" key.
// TODO: Maybe add a membership key to the events table so we can load that // TODO: Maybe add a membership key to the events table so we can load that
// key without having to load the entire event JSON? // key without having to load the entire event JSON?
events, err := updater.Events(ctx, nil, eventNIDs) events, err := updater.Events(ctx, "", eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -398,7 +398,10 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
for _, entry := range stateEntries { for _, entry := range stateEntries {
stateEventNIDs = append(stateEventNIDs, entry.EventNID) stateEventNIDs = append(stateEventNIDs, entry.EventNID)
} }
stateEvents, err := t.db.Events(ctx, t.roomInfo, stateEventNIDs) if t.roomInfo == nil {
return nil
}
stateEvents, err := t.db.Events(ctx, t.roomInfo.RoomVersion, stateEventNIDs)
if err != nil { if err != nil {
t.log.WithError(err).Warnf("failed to load state events locally") t.log.WithError(err).Warnf("failed to load state events locally")
return nil return nil

View file

@ -60,7 +60,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
return nil, err return nil, err
} }
memberEvents, err := r.DB.Events(ctx, roomInfo, memberNIDs) memberEvents, err := r.DB.Events(ctx, roomInfo.RoomVersion, memberNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -533,7 +533,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
roomNID = nid.RoomNID roomNID = nid.RoomNID
} }
} }
eventsWithNids, err := b.db.Events(ctx, &b.roomInfo, eventNIDs) eventsWithNids, err := b.db.Events(ctx, b.roomInfo.RoomVersion, eventNIDs)
if err != nil { if err != nil {
logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events")
return nil, err return nil, err
@ -563,7 +563,10 @@ func joinEventsFromHistoryVisibility(
} }
// Get all of the events in this state // Get all of the events in this state
stateEvents, err := db.Events(ctx, roomInfo, eventNIDs) if roomInfo == nil {
return nil, gomatrixserverlib.HistoryVisibilityJoined, types.ErrorInvalidRoomInfo
}
stateEvents, err := db.Events(ctx, roomInfo.RoomVersion, eventNIDs)
if err != nil { if err != nil {
// even though the default should be shared, restricting the visibility to joined // even though the default should be shared, restricting the visibility to joined
// feels more secure here. // feels more secure here.
@ -586,7 +589,7 @@ func joinEventsFromHistoryVisibility(
if err != nil { if err != nil {
return nil, visibility, err return nil, visibility, err
} }
evs, err := db.Events(ctx, roomInfo, joinEventNIDs) evs, err := db.Events(ctx, roomInfo.RoomVersion, joinEventNIDs)
return evs, visibility, err return evs, visibility, err
} }

View file

@ -269,7 +269,10 @@ func buildInviteStrippedState(
for _, stateNID := range stateEntries { for _, stateNID := range stateEntries {
stateNIDs = append(stateNIDs, stateNID.EventNID) stateNIDs = append(stateNIDs, stateNID.EventNID)
} }
stateEvents, err := db.Events(ctx, info, stateNIDs) if info == nil {
return nil, types.ErrorInvalidRoomInfo
}
stateEvents, err := db.Events(ctx, info.RoomVersion, stateNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -212,7 +212,7 @@ func (r *Queryer) QueryMembershipForUser(
response.IsInRoom = stillInRoom response.IsInRoom = stillInRoom
response.HasBeenInRoom = true response.HasBeenInRoom = true
evs, err := r.DB.Events(ctx, info, []types.EventNID{membershipEventNID}) evs, err := r.DB.Events(ctx, info.RoomVersion, []types.EventNID{membershipEventNID})
if err != nil { if err != nil {
return err return err
} }
@ -344,7 +344,7 @@ func (r *Queryer) QueryMembershipsForRoom(
} }
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
} }
events, err = r.DB.Events(ctx, info, eventNIDs) events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.Events: %w", err) return fmt.Errorf("r.DB.Events: %w", err)
} }
@ -383,7 +383,7 @@ func (r *Queryer) QueryMembershipsForRoom(
return err return err
} }
events, err = r.DB.Events(ctx, info, eventNIDs) events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs)
} else { } else {
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID) stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID)
if err != nil { if err != nil {
@ -967,7 +967,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
// For each of the joined users, let's see if we can get a valid // For each of the joined users, let's see if we can get a valid
// membership event. // membership event.
for _, joinNID := range joinNIDs { for _, joinNID := range joinNIDs {
events, err := r.DB.Events(ctx, roomInfo, []types.EventNID{joinNID}) events, err := r.DB.Events(ctx, roomInfo.RoomVersion, []types.EventNID{joinNID})
if err != nil || len(events) != 1 { if err != nil || len(events) != 1 {
continue continue
} }

View file

@ -571,7 +571,7 @@ func TestRedaction(t *testing.T) {
if ev.Type() == spec.MRoomRedaction { if ev.Type() == spec.MRoomRedaction {
nids, err := db.EventNIDs(ctx, []string{ev.Redacts()}) nids, err := db.EventNIDs(ctx, []string{ev.Redacts()})
assert.NoError(t, err) assert.NoError(t, err)
evs, err := db.Events(ctx, roomInfo, []types.EventNID{nids[ev.Redacts()].EventNID}) evs, err := db.Events(ctx, roomInfo.RoomVersion, []types.EventNID{nids[ev.Redacts()].EventNID})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 1, len(evs)) assert.Equal(t, 1, len(evs))
assert.Equal(t, tc.wantRedacted, evs[0].Redacted()) assert.Equal(t, tc.wantRedacted, evs[0].Redacted())

View file

@ -41,7 +41,7 @@ type StateResolutionStorage interface {
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
} }
@ -85,7 +85,10 @@ func (p *StateResolution) Resolve(ctx context.Context, eventID string) (*gomatri
return nil, fmt.Errorf("unable to find power level event") return nil, fmt.Errorf("unable to find power level event")
} }
events, err := p.db.Events(ctx, p.roomInfo, []types.EventNID{plNID}) if p.roomInfo == nil {
return nil, types.ErrorInvalidRoomInfo
}
events, err := p.db.Events(ctx, p.roomInfo.RoomVersion, []types.EventNID{plNID})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1134,7 +1137,11 @@ func (v *StateResolution) loadStateEvents(
eventNIDs = append(eventNIDs, entry.EventNID) eventNIDs = append(eventNIDs, entry.EventNID)
} }
} }
events, err := v.db.Events(ctx, v.roomInfo, eventNIDs)
if v.roomInfo == nil {
return nil, nil, types.ErrorInvalidRoomInfo
}
events, err := v.db.Events(ctx, v.roomInfo.RoomVersion, eventNIDs)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View file

@ -72,7 +72,7 @@ type Database interface {
) ([]types.StateEntryList, error) ) ([]types.StateEntryList, error)
// Look up the Events for a list of numeric event IDs. // Look up the Events for a list of numeric event IDs.
// Returns a sorted list of events. // Returns a sorted list of events.
Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
// Look up snapshot NID for an event ID string // Look up snapshot NID for an event ID string
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
@ -224,7 +224,7 @@ type EventDatabase interface {
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
// MaybeRedactEvent returns the redaction event and the redacted event if this call resulted in a redaction, else an error // MaybeRedactEvent returns the redaction event and the redacted event if this call resulted in a redaction, else an error
// (nil if there was nothing to do) // (nil if there was nothing to do)
MaybeRedactEvent( MaybeRedactEvent(

View file

@ -116,8 +116,11 @@ func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEvent
}) })
} }
func (u *RoomUpdater) Events(ctx context.Context, _ *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) { func (u *RoomUpdater) Events(ctx context.Context, _ gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) {
return u.d.events(ctx, u.txn, u.roomInfo, eventNIDs) if u.roomInfo == nil {
return nil, types.ErrorInvalidRoomInfo
}
return u.d.events(ctx, u.txn, u.roomInfo.RoomVersion, eventNIDs)
} }
func (u *RoomUpdater) SnapshotNIDFromEventID( func (u *RoomUpdater) SnapshotNIDFromEventID(

View file

@ -392,7 +392,10 @@ func (d *EventDatabase) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomInfo
nids = append(nids, nid.EventNID) nids = append(nids, nid.EventNID)
} }
return d.events(ctx, txn, roomInfo, nids) if roomInfo == nil {
return nil, types.ErrorInvalidRoomInfo
}
return d.events(ctx, txn, roomInfo.RoomVersion, nids)
} }
func (d *Database) LatestEventIDs( func (d *Database) LatestEventIDs(
@ -531,17 +534,13 @@ func (d *Database) GetInvitesForUser(
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID) return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
} }
func (d *EventDatabase) Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) { func (d *EventDatabase) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) {
return d.events(ctx, nil, roomInfo, eventNIDs) return d.events(ctx, nil, roomVersion, eventNIDs)
} }
func (d *EventDatabase) events( func (d *EventDatabase) events(
ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, inputEventNIDs types.EventNIDs, ctx context.Context, txn *sql.Tx, roomVersion gomatrixserverlib.RoomVersion, inputEventNIDs types.EventNIDs,
) ([]types.Event, error) { ) ([]types.Event, error) {
if roomInfo == nil { // this should never happen
return nil, fmt.Errorf("unable to parse events without roomInfo")
}
sort.Sort(inputEventNIDs) sort.Sort(inputEventNIDs)
events := make(map[types.EventNID]gomatrixserverlib.PDU, len(inputEventNIDs)) events := make(map[types.EventNID]gomatrixserverlib.PDU, len(inputEventNIDs))
eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs)) eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs))
@ -579,7 +578,7 @@ func (d *EventDatabase) events(
eventIDs = map[types.EventNID]string{} eventIDs = map[types.EventNID]string{}
} }
verImpl, err := gomatrixserverlib.GetRoomVersion(roomInfo.RoomVersion) verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1107,7 +1106,10 @@ func (d *EventDatabase) loadEvent(ctx context.Context, roomInfo *types.RoomInfo,
if len(nids) == 0 { if len(nids) == 0 {
return nil return nil
} }
evs, err := d.Events(ctx, roomInfo, []types.EventNID{nids[eventID].EventNID}) if roomInfo == nil {
return nil
}
evs, err := d.Events(ctx, roomInfo.RoomVersion, []types.EventNID{nids[eventID].EventNID})
if err != nil { if err != nil {
return nil return nil
} }

View file

@ -17,6 +17,7 @@ package types
import ( import (
"encoding/json" "encoding/json"
"fmt"
"sort" "sort"
"strings" "strings"
"sync" "sync"
@ -328,3 +329,5 @@ func (r *RoomInfo) CopyFrom(r2 *RoomInfo) {
r.stateSnapshotNID = r2.stateSnapshotNID r.stateSnapshotNID = r2.stateSnapshotNID
r.isStub = r2.isStub r.isStub = r2.isStub
} }
var ErrorInvalidRoomInfo = fmt.Errorf("room info is invalid")