Only require room version instead of room info for db.Events() ()

This reduces the API requirements for the Events database to align with
what is actually required.
This commit is contained in:
devonh 2023-05-08 19:25:44 +00:00 committed by GitHub
parent 2b34f88fde
commit a49c9f01e2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 74 additions and 36 deletions

View file

@ -91,7 +91,7 @@ func main() {
}
var eventEntries []types.Event
eventEntries, err = roomserverDB.Events(ctx, roomInfo, eventNIDs)
eventEntries, err = roomserverDB.Events(ctx, roomInfo.RoomVersion, eventNIDs)
if err != nil {
panic(err)
}
@ -149,7 +149,7 @@ func main() {
}
fmt.Println("Fetching", len(eventNIDMap), "state events")
eventEntries, err := roomserverDB.Events(ctx, roomInfo, eventNIDs)
eventEntries, err := roomserverDB.Events(ctx, roomInfo.RoomVersion, eventNIDs)
if err != nil {
panic(err)
}

View file

@ -219,7 +219,12 @@ func loadAuthEvents(
eventNIDs = append(eventNIDs, eventNID)
}
}
if result.events, err = db.Events(ctx, roomInfo, eventNIDs); err != nil {
if roomInfo == nil {
err = types.ErrorInvalidRoomInfo
return
}
if result.events, err = db.Events(ctx, roomInfo.RoomVersion, eventNIDs); err != nil {
return
}
roomID := ""

View file

@ -86,7 +86,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
return false, err
}
events, err := db.Events(ctx, info, eventNIDs)
events, err := db.Events(ctx, info.RoomVersion, eventNIDs)
if err != nil {
return false, err
}
@ -183,7 +183,10 @@ func GetMembershipsAtState(
util.Unique(eventNIDs)
// Get all of the events in this state
stateEvents, err := db.Events(ctx, roomInfo, eventNIDs)
if roomInfo == nil {
return nil, types.ErrorInvalidRoomInfo
}
stateEvents, err := db.Events(ctx, roomInfo.RoomVersion, eventNIDs)
if err != nil {
return nil, err
}
@ -235,7 +238,10 @@ func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types
func LoadEvents(
ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, eventNIDs []types.EventNID,
) ([]gomatrixserverlib.PDU, error) {
stateEvents, err := db.Events(ctx, roomInfo, eventNIDs)
if roomInfo == nil {
return nil, types.ErrorInvalidRoomInfo
}
stateEvents, err := db.Events(ctx, roomInfo.RoomVersion, eventNIDs)
if err != nil {
return nil, err
}

View file

@ -805,7 +805,10 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
return err
}
memberEvents, err := r.DB.Events(ctx, roomInfo, membershipNIDs)
if roomInfo == nil {
return types.ErrorInvalidRoomInfo
}
memberEvents, err := r.DB.Events(ctx, roomInfo.RoomVersion, membershipNIDs)
if err != nil {
return err
}

View file

@ -55,7 +55,7 @@ func (r *Inputer) updateMemberships(
// Load the event JSON so we can look up the "membership" key.
// TODO: Maybe add a membership key to the events table so we can load that
// key without having to load the entire event JSON?
events, err := updater.Events(ctx, nil, eventNIDs)
events, err := updater.Events(ctx, "", eventNIDs)
if err != nil {
return nil, err
}

View file

@ -398,7 +398,10 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
for _, entry := range stateEntries {
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
}
stateEvents, err := t.db.Events(ctx, t.roomInfo, stateEventNIDs)
if t.roomInfo == nil {
return nil
}
stateEvents, err := t.db.Events(ctx, t.roomInfo.RoomVersion, stateEventNIDs)
if err != nil {
t.log.WithError(err).Warnf("failed to load state events locally")
return nil

View file

@ -60,7 +60,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
return nil, err
}
memberEvents, err := r.DB.Events(ctx, roomInfo, memberNIDs)
memberEvents, err := r.DB.Events(ctx, roomInfo.RoomVersion, memberNIDs)
if err != nil {
return nil, err
}

View file

@ -533,7 +533,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
roomNID = nid.RoomNID
}
}
eventsWithNids, err := b.db.Events(ctx, &b.roomInfo, eventNIDs)
eventsWithNids, err := b.db.Events(ctx, b.roomInfo.RoomVersion, eventNIDs)
if err != nil {
logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events")
return nil, err
@ -563,7 +563,10 @@ func joinEventsFromHistoryVisibility(
}
// Get all of the events in this state
stateEvents, err := db.Events(ctx, roomInfo, eventNIDs)
if roomInfo == nil {
return nil, gomatrixserverlib.HistoryVisibilityJoined, types.ErrorInvalidRoomInfo
}
stateEvents, err := db.Events(ctx, roomInfo.RoomVersion, eventNIDs)
if err != nil {
// even though the default should be shared, restricting the visibility to joined
// feels more secure here.
@ -586,7 +589,7 @@ func joinEventsFromHistoryVisibility(
if err != nil {
return nil, visibility, err
}
evs, err := db.Events(ctx, roomInfo, joinEventNIDs)
evs, err := db.Events(ctx, roomInfo.RoomVersion, joinEventNIDs)
return evs, visibility, err
}

View file

@ -269,7 +269,10 @@ func buildInviteStrippedState(
for _, stateNID := range stateEntries {
stateNIDs = append(stateNIDs, stateNID.EventNID)
}
stateEvents, err := db.Events(ctx, info, stateNIDs)
if info == nil {
return nil, types.ErrorInvalidRoomInfo
}
stateEvents, err := db.Events(ctx, info.RoomVersion, stateNIDs)
if err != nil {
return nil, err
}

View file

@ -212,7 +212,7 @@ func (r *Queryer) QueryMembershipForUser(
response.IsInRoom = stillInRoom
response.HasBeenInRoom = true
evs, err := r.DB.Events(ctx, info, []types.EventNID{membershipEventNID})
evs, err := r.DB.Events(ctx, info.RoomVersion, []types.EventNID{membershipEventNID})
if err != nil {
return err
}
@ -344,7 +344,7 @@ func (r *Queryer) QueryMembershipsForRoom(
}
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
}
events, err = r.DB.Events(ctx, info, eventNIDs)
events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs)
if err != nil {
return fmt.Errorf("r.DB.Events: %w", err)
}
@ -383,7 +383,7 @@ func (r *Queryer) QueryMembershipsForRoom(
return err
}
events, err = r.DB.Events(ctx, info, eventNIDs)
events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs)
} else {
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID)
if err != nil {
@ -967,7 +967,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
// For each of the joined users, let's see if we can get a valid
// membership event.
for _, joinNID := range joinNIDs {
events, err := r.DB.Events(ctx, roomInfo, []types.EventNID{joinNID})
events, err := r.DB.Events(ctx, roomInfo.RoomVersion, []types.EventNID{joinNID})
if err != nil || len(events) != 1 {
continue
}

View file

@ -571,7 +571,7 @@ func TestRedaction(t *testing.T) {
if ev.Type() == spec.MRoomRedaction {
nids, err := db.EventNIDs(ctx, []string{ev.Redacts()})
assert.NoError(t, err)
evs, err := db.Events(ctx, roomInfo, []types.EventNID{nids[ev.Redacts()].EventNID})
evs, err := db.Events(ctx, roomInfo.RoomVersion, []types.EventNID{nids[ev.Redacts()].EventNID})
assert.NoError(t, err)
assert.Equal(t, 1, len(evs))
assert.Equal(t, tc.wantRedacted, evs[0].Redacted())

View file

@ -41,7 +41,7 @@ type StateResolutionStorage interface {
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error)
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
}
@ -85,7 +85,10 @@ func (p *StateResolution) Resolve(ctx context.Context, eventID string) (*gomatri
return nil, fmt.Errorf("unable to find power level event")
}
events, err := p.db.Events(ctx, p.roomInfo, []types.EventNID{plNID})
if p.roomInfo == nil {
return nil, types.ErrorInvalidRoomInfo
}
events, err := p.db.Events(ctx, p.roomInfo.RoomVersion, []types.EventNID{plNID})
if err != nil {
return nil, err
}
@ -1134,7 +1137,11 @@ func (v *StateResolution) loadStateEvents(
eventNIDs = append(eventNIDs, entry.EventNID)
}
}
events, err := v.db.Events(ctx, v.roomInfo, eventNIDs)
if v.roomInfo == nil {
return nil, nil, types.ErrorInvalidRoomInfo
}
events, err := v.db.Events(ctx, v.roomInfo.RoomVersion, eventNIDs)
if err != nil {
return nil, nil, err
}

View file

@ -72,7 +72,7 @@ type Database interface {
) ([]types.StateEntryList, error)
// Look up the Events for a list of numeric event IDs.
// Returns a sorted list of events.
Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error)
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
// Look up snapshot NID for an event ID string
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
@ -224,7 +224,7 @@ type EventDatabase interface {
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error)
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
// MaybeRedactEvent returns the redaction event and the redacted event if this call resulted in a redaction, else an error
// (nil if there was nothing to do)
MaybeRedactEvent(

View file

@ -116,8 +116,11 @@ func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEvent
})
}
func (u *RoomUpdater) Events(ctx context.Context, _ *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) {
return u.d.events(ctx, u.txn, u.roomInfo, eventNIDs)
func (u *RoomUpdater) Events(ctx context.Context, _ gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) {
if u.roomInfo == nil {
return nil, types.ErrorInvalidRoomInfo
}
return u.d.events(ctx, u.txn, u.roomInfo.RoomVersion, eventNIDs)
}
func (u *RoomUpdater) SnapshotNIDFromEventID(

View file

@ -392,7 +392,10 @@ func (d *EventDatabase) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomInfo
nids = append(nids, nid.EventNID)
}
return d.events(ctx, txn, roomInfo, nids)
if roomInfo == nil {
return nil, types.ErrorInvalidRoomInfo
}
return d.events(ctx, txn, roomInfo.RoomVersion, nids)
}
func (d *Database) LatestEventIDs(
@ -531,17 +534,13 @@ func (d *Database) GetInvitesForUser(
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
}
func (d *EventDatabase) Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) {
return d.events(ctx, nil, roomInfo, eventNIDs)
func (d *EventDatabase) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) {
return d.events(ctx, nil, roomVersion, eventNIDs)
}
func (d *EventDatabase) events(
ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, inputEventNIDs types.EventNIDs,
ctx context.Context, txn *sql.Tx, roomVersion gomatrixserverlib.RoomVersion, inputEventNIDs types.EventNIDs,
) ([]types.Event, error) {
if roomInfo == nil { // this should never happen
return nil, fmt.Errorf("unable to parse events without roomInfo")
}
sort.Sort(inputEventNIDs)
events := make(map[types.EventNID]gomatrixserverlib.PDU, len(inputEventNIDs))
eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs))
@ -579,7 +578,7 @@ func (d *EventDatabase) events(
eventIDs = map[types.EventNID]string{}
}
verImpl, err := gomatrixserverlib.GetRoomVersion(roomInfo.RoomVersion)
verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion)
if err != nil {
return nil, err
}
@ -1107,7 +1106,10 @@ func (d *EventDatabase) loadEvent(ctx context.Context, roomInfo *types.RoomInfo,
if len(nids) == 0 {
return nil
}
evs, err := d.Events(ctx, roomInfo, []types.EventNID{nids[eventID].EventNID})
if roomInfo == nil {
return nil
}
evs, err := d.Events(ctx, roomInfo.RoomVersion, []types.EventNID{nids[eventID].EventNID})
if err != nil {
return nil
}

View file

@ -17,6 +17,7 @@ package types
import (
"encoding/json"
"fmt"
"sort"
"strings"
"sync"
@ -328,3 +329,5 @@ func (r *RoomInfo) CopyFrom(r2 *RoomInfo) {
r.stateSnapshotNID = r2.stateSnapshotNID
r.isStub = r2.isStub
}
var ErrorInvalidRoomInfo = fmt.Errorf("room info is invalid")