diff --git a/federationsender/internal/perform.go b/federationsender/internal/perform.go index 30950f528..f9c295644 100644 --- a/federationsender/internal/perform.go +++ b/federationsender/internal/perform.go @@ -193,19 +193,6 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer( return fmt.Errorf("joinCtx.CheckSendJoinResponse: %w", err) } - // It's possible that the remote server has included our new - // membership event in the room state in the send_join response, - // but if that's the case, then we'll get a duplicate state block - // error if we try to send that along with our own copy of the - // event. The simple way around this is just to prune the event - // from the state if we find it. - for i, ev := range respState.StateEvents { - if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(userID) { - respState.StateEvents = append(respState.StateEvents[i:], respState.StateEvents[:i+1]...) - break - } - } - // If we successfully performed a send_join above then the other // server now thinks we're a part of the room. Send the newly // returned state to the roomserver to update our local view. diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 060f0a0e9..524a54510 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -36,7 +36,7 @@ func CheckAuthEvents( if err != nil { return nil, err } - // TODO: check for duplicate state keys here. + authStateEntries = types.DeduplicateStateEntries(authStateEntries) // Work out which of the state events we actually need. stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()}) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 19089cbd3..ec429b2dd 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -170,6 +170,7 @@ func (r *Inputer) calculateAndSetState( if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { return fmt.Errorf("r.DB.StateEntriesForEventIDs: %w", err) } + entries = types.DeduplicateStateEntries(entries) if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { return fmt.Errorf("r.DB.AddState: %w", err) diff --git a/roomserver/types/types.go b/roomserver/types/types.go index 60f4b0fd5..f573ce3e0 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -16,6 +16,8 @@ package types import ( + "sort" + "github.com/matrix-org/gomatrixserverlib" ) @@ -72,6 +74,24 @@ func (a StateEntry) LessThan(b StateEntry) bool { return a.EventNID < b.EventNID } +// Deduplicate ensures that the latest NIDs are always presented in the case of duplicates. +func DeduplicateStateEntries(a []StateEntry) []StateEntry { + result := a + if len(a) < 2 { + return a + } + sort.SliceStable(a, func(i, j int) bool { + return a[i].LessThan(a[j]) + }) + for i := 0; i < len(result)-1; i++ { + if result[i].StateKeyTuple == result[i+1].StateKeyTuple { + result = append(result[:i], result[i+1:]...) + i-- + } + } + return result +} + // StateAtEvent is the state before and after a matrix event. type StateAtEvent struct { // Should this state overwrite the latest events and memberships of the room? diff --git a/roomserver/types/types_test.go b/roomserver/types/types_test.go new file mode 100644 index 000000000..b1e84b821 --- /dev/null +++ b/roomserver/types/types_test.go @@ -0,0 +1,26 @@ +package types + +import ( + "testing" +) + +func TestDeduplicateStateEntries(t *testing.T) { + entries := []StateEntry{ + {StateKeyTuple{1, 1}, 1}, + {StateKeyTuple{1, 1}, 2}, + {StateKeyTuple{1, 1}, 3}, + {StateKeyTuple{2, 2}, 4}, + {StateKeyTuple{2, 3}, 5}, + {StateKeyTuple{3, 3}, 6}, + } + expected := []EventNID{3, 4, 5, 6} + entries = DeduplicateStateEntries(entries) + if len(entries) != 4 { + t.Fatalf("Expected 4 entries, got %d entries", len(entries)) + } + for i, v := range entries { + if v.EventNID != expected[i] { + t.Fatalf("Expected position %d to be %d but got %d", i, expected[i], v.EventNID) + } + } +}