Deduplicate state entries

This commit is contained in:
Neil Alexander 2020-09-09 15:55:01 +01:00
parent 746e105888
commit a12a36078a
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
5 changed files with 48 additions and 14 deletions

View file

@ -193,19 +193,6 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer(
return fmt.Errorf("joinCtx.CheckSendJoinResponse: %w", err) return fmt.Errorf("joinCtx.CheckSendJoinResponse: %w", err)
} }
// It's possible that the remote server has included our new
// membership event in the room state in the send_join response,
// but if that's the case, then we'll get a duplicate state block
// error if we try to send that along with our own copy of the
// event. The simple way around this is just to prune the event
// from the state if we find it.
for i, ev := range respState.StateEvents {
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(userID) {
respState.StateEvents = append(respState.StateEvents[i:], respState.StateEvents[:i+1]...)
break
}
}
// If we successfully performed a send_join above then the other // If we successfully performed a send_join above then the other
// server now thinks we're a part of the room. Send the newly // server now thinks we're a part of the room. Send the newly
// returned state to the roomserver to update our local view. // returned state to the roomserver to update our local view.

View file

@ -36,7 +36,7 @@ func CheckAuthEvents(
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO: check for duplicate state keys here. authStateEntries = types.DeduplicateStateEntries(authStateEntries)
// Work out which of the state events we actually need. // Work out which of the state events we actually need.
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()}) stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()})

View file

@ -170,6 +170,7 @@ func (r *Inputer) calculateAndSetState(
if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
return fmt.Errorf("r.DB.StateEntriesForEventIDs: %w", err) return fmt.Errorf("r.DB.StateEntriesForEventIDs: %w", err)
} }
entries = types.DeduplicateStateEntries(entries)
if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil {
return fmt.Errorf("r.DB.AddState: %w", err) return fmt.Errorf("r.DB.AddState: %w", err)

View file

@ -16,6 +16,8 @@
package types package types
import ( import (
"sort"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -72,6 +74,24 @@ func (a StateEntry) LessThan(b StateEntry) bool {
return a.EventNID < b.EventNID return a.EventNID < b.EventNID
} }
// Deduplicate ensures that the latest NIDs are always presented in the case of duplicates.
func DeduplicateStateEntries(a []StateEntry) []StateEntry {
result := a
if len(a) < 2 {
return a
}
sort.SliceStable(a, func(i, j int) bool {
return a[i].LessThan(a[j])
})
for i := 0; i < len(result)-1; i++ {
if result[i].StateKeyTuple == result[i+1].StateKeyTuple {
result = append(result[:i], result[i+1:]...)
i--
}
}
return result
}
// StateAtEvent is the state before and after a matrix event. // StateAtEvent is the state before and after a matrix event.
type StateAtEvent struct { type StateAtEvent struct {
// Should this state overwrite the latest events and memberships of the room? // Should this state overwrite the latest events and memberships of the room?

View file

@ -0,0 +1,26 @@
package types
import (
"testing"
)
func TestDeduplicateStateEntries(t *testing.T) {
entries := []StateEntry{
{StateKeyTuple{1, 1}, 1},
{StateKeyTuple{1, 1}, 2},
{StateKeyTuple{1, 1}, 3},
{StateKeyTuple{2, 2}, 4},
{StateKeyTuple{2, 3}, 5},
{StateKeyTuple{3, 3}, 6},
}
expected := []EventNID{3, 4, 5, 6}
entries = DeduplicateStateEntries(entries)
if len(entries) != 4 {
t.Fatalf("Expected 4 entries, got %d entries", len(entries))
}
for i, v := range entries {
if v.EventNID != expected[i] {
t.Fatalf("Expected position %d to be %d but got %d", i, expected[i], v.EventNID)
}
}
}