From 1d6b620081abcd98f468e69ca9cb2eb3d3610d4e Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 9 Feb 2022 12:00:38 +0000 Subject: [PATCH] Review comments --- roomserver/internal/input/input_events.go | 8 ++++---- roomserver/internal/query/query.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 785a74840..4de33db46 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -255,20 +255,20 @@ func (r *Inputer) processRoomEvent( hadEvents: map[string]bool{}, haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{}, } - if override, err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { + if stateSnapshot, err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { // Something went wrong with retrieving the missing state, so we can't // really do anything with the event other than reject it at this point. isRejected = true rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err) - } else if override != nil { + } else if stateSnapshot != nil { // We retrieved some state and we ended up having to call /state_ids for // the new event in question (probably because closing the gap by using // /get_missing_events didn't do what we hoped) so we'll instead overwrite // the state snapshot with the newly resolved state. missingPrev = false input.HasState = true - input.StateEventIDs = make([]string, 0, len(override.StateEvents)) - for _, e := range override.StateEvents { + input.StateEventIDs = make([]string, 0, len(stateSnapshot.StateEvents)) + for _, e := range stateSnapshot.StateEvents { input.StateEventIDs = append(input.StateEventIDs, e.EventID()) } } else { diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 112346472..fa7196fdd 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -150,7 +150,7 @@ func (r *Queryer) QueryMissingAuthPrevEvents( for _, prevEventID := range request.PrevEventIDs { state, err := r.DB.StateAtEventIDs(ctx, []string{prevEventID}) - if err != nil || len(state) == 0 || state[0].BeforeStateSnapshotNID == 0 { + if err != nil || len(state) == 0 || (state[0].EventTypeNID != types.MRoomCreateNID && state[0].EventStateKeyNID == types.EmptyStateKeyNID && state[0].BeforeStateSnapshotNID == 0) { response.MissingPrevEventIDs = append(response.MissingPrevEventIDs, prevEventID) } }