From b8a97b6ee083d1047d09f99daf2594d0821705ca Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 23 Feb 2022 10:45:07 +0000 Subject: [PATCH 01/24] Update to matrix-org/pinecone@0f0afd1a46aabf7b1e48cb607bc9fa6a083aade6 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 2316096df..b104b4192 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,7 @@ require ( github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed - github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf + github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 github.com/morikuni/aec v1.0.0 // indirect diff --git a/go.sum b/go.sum index e79015e51..23ffe8b85 100644 --- a/go.sum +++ b/go.sum @@ -985,8 +985,8 @@ github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed h1:R8EiLWArq7KT96DrUq1xq9scPh8vLwKKeCTnORPyjhU= github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= -github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf h1:/nqfHUdQHr3WVdbZieaYFvHF1rin5pvDTa/NOZ/qCyE= -github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= +github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk= +github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= From fea8d152e7d81a9cbd387108d263950d8ec6ddc7 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 23 Feb 2022 15:41:32 +0000 Subject: [PATCH 02/24] Relax roomserver input transactional isolation (#2224) * Don't force full transactional isolation on roomserver input * Set succeeded * Tweak `MissingAuthPrevEvents` --- roomserver/internal/helpers/auth.go | 11 +-- roomserver/internal/input/input.go | 57 +----------- roomserver/internal/input/input_events.go | 89 ++++++++++--------- .../internal/input/input_latest_events.go | 11 ++- roomserver/internal/input/input_missing.go | 20 +++-- roomserver/storage/interface.go | 5 ++ roomserver/storage/shared/room_updater.go | 54 ----------- roomserver/storage/shared/storage.go | 23 +++++ 8 files changed, 105 insertions(+), 165 deletions(-) diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 9af0bf591..0229f822f 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -20,22 +20,17 @@ import ( "sort" "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) -type checkForAuthAndSoftFailStorage interface { - state.StateResolutionStorage - StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) - RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) -} - // CheckForSoftFail returns true if the event should be soft-failed // and false otherwise. The return error value should be checked before // the soft-fail bool. func CheckForSoftFail( ctx context.Context, - db checkForAuthAndSoftFailStorage, + db storage.Database, event *gomatrixserverlib.HeaderedEvent, stateEventIDs []string, ) (bool, error) { @@ -97,7 +92,7 @@ func CheckForSoftFail( // Returns the numeric IDs for the auth events. func CheckAuthEvents( ctx context.Context, - db checkForAuthAndSoftFailStorage, + db storage.Database, event *gomatrixserverlib.HeaderedEvent, authEventIDs []string, ) ([]types.EventNID, error) { diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 22e4b67a0..178533ded 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -19,7 +19,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "sync" "time" @@ -40,19 +39,6 @@ import ( "github.com/tidwall/gjson" ) -type retryAction int -type commitAction int - -const ( - doNotRetry retryAction = iota - retryLater -) - -const ( - commitTransaction commitAction = iota - rollbackTransaction -) - var keyContentFields = map[string]string{ "m.room.join_rules": "join_rule", "m.room.history_visibility": "history_visibility", @@ -117,8 +103,7 @@ func (r *Inputer) Start() error { _ = msg.InProgress() // resets the acknowledgement wait timer defer eventsInProgress.Delete(index) defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - action, err := r.processRoomEventUsingUpdater(r.ProcessContext.Context(), roomID, &inputRoomEvent) - if err != nil { + if err := r.processRoomEvent(r.ProcessContext.Context(), &inputRoomEvent); err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) } @@ -127,11 +112,8 @@ func (r *Inputer) Start() error { "event_id": inputRoomEvent.Event.EventID(), "type": inputRoomEvent.Event.Type(), }).Warn("Roomserver failed to process async event") - } - switch action { - case retryLater: - _ = msg.Nak() - case doNotRetry: + _ = msg.Term() + } else { _ = msg.Ack() } }) @@ -153,37 +135,6 @@ func (r *Inputer) Start() error { return err } -// processRoomEventUsingUpdater opens up a room updater and tries to -// process the event. It returns whether or not we should positively -// or negatively acknowledge the event (i.e. for NATS) and an error -// if it occurred. -func (r *Inputer) processRoomEventUsingUpdater( - ctx context.Context, - roomID string, - inputRoomEvent *api.InputRoomEvent, -) (retryAction, error) { - roomInfo, err := r.DB.RoomInfo(ctx, roomID) - if err != nil { - return doNotRetry, fmt.Errorf("r.DB.RoomInfo: %w", err) - } - updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) - if err != nil { - return retryLater, fmt.Errorf("r.DB.GetRoomUpdater: %w", err) - } - action, err := r.processRoomEvent(ctx, updater, inputRoomEvent) - switch action { - case commitTransaction: - if cerr := updater.Commit(); cerr != nil { - return retryLater, fmt.Errorf("updater.Commit: %w", cerr) - } - case rollbackTransaction: - if rerr := updater.Rollback(); rerr != nil { - return retryLater, fmt.Errorf("updater.Rollback: %w", rerr) - } - } - return doNotRetry, err -} - // InputRoomEvents implements api.RoomserverInternalAPI func (r *Inputer) InputRoomEvents( ctx context.Context, @@ -230,7 +181,7 @@ func (r *Inputer) InputRoomEvents( worker.Act(nil, func() { defer eventsInProgress.Delete(index) defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - _, err := r.processRoomEventUsingUpdater(ctx, roomID, &inputRoomEvent) + err := r.processRoomEvent(ctx, &inputRoomEvent) if err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 4e151699e..531d6959e 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -26,10 +26,10 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/hooks" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -68,15 +68,14 @@ var processRoomEventDuration = prometheus.NewHistogramVec( // nolint:gocyclo func (r *Inputer) processRoomEvent( ctx context.Context, - updater *shared.RoomUpdater, input *api.InputRoomEvent, -) (commitAction, error) { +) error { select { case <-ctx.Done(): // Before we do anything, make sure the context hasn't expired for this pending task. // If it has then we'll give up straight away — it's probably a synchronous input // request and the caller has already given up, but the inbox task was still queued. - return rollbackTransaction, context.DeadlineExceeded + return context.DeadlineExceeded default: } @@ -109,7 +108,7 @@ func (r *Inputer) processRoomEvent( // if we have already got this event then do not process it again, if the input kind is an outlier. // Outliers contain no extra information which may warrant a re-processing. if input.Kind == api.KindOutlier { - evs, err2 := updater.EventsFromIDs(ctx, []string{event.EventID()}) + evs, err2 := r.DB.EventsFromIDs(ctx, []string{event.EventID()}) if err2 == nil && len(evs) == 1 { // check hash matches if we're on early room versions where the event ID was a random string idFormat, err2 := headered.RoomVersion.EventIDFormat() @@ -118,11 +117,11 @@ func (r *Inputer) processRoomEvent( case gomatrixserverlib.EventIDFormatV1: if bytes.Equal(event.EventReference().EventSHA256, evs[0].EventReference().EventSHA256) { logger.Debugf("Already processed event; ignoring") - return rollbackTransaction, nil + return nil } default: logger.Debugf("Already processed event; ignoring") - return rollbackTransaction, nil + return nil } } } @@ -131,17 +130,21 @@ func (r *Inputer) processRoomEvent( // Don't waste time processing the event if the room doesn't exist. // A room entry locally will only be created in response to a create // event. + roomInfo, rerr := r.DB.RoomInfo(ctx, event.RoomID()) + if rerr != nil { + return fmt.Errorf("r.DB.RoomInfo: %w", rerr) + } isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") - if !updater.RoomExists() && !isCreateEvent { - return rollbackTransaction, fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) + if roomInfo == nil && !isCreateEvent { + return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) } var missingAuth, missingPrev bool serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{} if !isCreateEvent { - missingAuthIDs, missingPrevIDs, err := updater.MissingAuthPrevEvents(ctx, event) + missingAuthIDs, missingPrevIDs, err := r.DB.MissingAuthPrevEvents(ctx, event) if err != nil { - return rollbackTransaction, fmt.Errorf("updater.MissingAuthPrevEvents: %w", err) + return fmt.Errorf("updater.MissingAuthPrevEvents: %w", err) } missingAuth = len(missingAuthIDs) > 0 missingPrev = !input.HasState && len(missingPrevIDs) > 0 @@ -153,7 +156,7 @@ func (r *Inputer) processRoomEvent( ExcludeSelf: true, } if err := r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { - return rollbackTransaction, fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) + return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) } // Sort all of the servers into a map so that we can randomise // their order. Then make sure that the input origin and the @@ -182,8 +185,8 @@ func (r *Inputer) processRoomEvent( isRejected := false authEvents := gomatrixserverlib.NewAuthEvents(nil) knownEvents := map[string]*types.Event{} - if err := r.fetchAuthEvents(ctx, updater, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { - return rollbackTransaction, fmt.Errorf("r.fetchAuthEvents: %w", err) + if err := r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + return fmt.Errorf("r.fetchAuthEvents: %w", err) } // Check if the event is allowed by its auth events. If it isn't then @@ -205,12 +208,12 @@ func (r *Inputer) processRoomEvent( // but weren't found. if isRejected { if event.StateKey() != nil { - return commitTransaction, fmt.Errorf( + return fmt.Errorf( "missing auth event %s for state event %s (type %q, state key %q)", authEventID, event.EventID(), event.Type(), *event.StateKey(), ) } else { - return commitTransaction, fmt.Errorf( + return fmt.Errorf( "missing auth event %s for timeline event %s (type %q)", authEventID, event.EventID(), event.Type(), ) @@ -226,7 +229,7 @@ func (r *Inputer) processRoomEvent( // Check that the event passes authentication checks based on the // current room state. var err error - softfail, err = helpers.CheckForSoftFail(ctx, updater, headered, input.StateEventIDs) + softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) if err != nil { logger.WithError(err).Warn("Error authing soft-failed event") } @@ -250,7 +253,8 @@ func (r *Inputer) processRoomEvent( missingState := missingStateReq{ origin: input.Origin, inputer: r, - db: updater, + db: r.DB, + roomInfo: roomInfo, federation: r.FSAPI, keys: r.KeyRing, roomsMu: internal.NewMutexByRoom(), @@ -290,16 +294,16 @@ func (r *Inputer) processRoomEvent( } // Store the event. - _, _, stateAtEvent, redactionEvent, redactedEventID, err := updater.StoreEvent(ctx, event, authEventNIDs, isRejected) + _, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected) if err != nil { - return rollbackTransaction, fmt.Errorf("updater.StoreEvent: %w", err) + return fmt.Errorf("updater.StoreEvent: %w", err) } // if storing this event results in it being redacted then do so. if !isRejected && redactedEventID == event.EventID() { r, rerr := eventutil.RedactEvent(redactionEvent, event) if rerr != nil { - return rollbackTransaction, fmt.Errorf("eventutil.RedactEvent: %w", rerr) + return fmt.Errorf("eventutil.RedactEvent: %w", rerr) } event = r } @@ -310,23 +314,25 @@ func (r *Inputer) processRoomEvent( if input.Kind == api.KindOutlier { logger.Debug("Stored outlier") hooks.Run(hooks.KindNewEventPersisted, headered) - return commitTransaction, nil + return nil } - roomInfo, err := updater.RoomInfo(ctx, event.RoomID()) + // Request the room info again — it's possible that the room has been + // created by now if it didn't exist already. + roomInfo, err = r.DB.RoomInfo(ctx, event.RoomID()) if err != nil { - return rollbackTransaction, fmt.Errorf("updater.RoomInfo: %w", err) + return fmt.Errorf("updater.RoomInfo: %w", err) } if roomInfo == nil { - return rollbackTransaction, fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID()) + return fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID()) } if input.HasState || (!missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0) { // We haven't calculated a state for this event yet. // Lets calculate one. - err = r.calculateAndSetState(ctx, updater, input, roomInfo, &stateAtEvent, event, isRejected) + err = r.calculateAndSetState(ctx, input, roomInfo, &stateAtEvent, event, isRejected) if err != nil { - return rollbackTransaction, fmt.Errorf("r.calculateAndSetState: %w", err) + return fmt.Errorf("r.calculateAndSetState: %w", err) } } @@ -337,16 +343,15 @@ func (r *Inputer) processRoomEvent( "missing_prev": missingPrev, }).Warn("Stored rejected event") if rejectionErr != nil { - return commitTransaction, types.RejectedError(rejectionErr.Error()) + return types.RejectedError(rejectionErr.Error()) } - return commitTransaction, nil + return nil } switch input.Kind { case api.KindNew: if err = r.updateLatestEvents( ctx, // context - updater, // room updater roomInfo, // room info for the room being updated stateAtEvent, // state at event (below) event, // event @@ -354,7 +359,7 @@ func (r *Inputer) processRoomEvent( input.TransactionID, // transaction ID input.HasState, // rewrites state? ); err != nil { - return rollbackTransaction, fmt.Errorf("r.updateLatestEvents: %w", err) + return fmt.Errorf("r.updateLatestEvents: %w", err) } case api.KindOld: err = r.WriteOutputEvents(event.RoomID(), []api.OutputEvent{ @@ -366,7 +371,7 @@ func (r *Inputer) processRoomEvent( }, }) if err != nil { - return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (old): %w", err) + return fmt.Errorf("r.WriteOutputEvents (old): %w", err) } } @@ -385,14 +390,14 @@ func (r *Inputer) processRoomEvent( }, }) if err != nil { - return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (redactions): %w", err) + return fmt.Errorf("r.WriteOutputEvents (redactions): %w", err) } } // Everything was OK — the latest events updater didn't error and // we've sent output events. Finally, generate a hook call. hooks.Run(hooks.KindNewEventPersisted, headered) - return commitTransaction, nil + return nil } // fetchAuthEvents will check to see if any of the @@ -404,7 +409,6 @@ func (r *Inputer) processRoomEvent( // they are now in the database. func (r *Inputer) fetchAuthEvents( ctx context.Context, - updater *shared.RoomUpdater, logger *logrus.Entry, event *gomatrixserverlib.HeaderedEvent, auth *gomatrixserverlib.AuthEvents, @@ -418,7 +422,7 @@ func (r *Inputer) fetchAuthEvents( } for _, authEventID := range authEventIDs { - authEvents, err := updater.EventsFromIDs(ctx, []string{authEventID}) + authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID}) if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { unknown[authEventID] = struct{}{} continue @@ -495,7 +499,7 @@ nextAuthEvent: } // Finally, store the event in the database. - eventNID, _, _, _, _, err := updater.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) + eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) if err != nil { return fmt.Errorf("updater.StoreEvent: %w", err) } @@ -520,14 +524,18 @@ nextAuthEvent: func (r *Inputer) calculateAndSetState( ctx context.Context, - updater *shared.RoomUpdater, input *api.InputRoomEvent, roomInfo *types.RoomInfo, stateAtEvent *types.StateAtEvent, event *gomatrixserverlib.Event, isRejected bool, ) error { - var err error + var succeeded bool + updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) + if err != nil { + return fmt.Errorf("r.DB.GetRoomUpdater: %w", err) + } + defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) roomState := state.NewStateResolution(updater, roomInfo) if input.HasState { @@ -536,7 +544,7 @@ func (r *Inputer) calculateAndSetState( // We've been told what the state at the event is so we don't need to calculate it. // Check that those state events are in the database and store the state. var entries []types.StateEntry - if entries, err = updater.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { + if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err) } entries = types.DeduplicateStateEntries(entries) @@ -557,5 +565,6 @@ func (r *Inputer) calculateAndSetState( if err != nil { return fmt.Errorf("r.DB.SetState: %w", err) } + succeeded = true return nil } diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index ae28ebefa..f4a52031a 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -20,6 +20,7 @@ import ( "context" "fmt" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage/shared" @@ -47,7 +48,6 @@ import ( // Can only be called once at a time func (r *Inputer) updateLatestEvents( ctx context.Context, - updater *shared.RoomUpdater, roomInfo *types.RoomInfo, stateAtEvent types.StateAtEvent, event *gomatrixserverlib.Event, @@ -55,6 +55,14 @@ func (r *Inputer) updateLatestEvents( transactionID *api.TransactionID, rewritesState bool, ) (err error) { + var succeeded bool + updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) + if err != nil { + return fmt.Errorf("r.DB.GetRoomUpdater: %w", err) + } + + defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) + u := latestEventsUpdater{ ctx: ctx, api: r, @@ -71,6 +79,7 @@ func (r *Inputer) updateLatestEvents( return fmt.Errorf("u.doUpdateLatestEvents: %w", err) } + succeeded = true return } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index fc3be7987..4655e92a9 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -11,7 +11,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" - "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -25,7 +25,8 @@ type parsedRespState struct { type missingStateReq struct { origin gomatrixserverlib.ServerName - db *shared.RoomUpdater + db storage.Database + roomInfo *types.RoomInfo inputer *Inputer keys gomatrixserverlib.JSONVerifier federation fedapi.FederationInternalAPI @@ -80,7 +81,7 @@ func (t *missingStateReq) processEventWithMissingState( // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // in the gap in the DAG for _, newEvent := range newEvents { - _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, @@ -139,8 +140,7 @@ func (t *missingStateReq) processEventWithMissingState( }) } for _, ire := range outlierRoomEvents { - _, err = t.inputer.processRoomEvent(ctx, t.db, &ire) - if err != nil { + if err = t.inputer.processRoomEvent(ctx, &ire); err != nil { if _, ok := err.(types.RejectedError); !ok { return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err) } @@ -163,7 +163,7 @@ func (t *missingStateReq) processEventWithMissingState( stateIDs = append(stateIDs, event.EventID()) } - _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ Kind: api.KindOld, Event: backwardsExtremity.Headered(roomVersion), Origin: t.origin, @@ -182,7 +182,7 @@ func (t *missingStateReq) processEventWithMissingState( // they will automatically fast-forward based on the room state at the // extremity in the last step. for _, newEvent := range newEvents { - _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, @@ -473,8 +473,10 @@ retryAllowedState: // without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled, prevStateKnown bool, err error) { logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) - - latest := t.db.LatestEvents() + latest, _, _, err := t.db.LatestEventIDs(ctx, t.roomInfo.RoomNID) + if err != nil { + return nil, false, false, fmt.Errorf("t.DB.LatestEventIDs: %w", err) + } latestEvents := make([]string, len(latest)) for i, ev := range latest { latestEvents[i] = ev.EventID diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index a9851e05b..685505d52 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -35,6 +35,11 @@ type Database interface { stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, ) (types.StateSnapshotNID, error) + + MissingAuthPrevEvents( + ctx context.Context, e *gomatrixserverlib.Event, + ) (missingAuth, missingPrev []string, err error) + // Look up the state of a room at each event for a list of string event IDs. // Returns an error if there is an error talking to the database. // The length of []types.StateAtEvent is guaranteed to equal the length of eventIDs if no error is returned. diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 810a18ef2..d4a2ee3b9 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -103,25 +103,6 @@ func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { return u.currentStateSnapshotNID } -func (u *RoomUpdater) MissingAuthPrevEvents( - ctx context.Context, e *gomatrixserverlib.Event, -) (missingAuth, missingPrev []string, err error) { - for _, authEventID := range e.AuthEventIDs() { - if nids, err := u.EventNIDs(ctx, []string{authEventID}); err != nil || len(nids) == 0 { - missingAuth = append(missingAuth, authEventID) - } - } - - for _, prevEventID := range e.PrevEventIDs() { - state, err := u.StateAtEventIDs(ctx, []string{prevEventID}) - if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) { - missingPrev = append(missingPrev, prevEventID) - } - } - - return -} - // StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { @@ -146,13 +127,6 @@ func (u *RoomUpdater) SnapshotNIDFromEventID( return u.d.snapshotNIDFromEventID(ctx, u.txn, eventID) } -func (u *RoomUpdater) StoreEvent( - ctx context.Context, event *gomatrixserverlib.Event, - authEventNIDs []types.EventNID, isRejected bool, -) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { - return u.d.storeEvent(ctx, u, event, authEventNIDs, isRejected) -} - func (u *RoomUpdater) StateBlockNIDs( ctx context.Context, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { @@ -212,44 +186,16 @@ func (u *RoomUpdater) EventIDs( return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs) } -func (u *RoomUpdater) EventNIDs( - ctx context.Context, eventIDs []string, -) (map[string]types.EventNID, error) { - return u.d.eventNIDs(ctx, u.txn, eventIDs, NoFilter) -} - -func (u *RoomUpdater) UnsentEventNIDs( - ctx context.Context, eventIDs []string, -) (map[string]types.EventNID, error) { - return u.d.eventNIDs(ctx, u.txn, eventIDs, FilterUnsentOnly) -} - func (u *RoomUpdater) StateAtEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs) } -func (u *RoomUpdater) StateEntriesForEventIDs( - ctx context.Context, eventIDs []string, -) ([]types.StateEntry, error) { - return u.d.EventsTable.BulkSelectStateEventByID(ctx, u.txn, eventIDs) -} - -func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false) -} - func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true) } -func (u *RoomUpdater) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, -) ([]types.EventNID, error) { - return u.d.getMembershipEventNIDsForRoom(ctx, u.txn, roomNID, joinOnly, localOnly) -} - // IsReferenced implements types.RoomRecentEventsUpdater func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index e270e121c..6e84b2832 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -674,6 +674,29 @@ func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true) } +func (d *Database) MissingAuthPrevEvents( + ctx context.Context, e *gomatrixserverlib.Event, +) (missingAuth, missingPrev []string, err error) { + authEventNIDs, err := d.EventNIDs(ctx, e.AuthEventIDs()) + if err != nil { + return nil, nil, fmt.Errorf("d.EventNIDs: %w", err) + } + for _, authEventID := range e.AuthEventIDs() { + if _, ok := authEventNIDs[authEventID]; !ok { + missingAuth = append(missingAuth, authEventID) + } + } + + for _, prevEventID := range e.PrevEventIDs() { + state, err := d.StateAtEventIDs(ctx, []string{prevEventID}) + if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) { + missingPrev = append(missingPrev, prevEventID) + } + } + + return +} + func (d *Database) assignRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, From 4b01f1cd122ed155ff8b05aa848f4daf2aa60c34 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 24 Feb 2022 11:09:01 +0000 Subject: [PATCH 03/24] State resolution v2 micro-optimisations (#2226) * Don't populate duplicates into auth events * Only append the single event * Potentially reduce number of iterations in `isInAllAuthLists --- roomserver/state/state.go | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/roomserver/state/state.go b/roomserver/state/state.go index e5f69521e..187b996cd 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -814,6 +814,7 @@ func (v *StateResolution) resolveConflictsV2( // events may be duplicated across these sets but that's OK. authSets := make(map[string][]*gomatrixserverlib.Event, len(conflicted)) authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3) + gotAuthEvents := make(map[string]struct{}, estimate*3) authDifference := make([]*gomatrixserverlib.Event, 0, estimate) // For each conflicted event, let's try and get the needed auth events. @@ -850,9 +851,22 @@ func (v *StateResolution) resolveConflictsV2( if err != nil { return nil, err } - authEvents = append(authEvents, authSets[key]...) + + // Only add auth events into the authEvents slice once, otherwise the + // check for the auth difference can become expensive and produce + // duplicate entries, which just waste memory and CPU time. + for _, event := range authSets[key] { + if _, ok := gotAuthEvents[event.EventID()]; !ok { + authEvents = append(authEvents, event) + gotAuthEvents[event.EventID()] = struct{}{} + } + } } + // Kill the reference to this so that the GC may pick it up, since we no + // longer need this after this point. + gotAuthEvents = nil // nolint:ineffassign + // This function helps us to work out whether an event exists in one of the // auth sets. isInAuthList := func(k string, event *gomatrixserverlib.Event) bool { @@ -866,11 +880,12 @@ func (v *StateResolution) resolveConflictsV2( // This function works out if an event exists in all of the auth sets. isInAllAuthLists := func(event *gomatrixserverlib.Event) bool { - found := true for k := range authSets { - found = found && isInAuthList(k, event) + if !isInAuthList(k, event) { + return false + } } - return found + return true } // Look through all of the auth events that we've been given and work out if From 4c07374c42e5d671cfb137634475f43f84f9db0e Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 24 Feb 2022 17:05:49 +0000 Subject: [PATCH 04/24] Reduce allocations significantly in state res v2, which should help to keep memory down when joining rooms too (update to matrix-org/gomatrixserverlib@f6ab9c5) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index b104b4192..a416ec98f 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed + github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052 github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 diff --git a/go.sum b/go.sum index 23ffe8b85..ecb0496bc 100644 --- a/go.sum +++ b/go.sum @@ -983,8 +983,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed h1:R8EiLWArq7KT96DrUq1xq9scPh8vLwKKeCTnORPyjhU= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052 h1:+4Q/JQ3fGgA7sIHaLMlqREX8yEpsI+HlVoW9WId7SNc= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= From cf27e26712c5aa655377f21a2efe34237c3d681f Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Fri, 25 Feb 2022 14:33:02 +0100 Subject: [PATCH 05/24] Remember parameters on registration (#2225) * Remember parameters for sessions Cleanup sessions on successfully registering or after a while * Add flakey test * Update to use time.AfterFunc, add more tests * Try to drain the channel, if possible --- clientapi/routing/auth_fallback.go | 2 +- clientapi/routing/key_crosssigning.go | 2 +- clientapi/routing/password.go | 2 +- clientapi/routing/register.go | 122 ++++++++++++++++++++------ clientapi/routing/register_test.go | 45 +++++++++- sytest-blacklist | 1 + sytest-whitelist | 4 + 7 files changed, 149 insertions(+), 29 deletions(-) diff --git a/clientapi/routing/auth_fallback.go b/clientapi/routing/auth_fallback.go index 839ca9e54..abfe830fb 100644 --- a/clientapi/routing/auth_fallback.go +++ b/clientapi/routing/auth_fallback.go @@ -162,7 +162,7 @@ func AuthFallback( } // Success. Add recaptcha as a completed login flow - AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) serveSuccess() return nil diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 7ecab9d4e..4426b7fdc 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -70,7 +70,7 @@ func UploadCrossSigningDeviceKeys( if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { return *authErr } - AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) uploadReq.UserID = device.UserID keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes) diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 499510193..acac60fa5 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -74,7 +74,7 @@ func Password( if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil { return *authErr } - AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) // Check the new password strength. if resErr = validatePassword(r.NewPassword); resErr != nil { diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index d00d9886e..10cfa4325 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -72,14 +72,19 @@ func init() { // sessionsDict keeps track of completed auth stages for each session. // It shouldn't be passed by value because it contains a mutex. type sessionsDict struct { - sync.Mutex + sync.RWMutex sessions map[string][]authtypes.LoginType + params map[string]registerRequest + timer map[string]*time.Timer } -// GetCompletedStages returns the completed stages for a session. -func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType { - d.Lock() - defer d.Unlock() +// defaultTimeout is the timeout used to clean up sessions +const defaultTimeOut = time.Minute * 5 + +// getCompletedStages returns the completed stages for a session. +func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType { + d.RLock() + defer d.RUnlock() if completedStages, ok := d.sessions[sessionID]; ok { return completedStages @@ -88,28 +93,79 @@ func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginTyp return make([]authtypes.LoginType, 0) } -func newSessionsDict() *sessionsDict { - return &sessionsDict{ - sessions: make(map[string][]authtypes.LoginType), +// addParams adds a registerRequest to a sessionID and starts a timer to delete that registerRequest +func (d *sessionsDict) addParams(sessionID string, r registerRequest) { + d.startTimer(defaultTimeOut, sessionID) + d.Lock() + defer d.Unlock() + d.params[sessionID] = r +} + +func (d *sessionsDict) getParams(sessionID string) (registerRequest, bool) { + d.RLock() + defer d.RUnlock() + r, ok := d.params[sessionID] + return r, ok +} + +// deleteSession cleans up a given session, either because the registration completed +// successfully, or because a given timeout (default: 5min) was reached. +func (d *sessionsDict) deleteSession(sessionID string) { + d.Lock() + defer d.Unlock() + delete(d.params, sessionID) + delete(d.sessions, sessionID) + // stop the timer, e.g. because the registration was completed + if t, ok := d.timer[sessionID]; ok { + if !t.Stop() { + select { + case <-t.C: + default: + } + } + delete(d.timer, sessionID) } } -// AddCompletedSessionStage records that a session has completed an auth stage. -func AddCompletedSessionStage(sessionID string, stage authtypes.LoginType) { - sessions.Lock() - defer sessions.Unlock() +func newSessionsDict() *sessionsDict { + return &sessionsDict{ + sessions: make(map[string][]authtypes.LoginType), + params: make(map[string]registerRequest), + timer: make(map[string]*time.Timer), + } +} - for _, completedStage := range sessions.sessions[sessionID] { +func (d *sessionsDict) startTimer(duration time.Duration, sessionID string) { + d.Lock() + defer d.Unlock() + t, ok := d.timer[sessionID] + if ok { + if !t.Stop() { + <-t.C + } + t.Reset(duration) + return + } + d.timer[sessionID] = time.AfterFunc(duration, func() { + d.deleteSession(sessionID) + }) +} + +// addCompletedSessionStage records that a session has completed an auth stage +// also starts a timer to delete the session once done. +func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtypes.LoginType) { + d.startTimer(defaultTimeOut, sessionID) + d.Lock() + defer d.Unlock() + for _, completedStage := range d.sessions[sessionID] { if completedStage == stage { return } } - sessions.sessions[sessionID] = append(sessions.sessions[sessionID], stage) + d.sessions[sessionID] = append(sessions.sessions[sessionID], stage) } var ( - // TODO: Remove old sessions. Need to do so on a session-specific timeout. - // sessions stores the completed flow stages for all sessions. Referenced using their sessionID. sessions = newSessionsDict() validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) ) @@ -167,7 +223,7 @@ func newUserInteractiveResponse( params map[string]interface{}, ) userInteractiveResponse { return userInteractiveResponse{ - fs, sessions.GetCompletedStages(sessionID), params, sessionID, + fs, sessions.getCompletedStages(sessionID), params, sessionID, } } @@ -645,12 +701,12 @@ func handleRegistrationFlow( } // Add Recaptcha to the list of completed registration stages - AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) case authtypes.LoginTypeDummy: // there is nothing to do // Add Dummy to the list of completed registration stages - AddCompletedSessionStage(sessionID, authtypes.LoginTypeDummy) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeDummy) case "": // An empty auth type means that we want to fetch the available @@ -666,7 +722,7 @@ func handleRegistrationFlow( // Check if the user's registration flow has been completed successfully // A response with current registration flow and remaining available methods // will be returned if a flow has not been successfully completed yet - return checkAndCompleteFlow(sessions.GetCompletedStages(sessionID), + return checkAndCompleteFlow(sessions.getCompletedStages(sessionID), req, r, sessionID, cfg, userAPI) } @@ -708,7 +764,7 @@ func handleApplicationServiceRegistration( // Don't need to worry about appending to registration stages as // application service registration is entirely separate. return completeRegistration( - req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), + req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, ) } @@ -727,11 +783,11 @@ func checkAndCompleteFlow( if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue return completeRegistration( - req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), + req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, ) } - + sessions.addParams(sessionID, r) // There are still more stages to complete. // Return the flows and those that have been completed. return util.JSONResponse{ @@ -750,11 +806,25 @@ func checkAndCompleteFlow( func completeRegistration( ctx context.Context, userAPI userapi.UserInternalAPI, - username, password, appserviceID, ipAddr, userAgent string, + username, password, appserviceID, ipAddr, userAgent, sessionID string, inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string, accType userapi.AccountType, ) util.JSONResponse { + var registrationOK bool + defer func() { + if registrationOK { + sessions.deleteSession(sessionID) + } + }() + + if data, ok := sessions.getParams(sessionID); ok { + username = data.Username + password = data.Password + deviceID = data.DeviceID + displayName = data.InitialDisplayName + inhibitLogin = data.InhibitLogin + } if username == "" { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -795,6 +865,7 @@ func completeRegistration( // Check whether inhibit_login option is set. If so, don't create an access // token or a device for this user if inhibitLogin { + registrationOK = true return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ @@ -828,6 +899,7 @@ func completeRegistration( } } + registrationOK = true return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ @@ -976,5 +1048,5 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS if ssrr.Admin { accType = userapi.AccountTypeAdmin } - return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID, accType) + return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) } diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 1f615dc26..c6b7e61cf 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -17,6 +17,7 @@ package routing import ( "regexp" "testing" + "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/setup/config" @@ -140,7 +141,7 @@ func TestFlowCheckingExtraneousIncorrectInput(t *testing.T) { func TestEmptyCompletedFlows(t *testing.T) { fakeEmptySessions := newSessionsDict() fakeSessionID := "aRandomSessionIDWhichDoesNotExist" - ret := fakeEmptySessions.GetCompletedStages(fakeSessionID) + ret := fakeEmptySessions.getCompletedStages(fakeSessionID) // check for [] if ret == nil || len(ret) != 0 { @@ -208,3 +209,45 @@ func TestValidationOfApplicationServices(t *testing.T) { t.Errorf("user_id should not have been valid: @_something_else:localhost") } } + +func TestSessionCleanUp(t *testing.T) { + s := newSessionsDict() + + t.Run("session is cleaned up after a while", func(t *testing.T) { + t.Parallel() + dummySession := "helloWorld" + // manually added, as s.addParams() would start the timer with the default timeout + s.params[dummySession] = registerRequest{Username: "Testing"} + s.startTimer(time.Millisecond, dummySession) + time.Sleep(time.Millisecond * 2) + if data, ok := s.getParams(dummySession); ok { + t.Errorf("expected session to be deleted: %+v", data) + } + }) + + t.Run("session is deleted, once the registration completed", func(t *testing.T) { + t.Parallel() + dummySession := "helloWorld2" + s.startTimer(time.Minute, dummySession) + s.deleteSession(dummySession) + if data, ok := s.getParams(dummySession); ok { + t.Errorf("expected session to be deleted: %+v", data) + } + }) + + t.Run("session timer is restarted after second call", func(t *testing.T) { + t.Parallel() + dummySession := "helloWorld3" + // the following will start a timer with the default timeout of 5min + s.addParams(dummySession, registerRequest{Username: "Testing"}) + s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha) + s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy) + s.getCompletedStages(dummySession) + // reset the timer with a lower timeout + s.startTimer(time.Millisecond, dummySession) + time.Sleep(time.Millisecond * 2) + if data, ok := s.getParams(dummySession); ok { + t.Errorf("expected session to be deleted: %+v", data) + } + }) +} \ No newline at end of file diff --git a/sytest-blacklist b/sytest-blacklist index 16abce8da..e8617dcdf 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -24,6 +24,7 @@ Local device key changes get to remote servers with correct prev_id # Flakey Local device key changes appear in /keys/changes +/context/ with lazy_load_members filter works # we don't support groups Remove group category diff --git a/sytest-whitelist b/sytest-whitelist index d3144572d..12522cfb3 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -601,3 +601,7 @@ Can query remote device keys using POST after notification Device deletion propagates over federation Get left notifs in sync and /keys/changes when other user leaves Remote banned user is kicked and may not rejoin until unbanned +registration remembers parameters +registration accepts non-ascii passwords +registration with inhibit_login inhibits login + From 264165eb8c0e8ae618e53de47b814f8a9781d567 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 25 Feb 2022 17:35:10 +0000 Subject: [PATCH 06/24] Update systemd example to set `LimitNOFILE` --- docs/systemd/monolith-example.service | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/systemd/monolith-example.service b/docs/systemd/monolith-example.service index 731c6159b..237120ffb 100644 --- a/docs/systemd/monolith-example.service +++ b/docs/systemd/monolith-example.service @@ -13,6 +13,7 @@ Group=dendrite WorkingDirectory=/opt/dendrite/ ExecStart=/opt/dendrite/bin/dendrite-monolith-server Restart=always +LimitNOFILE=65535 [Install] WantedBy=multi-user.target From ac77732185a2473dd39984669ea3f454548aa7f5 Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Mon, 28 Feb 2022 13:57:56 +0100 Subject: [PATCH 07/24] Add possibility to reset password using create-account (#2231) * Add possibility to reset password * Invalidate logins * Fix test --- cmd/create-account/main.go | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 3003896c8..307fa17e6 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -48,24 +48,27 @@ Example: # read password from stdin %s --config dendrite.yaml -username alice -passwordstdin < my.pass cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin + # reset password for a user, can be used with a combination above to read the password + %s --config dendrite.yaml -reset-password -username alice -password foobarbaz Arguments: ` var ( - username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") - password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)") - pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") - pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") - askPass = flag.Bool("ask-pass", false, "Ask for the password to use") - isAdmin = flag.Bool("admin", false, "Create an admin account") + username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") + password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)") + pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") + pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") + askPass = flag.Bool("ask-pass", false, "Ask for the password to use") + isAdmin = flag.Bool("admin", false, "Create an admin account") + resetPassword = flag.Bool("reset-password", false, "Resets the password for the given username") ) func main() { name := os.Args[0] flag.Usage = func() { - _, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name) + _, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name, name) flag.PrintDefaults() } cfg := setup.ParseFlags(true) @@ -93,6 +96,19 @@ func main() { if *isAdmin { accType = api.AccountTypeAdmin } + + if *resetPassword { + err = accountDB.SetPassword(context.Background(), *username, pass) + if err != nil { + logrus.Fatalf("Failed to update password for user %s: %s", *username, err.Error()) + } + if _, err = accountDB.RemoveAllDevices(context.Background(), *username, ""); err != nil { + logrus.Fatalf("Failed to remove all devices: %s", err.Error()) + } + logrus.Infof("Updated password for user %s and invalidated all logins\n", *username) + return + } + _, err = accountDB.CreateAccount(context.Background(), *username, pass, "", accType) if err != nil { logrus.Fatalln("Failed to create the account:", err.Error()) From a23fda662607e9160230335503e912f626abf616 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 28 Feb 2022 14:51:40 +0000 Subject: [PATCH 08/24] Update `Events` call-sites which now don't return an error, update `parsedRespState` to sort (#2227) * Topologically sort with `SendEventWithState`, so that earlier events should satisfy auth for later ones * Revert "Topologically sort with `SendEventWithState`, so that earlier events should satisfy auth for later ones" This reverts commit b0cd706012b4c9b6724b11e16f19c4cb732ab286. * Update to matrix-org/gomatrixserverlib#293 * `Events` no longer returns an error, other tweaks * Make sure `Events` is sorted for `parsedRespState` too --- go.mod | 2 +- go.sum | 4 ++-- roomserver/api/wrapper.go | 8 ++------ roomserver/internal/input/input_missing.go | 22 +++++++++++++++++----- setup/mscs/msc2836/msc2836.go | 8 ++------ 5 files changed, 24 insertions(+), 20 deletions(-) diff --git a/go.mod b/go.mod index a416ec98f..78af67e6e 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220225115648-d2a338a15438 github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 diff --git a/go.sum b/go.sum index ecb0496bc..4ba0723f2 100644 --- a/go.sum +++ b/go.sum @@ -983,8 +983,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052 h1:+4Q/JQ3fGgA7sIHaLMlqREX8yEpsI+HlVoW9WId7SNc= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220225115648-d2a338a15438 h1:3B0ZEJ5YVVbRKHV7WFgji5z6s262YIVRZvtDotdpbsI= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220225115648-d2a338a15438/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 012094c62..5491d36b3 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -51,12 +51,8 @@ func SendEventWithState( state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, ) error { - outliers, err := state.Events(event.RoomVersion) - if err != nil { - return err - } - - var ires []InputRoomEvent + outliers := state.Events(event.RoomVersion) + ires := make([]InputRoomEvent, 0, len(outliers)) for _, outlier := range outliers { if haveEventIDs[outlier.EventID()] { continue diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 4655e92a9..a7da9b06d 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -23,6 +23,21 @@ type parsedRespState struct { StateEvents []*gomatrixserverlib.Event } +func (p *parsedRespState) Events() []*gomatrixserverlib.Event { + eventsByID := make(map[string]*gomatrixserverlib.Event, len(p.AuthEvents)+len(p.StateEvents)) + for i, event := range p.AuthEvents { + eventsByID[event.EventID()] = p.AuthEvents[i] + } + for i, event := range p.StateEvents { + eventsByID[event.EventID()] = p.StateEvents[i] + } + allEvents := make([]*gomatrixserverlib.Event, 0, len(eventsByID)) + for _, event := range eventsByID { + allEvents = append(allEvents, event) + } + return gomatrixserverlib.ReverseTopologicalOrdering(allEvents, gomatrixserverlib.TopologicalOrderByAuthEvents) +} + type missingStateReq struct { origin gomatrixserverlib.ServerName db storage.Database @@ -124,11 +139,8 @@ func (t *missingStateReq) processEventWithMissingState( t.hadEventsMutex.Unlock() sendOutliers := func(resolvedState *parsedRespState) error { - outliers, oerr := gomatrixserverlib.OrderAuthAndStateEvents(resolvedState.AuthEvents, resolvedState.StateEvents, roomVersion) - if oerr != nil { - return fmt.Errorf("gomatrixserverlib.OrderAuthAndStateEvents: %w", oerr) - } - var outlierRoomEvents []api.InputRoomEvent + outliers := resolvedState.Events() + outlierRoomEvents := make([]api.InputRoomEvent, 0, len(outliers)) for _, outlier := range outliers { if hadEvents[outlier.EventID()] { continue diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 0af22c19a..29c781a88 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -654,11 +654,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo AuthEvents: res.AuthChain, StateEvents: stateEvents, } - eventsInOrder, err := respState.Events(rc.roomVersion) - if err != nil { - util.GetLogger(rc.ctx).WithError(err).Error("failed to calculate order to send events in MSC2836EventRelationshipsResponse") - return - } + eventsInOrder := respState.Events(rc.roomVersion) // everything gets sent as an outlier because auth chain events may be disjoint from the DAG // as may the threaded events. var ires []roomserver.InputRoomEvent @@ -669,7 +665,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo }) } // we've got the data by this point so use a background context - err = roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false) + err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver") } From 58bf91a585ec78f6ca6ff0c9ad0c10c5db9715a7 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 1 Mar 2022 11:00:54 +0000 Subject: [PATCH 09/24] Check for changes in `PerformUploadDeviceKeys` (#2233) * Don't generate key change notifs if nothing changed on cross-signing upload * Check both directions of changes --- keyserver/internal/cross_signing.go | 53 ++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go index bfb2037f8..5124f37e6 100644 --- a/keyserver/internal/cross_signing.go +++ b/keyserver/internal/cross_signing.go @@ -166,26 +166,53 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P } // We can't have a self-signing or user-signing key without a master - // key, so make sure we have one of those. - if !hasMasterKey { - existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID) - if err != nil { - res.Error = &api.KeyError{ - Err: "Retrieving cross-signing keys from database failed: " + err.Error(), - } - return + // key, so make sure we have one of those. We will also only actually do + // something if any of the specified keys in the request are different + // to what we've got in the database, to avoid generating key change + // notifications unnecessarily. + existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID) + if err != nil { + res.Error = &api.KeyError{ + Err: "Retrieving cross-signing keys from database failed: " + err.Error(), } - - _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster] + return } // If we still can't find a master key for the user then stop the upload. // This satisfies the "Fails to upload self-signing key without master key" test. if !hasMasterKey { - res.Error = &api.KeyError{ - Err: "No master key was found", - IsMissingParam: true, + if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey { + res.Error = &api.KeyError{ + Err: "No master key was found", + IsMissingParam: true, + } + return } + } + + // Check if anything actually changed compared to what we have in the database. + changed := false + for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{ + gomatrixserverlib.CrossSigningKeyPurposeMaster, + gomatrixserverlib.CrossSigningKeyPurposeSelfSigning, + gomatrixserverlib.CrossSigningKeyPurposeUserSigning, + } { + old, gotOld := existingKeys[purpose] + new, gotNew := toStore[purpose] + if gotOld != gotNew { + // A new key purpose has been specified that we didn't know before, + // or one has been removed. + changed = true + break + } + if !bytes.Equal(old, new) { + // One of the existing keys for a purpose we already knew about has + // changed. + changed = true + break + } + } + if !changed { return } From 530f05885dccba91559aff09eaaa20540a08a419 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 1 Mar 2022 13:01:38 +0000 Subject: [PATCH 10/24] Limit `JoinedUsersSetInRooms` to interested users (#2234) * Limit database work in `JoinedUsersSetInRooms` to changed user IDs only * Comments * Fix variadic params for SQLite, update comments --- roomserver/api/query.go | 1 + roomserver/internal/query/query.go | 2 +- roomserver/storage/interface.go | 4 ++-- .../storage/postgres/membership_table.go | 10 ++++----- roomserver/storage/shared/storage.go | 20 +++++++++++------- .../storage/sqlite3/membership_table.go | 21 ++++++++++++------- roomserver/storage/tables/interface.go | 5 ++--- syncapi/internal/keychange.go | 17 +++++++++++---- 8 files changed, 49 insertions(+), 31 deletions(-) diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 96d6711c6..6255ffa3a 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -269,6 +269,7 @@ type QueryAuthChainResponse struct { type QuerySharedUsersRequest struct { UserID string + OtherUserIDs []string ExcludeRoomIDs []string IncludeRoomIDs []string } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index c8bbe7705..20c5a2585 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -696,7 +696,7 @@ func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUser } roomIDs = roomIDs[:j] - users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs) + users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs) if err != nil { return err } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 685505d52..cd232e3e9 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -151,8 +151,8 @@ type Database interface { // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) - // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. - JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) + // JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms. + JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) // GetServerInRoom returns true if we think a server is in a given room or false otherwise. diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 48c2c35cd..127178743 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -66,7 +66,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( ` var selectJoinedUsersSetForRoomsSQL = "" + - "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + + " WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " GROUP BY target_nid" @@ -306,13 +307,10 @@ func (s *membershipStatements) SelectRoomsWithMembership( func (s *membershipStatements) SelectJoinedUsersSetForRooms( ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, + userNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]int, error) { - roomIDarray := make([]int64, len(roomNIDs)) - for i := range roomNIDs { - roomIDarray[i] = int64(roomNIDs[i]) - } stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt) - rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) + rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs)) if err != nil { return nil, err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 6e84b2832..6dc40816c 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1104,13 +1104,23 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu return result, nil } -// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. -func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { +// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms. +func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) { roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs) if err != nil { return nil, err } - userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs) + userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs) + if err != nil { + return nil, err + } + userNIDs := make([]types.EventStateKeyNID, 0, len(userNIDsMap)) + nidToUserID := make(map[types.EventStateKeyNID]string, len(userNIDsMap)) + for id, nid := range userNIDsMap { + userNIDs = append(userNIDs, nid) + nidToUserID[nid] = id + } + userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs) if err != nil { return nil, err } @@ -1120,10 +1130,6 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) stateKeyNIDs[i] = nid i++ } - nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs) - if err != nil { - return nil, err - } if len(nidToUserID) != len(userNIDToCount) { logrus.Warnf("SelectJoinedUsersSetForRooms found %d users but BulkSelectEventStateKey only returned state key NIDs for %d of them", len(userNIDToCount), len(nidToUserID)) } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 181b4b4c9..43567a94c 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -42,7 +42,8 @@ const membershipSchema = ` ` var selectJoinedUsersSetForRoomsSQL = "" + - "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + + " WHERE room_nid IN ($1) AND target_nid IN ($2) AND" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " GROUP BY target_nid" @@ -280,18 +281,22 @@ func (s *membershipStatements) SelectRoomsWithMembership( return roomNIDs, nil } -func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { - iRoomNIDs := make([]interface{}, len(roomNIDs)) - for i, v := range roomNIDs { - iRoomNIDs[i] = v +func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) { + params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs)) + for _, v := range roomNIDs { + params = append(params, v) } - query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) + for _, v := range userNIDs { + params = append(params, v) + } + query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) + query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1) var rows *sql.Rows var err error if txn != nil { - rows, err = txn.QueryContext(ctx, query, iRoomNIDs...) + rows, err = txn.QueryContext(ctx, query, params...) } else { - rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...) + rows, err = s.db.QueryContext(ctx, query, params...) } if err != nil { return nil, err diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index e3fed700b..04e3c96cc 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -127,9 +127,8 @@ type Membership interface { SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) - // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the - // counts of how many rooms they are joined. - SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) + // SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms. + SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 37a9e2d39..dc4acd8da 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -82,7 +82,16 @@ func DeviceListCatchup( util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed") return to, hasNew, nil } - // QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user. + + // Work out which user IDs we care about — that includes those in the original request, + // the response from QueryKeyChanges (which includes ALL users who have changed keys) + // as well as every user who has a join or leave event in the current sync response. We + // will request information about which rooms these users are joined to, so that we can + // see if we still share any rooms with them. + joinUserIDs, leaveUserIDs := membershipEvents(res) + queryRes.UserIDs = append(queryRes.UserIDs, joinUserIDs...) + queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...) + queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs) var sharedUsersMap map[string]int sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs) util.GetLogger(ctx).Debugf( @@ -100,9 +109,8 @@ func DeviceListCatchup( userSet[userID] = true } } - // if the response has any join/leave events, add them now. + // Finally, add in users who have joined or left. // TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them. - joinUserIDs, leaveUserIDs := membershipEvents(res) for _, userID := range joinUserIDs { if !userSet[userID] { res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) @@ -213,7 +221,8 @@ func filterSharedUsers( var result []string var sharedUsersRes roomserverAPI.QuerySharedUsersResponse err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{ - UserID: userID, + UserID: userID, + OtherUserIDs: usersWithChangedKeys, }, &sharedUsersRes) if err != nil { // default to all users so we do needless queries rather than miss some important device update From f1b92de0175d74be7b167c5e3f168d0048c84843 Mon Sep 17 00:00:00 2001 From: kegsay Date: Tue, 1 Mar 2022 13:40:07 +0000 Subject: [PATCH 11/24] MSC2946: Spaces Summary (round 2) (#2232) * Initial cut at fixing up MSC2946 to work with latest spec * bugfix: send response back correctly * Initial working version of MSC2946 * msc2946: handle suggested_only; remove custom database As the MSC doesn't require reverse lookups, we can just pull the room state and inspect via the roomserver database. To handle this, expand QueryCurrentState to support wildcards. Use all this and handle `?suggested_only`. * Sort child rooms * msc2946: Make TestClientSpacesSummary pass * msc2946: allow invited rooms to be spidered * msc2946: support basic federation requests * fix up go mod --- federationapi/api/api.go | 2 +- federationapi/internal/federationclient.go | 4 +- federationapi/inthttp/client.go | 18 +- federationapi/inthttp/server.go | 2 +- go.mod | 7 +- go.sum | 12 +- roomserver/api/query.go | 5 +- roomserver/internal/query/query.go | 25 +- roomserver/storage/interface.go | 1 + roomserver/storage/shared/storage.go | 56 +++ setup/mscs/msc2946/msc2946.go | 543 ++++++++++++--------- setup/mscs/msc2946/msc2946_test.go | 464 ------------------ setup/mscs/msc2946/storage.go | 182 ------- 13 files changed, 411 insertions(+), 910 deletions(-) delete mode 100644 setup/mscs/msc2946/msc2946_test.go delete mode 100644 setup/mscs/msc2946/storage.go diff --git a/federationapi/api/api.go b/federationapi/api/api.go index f5ee75b4b..4d6b0211c 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -21,7 +21,7 @@ type FederationClient interface { QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) - MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest) (res gomatrixserverlib.MSC2946SpacesResponse, err error) + MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index b31db466c..b8bd5beda 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships( } func (a *FederationInternalAPI) MSC2946Spaces( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest, + ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.MSC2946Spaces(ctx, s, roomID, r) + return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly) }) if err != nil { return res, err diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index f9b2a33d2..01ca6595d 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -526,23 +526,23 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships( } type spacesReq struct { - S gomatrixserverlib.ServerName - Req gomatrixserverlib.MSC2946SpacesRequest - RoomID string - Res gomatrixserverlib.MSC2946SpacesResponse - Err *api.FederationClientError + S gomatrixserverlib.ServerName + SuggestedOnly bool + RoomID string + Res gomatrixserverlib.MSC2946SpacesResponse + Err *api.FederationClientError } func (h *httpFederationInternalAPI) MSC2946Spaces( - ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest, + ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces") defer span.Finish() request := spacesReq{ - S: dst, - Req: r, - RoomID: roomID, + S: dst, + SuggestedOnly: suggestedOnly, + RoomID: roomID, } var response spacesReq apiURL := h.federationAPIURL + FederationAPISpacesSummaryPath diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 8d193d9c9..ca4930f20 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -378,7 +378,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { if err := json.NewDecoder(req.Body).Decode(&request); err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.Req) + res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.SuggestedOnly) if err != nil { ferr, ok := err.(*api.FederationClientError) if ok { diff --git a/go.mod b/go.mod index 78af67e6e..e44e7393d 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/frankban/quicktest v1.14.0 // indirect github.com/getsentry/sentry-go v0.12.0 github.com/gologme/log v1.3.0 + github.com/google/uuid v1.2.0 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.4.2 github.com/h2non/filetype v1.1.3 // indirect @@ -39,7 +40,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220225115648-d2a338a15438 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6 github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 @@ -61,11 +62,11 @@ require ( github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.2 go.uber.org/atomic v1.9.0 - golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a + golang.org/x/crypto v0.0.0-20220214200702-86341886e292 golang.org/x/image v0.0.0-20211028202545-6944b10bf410 golang.org/x/mobile v0.0.0-20220112015953-858099ff7816 golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd - golang.org/x/sys v0.0.0-20220207234003-57398862261d // indirect + golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 gopkg.in/h2non/bimg.v1 v1.1.5 gopkg.in/yaml.v2 v2.4.0 diff --git a/go.sum b/go.sum index 4ba0723f2..43b363d30 100644 --- a/go.sum +++ b/go.sum @@ -983,8 +983,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220225115648-d2a338a15438 h1:3B0ZEJ5YVVbRKHV7WFgji5z6s262YIVRZvtDotdpbsI= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220225115648-d2a338a15438/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6 h1:CqbM5tPbF1mV2h1J6N0NSF8et6HvFlwA98TLCUcjiSI= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= @@ -1510,8 +1510,8 @@ golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5 golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo= -golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1737,8 +1737,8 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc= -golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 h1:nhht2DYV/Sn3qOayu8lM+cU1ii9sTLUeBQwQQfUHtrs= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 6255ffa3a..66e85f2f3 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -313,7 +313,10 @@ type QueryBulkStateContentResponse struct { } type QueryCurrentStateRequest struct { - RoomID string + RoomID string + AllowWildcards bool + // State key tuples. If a state_key has '*' and AllowWidlcards is true, returns all matching + // state events with that event type. StateTuples []gomatrixserverlib.StateKeyTuple } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 20c5a2585..70cc5d62c 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -621,12 +621,25 @@ func (r *Queryer) QueryPublishedRooms( func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error { res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) for _, tuple := range req.StateTuples { - ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey) - if err != nil { - return err - } - if ev != nil { - res.StateEvents[tuple] = ev + if tuple.StateKey == "*" && req.AllowWildcards { + events, err := r.DB.GetStateEventsWithEventType(ctx, req.RoomID, tuple.EventType) + if err != nil { + return err + } + for _, e := range events { + res.StateEvents[gomatrixserverlib.StateKeyTuple{ + EventType: e.Type(), + StateKey: *e.StateKey(), + }] = e + } + } else { + ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey) + if err != nil { + return err + } + if ev != nil { + res.StateEvents[tuple] = ev + } } } return nil diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index cd232e3e9..a2b22b401 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -146,6 +146,7 @@ type Database interface { // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) + GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 6dc40816c..f87782776 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -979,6 +979,62 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s return nil, nil } +// Same as GetStateEvent but returns all matching state events with this event type. Returns no error +// if there are no events with this event type. +func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) { + roomInfo, err := d.RoomInfo(ctx, roomID) + if err != nil { + return nil, err + } + if roomInfo == nil { + return nil, fmt.Errorf("room %s doesn't exist", roomID) + } + // e.g invited rooms + if roomInfo.IsStub { + return nil, nil + } + eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) + if err == sql.ErrNoRows { + // No rooms have an event of this type, otherwise we'd have an event type NID + return nil, nil + } + if err != nil { + return nil, err + } + entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID) + if err != nil { + return nil, err + } + var eventNIDs []types.EventNID + for _, e := range entries { + if e.EventTypeNID == eventTypeNID { + eventNIDs = append(eventNIDs, e.EventNID) + } + } + eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) + if err != nil { + eventIDs = map[types.EventNID]string{} + } + // return the events requested + eventPairs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs) + if err != nil { + return nil, err + } + if len(eventPairs) == 0 { + return nil, nil + } + var result []*gomatrixserverlib.HeaderedEvent + for _, pair := range eventPairs { + ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[pair.EventNID], pair.EventJSON, false, roomInfo.RoomVersion) + if err != nil { + return nil, err + } + result = append(result, ev.Headered(roomInfo.RoomVersion)) + } + + return result, nil +} + // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) { var membershipState tables.MembershipState diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 3824c99a2..3bb56f4b3 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -18,17 +18,18 @@ package msc2946 import ( "context" "encoding/json" - "fmt" "net/http" + "net/url" + "sort" + "strconv" "strings" "sync" "time" + "github.com/google/uuid" "github.com/gorilla/mux" - chttputil "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" fs "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/internal/httputil" roomserver "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" @@ -39,15 +40,15 @@ import ( ) const ( - ConstCreateEventContentKey = "type" - ConstSpaceChildEventType = "m.space.child" - ConstSpaceParentEventType = "m.space.parent" + ConstCreateEventContentKey = "type" + ConstCreateEventContentValueSpace = "m.space" + ConstSpaceChildEventType = "m.space.child" + ConstSpaceParentEventType = "m.space.parent" ) -// Defaults sets the request defaults -func Defaults(r *gomatrixserverlib.MSC2946SpacesRequest) { - r.Limit = 2000 - r.MaxRoomsPerSpace = -1 +type MSC2946ClientResponse struct { + Rooms []gomatrixserverlib.MSC2946Room `json:"rooms"` + NextBatch string `json:"next_batch,omitempty"` } // Enable this MSC @@ -55,26 +56,11 @@ func Enable( base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI, fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, ) error { - db, err := NewDatabase(&base.Cfg.MSCs.Database) - if err != nil { - return fmt.Errorf("cannot enable MSC2946: %w", err) - } - hooks.Enable() - hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { - he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) - hookErr := db.StoreReference(context.Background(), he) - if hookErr != nil { - util.GetLogger(context.Background()).WithError(hookErr).WithField("event_id", he.EventID()).Error( - "failed to StoreReference", - ) - } - }) + clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, base.Cfg.Global.ServerName)) + base.PublicClientAPIMux.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) + base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) - base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/spaces", - httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(db, rsAPI, fsAPI, base.Cfg.Global.ServerName)), - ).Methods(http.MethodPost, http.MethodOptions) - - base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/spaces/{roomID}", httputil.MakeExternalAPI( + fedAPI := httputil.MakeExternalAPI( "msc2946_fed_spaces", func(req *http.Request) util.JSONResponse { fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( req, time.Now(), base.Cfg.Global.ServerName, keyRing, @@ -88,105 +74,99 @@ func Enable( return util.ErrorResponse(err) } roomID := params["roomID"] - return federatedSpacesHandler(req.Context(), fedReq, roomID, db, rsAPI, fsAPI, base.Cfg.Global.ServerName) + return federatedSpacesHandler(req.Context(), fedReq, roomID, rsAPI, fsAPI, base.Cfg.Global.ServerName) }, - )).Methods(http.MethodPost, http.MethodOptions) + ) + base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet) + base.PublicFederationAPIMux.Handle("/v1/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet) return nil } func federatedSpacesHandler( - ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, db Database, + ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, thisServer gomatrixserverlib.ServerName, ) util.JSONResponse { - inMemoryBatchCache := make(map[string]set) - var r gomatrixserverlib.MSC2946SpacesRequest - Defaults(&r) - if err := json.Unmarshal(fedReq.Content(), &r); err != nil { + u, err := url.Parse(fedReq.RequestURI()) + if err != nil { return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), + Code: 400, + JSON: jsonerror.InvalidParam("bad request uri"), } } - w := walker{ - req: &r, - rootRoomID: roomID, - serverName: fedReq.Origin(), - thisServer: thisServer, - ctx: ctx, - db: db, - rsAPI: rsAPI, - fsAPI: fsAPI, - inMemoryBatchCache: inMemoryBatchCache, - } - res := w.walk() - return util.JSONResponse{ - Code: 200, - JSON: res, + w := walker{ + rootRoomID: roomID, + serverName: fedReq.Origin(), + thisServer: thisServer, + ctx: ctx, + suggestedOnly: u.Query().Get("suggested_only") == "true", + limit: 1000, + // The main difference is that it does not recurse into spaces and does not support pagination. + // This is somewhat equivalent to a Client-Server request with a max_depth=1. + maxDepth: 1, + + rsAPI: rsAPI, + fsAPI: fsAPI, + // inline cache as we don't have pagination in federation mode + paginationCache: make(map[string]paginationInfo), } + return w.walk() } func spacesHandler( - db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, + rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, thisServer gomatrixserverlib.ServerName, ) func(*http.Request, *userapi.Device) util.JSONResponse { + // declared outside the returned handler so it persists between calls + // TODO: clear based on... time? + paginationCache := make(map[string]paginationInfo) + return func(req *http.Request, device *userapi.Device) util.JSONResponse { - inMemoryBatchCache := make(map[string]set) // Extract the room ID from the request. Sanity check request data. params, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } roomID := params["roomID"] - var r gomatrixserverlib.MSC2946SpacesRequest - Defaults(&r) - if resErr := chttputil.UnmarshalJSONRequest(req, &r); resErr != nil { - return *resErr - } w := walker{ - req: &r, - rootRoomID: roomID, - caller: device, - thisServer: thisServer, - ctx: req.Context(), + suggestedOnly: req.URL.Query().Get("suggested_only") == "true", + limit: parseInt(req.URL.Query().Get("limit"), 1000), + maxDepth: parseInt(req.URL.Query().Get("max_depth"), -1), + paginationToken: req.URL.Query().Get("from"), + rootRoomID: roomID, + caller: device, + thisServer: thisServer, + ctx: req.Context(), - db: db, - rsAPI: rsAPI, - fsAPI: fsAPI, - inMemoryBatchCache: inMemoryBatchCache, - } - res := w.walk() - return util.JSONResponse{ - Code: 200, - JSON: res, + rsAPI: rsAPI, + fsAPI: fsAPI, + paginationCache: paginationCache, } + return w.walk() } } +type paginationInfo struct { + processed set + unvisited []roomVisit +} + type walker struct { - req *gomatrixserverlib.MSC2946SpacesRequest - rootRoomID string - caller *userapi.Device - serverName gomatrixserverlib.ServerName - thisServer gomatrixserverlib.ServerName - db Database - rsAPI roomserver.RoomserverInternalAPI - fsAPI fs.FederationInternalAPI - ctx context.Context + rootRoomID string + caller *userapi.Device + serverName gomatrixserverlib.ServerName + thisServer gomatrixserverlib.ServerName + rsAPI roomserver.RoomserverInternalAPI + fsAPI fs.FederationInternalAPI + ctx context.Context + suggestedOnly bool + limit int + maxDepth int + paginationToken string - // user ID|device ID|batch_num => event/room IDs sent to client - inMemoryBatchCache map[string]set - mu sync.Mutex -} - -func (w *walker) roomIsExcluded(roomID string) bool { - for _, exclRoom := range w.req.ExcludeRooms { - if exclRoom == roomID { - return true - } - } - return false + paginationCache map[string]paginationInfo + mu sync.Mutex } func (w *walker) callerID() string { @@ -196,144 +176,207 @@ func (w *walker) callerID() string { return string(w.serverName) } -func (w *walker) alreadySent(id string) bool { - w.mu.Lock() - defer w.mu.Unlock() - m, ok := w.inMemoryBatchCache[w.callerID()] - if !ok { - return false +func (w *walker) newPaginationCache() (string, paginationInfo) { + p := paginationInfo{ + processed: make(set), + unvisited: nil, } - return m[id] + tok := uuid.NewString() + return tok, p } -func (w *walker) markSent(id string) { +func (w *walker) loadPaginationCache(paginationToken string) *paginationInfo { w.mu.Lock() defer w.mu.Unlock() - m := w.inMemoryBatchCache[w.callerID()] - if m == nil { - m = make(set) - } - m[id] = true - w.inMemoryBatchCache[w.callerID()] = m + p := w.paginationCache[paginationToken] + return &p } -func (w *walker) walk() *gomatrixserverlib.MSC2946SpacesResponse { - var res gomatrixserverlib.MSC2946SpacesResponse - // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms - unvisited := []string{w.rootRoomID} - processed := make(set) +func (w *walker) storePaginationCache(paginationToken string, cache paginationInfo) { + w.mu.Lock() + defer w.mu.Unlock() + w.paginationCache[paginationToken] = cache +} + +type roomVisit struct { + roomID string + depth int + vias []string // vias to query this room by +} + +func (w *walker) walk() util.JSONResponse { + if !w.authorised(w.rootRoomID) { + if w.caller != nil { + // CS API format + return util.JSONResponse{ + Code: 403, + JSON: jsonerror.Forbidden("room is unknown/forbidden"), + } + } else { + // SS API format + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("room is unknown/forbidden"), + } + } + } + + var discoveredRooms []gomatrixserverlib.MSC2946Room + + var cache *paginationInfo + if w.paginationToken != "" { + cache = w.loadPaginationCache(w.paginationToken) + if cache == nil { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("invalid from"), + } + } + } else { + tok, c := w.newPaginationCache() + cache = &c + w.paginationToken = tok + // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms + c.unvisited = append(c.unvisited, roomVisit{ + roomID: w.rootRoomID, + depth: 0, + }) + } + + processed := cache.processed + unvisited := cache.unvisited + + // Depth first -> stack data structure for len(unvisited) > 0 { - roomID := unvisited[0] - unvisited = unvisited[1:] - // If this room has already been processed, skip. NB: do not remember this between calls - if processed[roomID] || roomID == "" { + if len(discoveredRooms) >= w.limit { + break + } + + // pop the stack + rv := unvisited[len(unvisited)-1] + unvisited = unvisited[:len(unvisited)-1] + // If this room has already been processed, skip. + // If this room exceeds the specified depth, skip. + if processed.isSet(rv.roomID) || rv.roomID == "" || (w.maxDepth > 0 && rv.depth > w.maxDepth) { continue } + // Mark this room as processed. - processed[roomID] = true + processed.set(rv.roomID) + + // if this room is not a space room, skip. + var roomType string + create := w.stateEvent(rv.roomID, gomatrixserverlib.MRoomCreate, "") + if create != nil { + // escape the `.`s so gjson doesn't think it's nested + roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str + } // Collect rooms/events to send back (either locally or fetched via federation) - var discoveredRooms []gomatrixserverlib.MSC2946Room - var discoveredEvents []gomatrixserverlib.MSC2946StrippedEvent + var discoveredChildEvents []gomatrixserverlib.MSC2946StrippedEvent // If we know about this room and the caller is authorised (joined/world_readable) then pull // events locally - if w.roomExists(roomID) && w.authorised(roomID) { - // Get all `m.space.child` and `m.space.parent` state events for the room. *In addition*, get - // all `m.space.child` and `m.space.parent` state events which *point to* (via `state_key` or `content.room_id`) - // this room. This requires servers to store reverse lookups. - events, err := w.references(roomID) + if w.roomExists(rv.roomID) && w.authorised(rv.roomID) { + // Get all `m.space.child` state events for this room + events, err := w.childReferences(rv.roomID) if err != nil { - util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Error("failed to extract references for room") + util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Error("failed to extract references for room") continue } - discoveredEvents = events + discoveredChildEvents = events - pubRoom := w.publicRoomsChunk(roomID) - roomType := "" - create := w.stateEvent(roomID, gomatrixserverlib.MRoomCreate, "") - if create != nil { - // escape the `.`s so gjson doesn't think it's nested - roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str - } + pubRoom := w.publicRoomsChunk(rv.roomID) - // Add the total number of events to `PublicRoomsChunk` under `num_refs`. Add `PublicRoomsChunk` to `rooms`. discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{ - PublicRoom: *pubRoom, - NumRefs: len(discoveredEvents), - RoomType: roomType, + PublicRoom: *pubRoom, + RoomType: roomType, + ChildrenState: events, }) } else { // attempt to query this room over federation, as either we've never heard of it before // or we've left it and hence are not authorised (but info may be exposed regardless) - fedRes, err := w.federatedRoomInfo(roomID) + fedRes, err := w.federatedRoomInfo(rv.roomID, rv.vias) if err != nil { - util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Errorf("failed to query federated spaces") + util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Errorf("failed to query federated spaces") continue } if fedRes != nil { - discoveredRooms = fedRes.Rooms - discoveredEvents = fedRes.Events + discoveredChildEvents = fedRes.Room.ChildrenState + discoveredRooms = append(discoveredRooms, fedRes.Room) + if len(fedRes.Children) > 0 { + discoveredRooms = append(discoveredRooms, fedRes.Children...) + } + // mark this room as a space room as the federated server responded. + // we need to do this so we add the children of this room to the unvisited stack + // as these children may be rooms we do know about. + roomType = ConstCreateEventContentValueSpace } } - // If this room has not ever been in `rooms` (across multiple requests), send it now - for _, room := range discoveredRooms { - if !w.alreadySent(room.RoomID) && !w.roomIsExcluded(room.RoomID) { - res.Rooms = append(res.Rooms, room) - w.markSent(room.RoomID) - } + // don't walk the children + // if the parent is not a space room + if roomType != ConstCreateEventContentValueSpace { + continue } - uniqueRooms := make(set) - - // If this is the root room from the original request, insert all these events into `events` if - // they haven't been added before (across multiple requests). - if w.rootRoomID == roomID { - for _, ev := range discoveredEvents { - if !w.alreadySent(eventKey(&ev)) { - res.Events = append(res.Events, ev) - uniqueRooms[ev.RoomID] = true - uniqueRooms[spaceTargetStripped(&ev)] = true - w.markSent(eventKey(&ev)) - } - } - } else { - // Else add them to `events` honouring the `limit` and `max_rooms_per_space` values. If either - // are exceeded, stop adding events. If the event has already been added, do not add it again. - numAdded := 0 - for _, ev := range discoveredEvents { - if w.req.Limit > 0 && len(res.Events) >= w.req.Limit { - break - } - if w.req.MaxRoomsPerSpace > 0 && numAdded >= w.req.MaxRoomsPerSpace { - break - } - if w.alreadySent(eventKey(&ev)) { - continue - } - // Skip the room if it's part of exclude_rooms but ONLY IF the source matches, as we still - // want to catch arrows which point to excluded rooms. - if w.roomIsExcluded(ev.RoomID) { - continue - } - res.Events = append(res.Events, ev) - uniqueRooms[ev.RoomID] = true - uniqueRooms[spaceTargetStripped(&ev)] = true - w.markSent(eventKey(&ev)) - // we don't distinguish between child state events and parent state events for the purposes of - // max_rooms_per_space, maybe we should? - numAdded++ - } - } - - // For each referenced room ID in the events being returned to the caller (both parent and child) + // For each referenced room ID in the child events being returned to the caller // add the room ID to the queue of unvisited rooms. Loop from the beginning. - for roomID := range uniqueRooms { - unvisited = append(unvisited, roomID) + // We need to invert the order here because the child events are lo->hi on the timestamp, + // so we need to ensure we pop in the same lo->hi order, which won't be the case if we + // insert the highest timestamp last in a stack. + for i := len(discoveredChildEvents) - 1; i >= 0; i-- { + spaceContent := struct { + Via []string `json:"via"` + }{} + ev := discoveredChildEvents[i] + _ = json.Unmarshal(ev.Content, &spaceContent) + unvisited = append(unvisited, roomVisit{ + roomID: ev.StateKey, + depth: rv.depth + 1, + vias: spaceContent.Via, + }) } } - return &res + + if len(unvisited) > 0 { + // we still have more rooms so we need to send back a pagination token, + // we probably hit a room limit + cache.processed = processed + cache.unvisited = unvisited + w.storePaginationCache(w.paginationToken, *cache) + } else { + // clear the pagination token so we don't send it back to the client + // Note we do NOT nuke the cache just in case this response is lost + // and the client retries it. + w.paginationToken = "" + } + + if w.caller != nil { + // return CS API format + return util.JSONResponse{ + Code: 200, + JSON: MSC2946ClientResponse{ + Rooms: discoveredRooms, + NextBatch: w.paginationToken, + }, + } + } + // return SS API format + // the first discovered room will be the room asked for, and subsequent ones the depth=1 children + if len(discoveredRooms) == 0 { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("room is unknown/forbidden"), + } + } + return util.JSONResponse{ + Code: 200, + JSON: gomatrixserverlib.MSC2946SpacesResponse{ + Room: discoveredRooms[0], + Children: discoveredRooms[1:], + }, + } } func (w *walker) stateEvent(roomID, evType, stateKey string) *gomatrixserverlib.HeaderedEvent { @@ -366,42 +409,19 @@ func (w *walker) publicRoomsChunk(roomID string) *gomatrixserverlib.PublicRoom { // federatedRoomInfo returns more of the spaces graph from another server. Returns nil if this was // unsuccessful. -func (w *walker) federatedRoomInfo(roomID string) (*gomatrixserverlib.MSC2946SpacesResponse, error) { +func (w *walker) federatedRoomInfo(roomID string, vias []string) (*gomatrixserverlib.MSC2946SpacesResponse, error) { // only do federated requests for client requests if w.caller == nil { return nil, nil } - // extract events which point to this room ID and extract their vias - events, err := w.db.References(w.ctx, roomID) - if err != nil { - return nil, fmt.Errorf("failed to get References events: %w", err) - } - vias := make(set) - for _, ev := range events { - if ev.StateKeyEquals(roomID) { - // event points at this room, extract vias - content := struct { - Vias []string `json:"via"` - }{} - if err = json.Unmarshal(ev.Content(), &content); err != nil { - continue // silently ignore corrupted state events - } - for _, v := range content.Vias { - vias[v] = true - } - } - } - util.GetLogger(w.ctx).Infof("Querying federatedRoomInfo via %+v", vias) + util.GetLogger(w.ctx).Infof("Querying %s via %+v", roomID, vias) ctx := context.Background() // query more of the spaces graph using these servers - for serverName := range vias { + for _, serverName := range vias { if serverName == string(w.thisServer) { continue } - res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, gomatrixserverlib.MSC2946SpacesRequest{ - Limit: w.req.Limit, - MaxRoomsPerSpace: w.req.MaxRoomsPerSpace, - }) + res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly) if err != nil { util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName) continue @@ -501,7 +521,7 @@ func (w *walker) authorisedUser(roomID string) bool { hisVisEv := queryRes.StateEvents[hisVisTuple] if memberEv != nil { membership, _ := memberEv.Membership() - if membership == gomatrixserverlib.Join { + if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite { return true } } @@ -514,40 +534,85 @@ func (w *walker) authorisedUser(roomID string) bool { return false } -// references returns all references pointing to or from this room. -func (w *walker) references(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) { - events, err := w.db.References(w.ctx, roomID) +// references returns all child references pointing to or from this room. +func (w *walker) childReferences(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) { + createTuple := gomatrixserverlib.StateKeyTuple{ + EventType: gomatrixserverlib.MRoomCreate, + StateKey: "", + } + var res roomserver.QueryCurrentStateResponse + err := w.rsAPI.QueryCurrentState(context.Background(), &roomserver.QueryCurrentStateRequest{ + RoomID: roomID, + AllowWildcards: true, + StateTuples: []gomatrixserverlib.StateKeyTuple{ + createTuple, { + EventType: ConstSpaceChildEventType, + StateKey: "*", + }, + }, + }, &res) if err != nil { return nil, err } - el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(events)) - for _, ev := range events { + + // don't return any child refs if the room is not a space room + if res.StateEvents[createTuple] != nil { + // escape the `.`s so gjson doesn't think it's nested + roomType := gjson.GetBytes(res.StateEvents[createTuple].Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str + if roomType != ConstCreateEventContentValueSpace { + return nil, nil + } + } + delete(res.StateEvents, createTuple) + + el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(res.StateEvents)) + for _, ev := range res.StateEvents { + content := gjson.ParseBytes(ev.Content()) // only return events that have a `via` key as per MSC1772 // else we'll incorrectly walk redacted events (as the link // is in the state_key) - if gjson.GetBytes(ev.Content(), "via").Exists() { + if content.Get("via").Exists() { strip := stripped(ev.Event) if strip == nil { continue } + // if suggested only and this child isn't suggested, skip it. + // if suggested only = false we include everything so don't need to check the content. + if w.suggestedOnly && !content.Get("suggested").Bool() { + continue + } el = append(el, *strip) } } + // sort by origin_server_ts as per MSC2946 + sort.Slice(el, func(i, j int) bool { + return el[i].OriginServerTS < el[j].OriginServerTS + }) + return el, nil } -type set map[string]bool +type set map[string]struct{} + +func (s set) set(val string) { + s[val] = struct{}{} +} +func (s set) isSet(val string) bool { + _, ok := s[val] + return ok +} func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEvent { if ev.StateKey() == nil { return nil } return &gomatrixserverlib.MSC2946StrippedEvent{ - Type: ev.Type(), - StateKey: *ev.StateKey(), - Content: ev.Content(), - Sender: ev.Sender(), - RoomID: ev.RoomID(), + Type: ev.Type(), + StateKey: *ev.StateKey(), + Content: ev.Content(), + Sender: ev.Sender(), + RoomID: ev.RoomID(), + OriginServerTS: ev.OriginServerTS(), } } @@ -567,3 +632,11 @@ func spaceTargetStripped(event *gomatrixserverlib.MSC2946StrippedEvent) string { } return "" } + +func parseInt(intstr string, defaultVal int) int { + i, err := strconv.ParseInt(intstr, 10, 32) + if err != nil { + return defaultVal + } + return int(i) +} diff --git a/setup/mscs/msc2946/msc2946_test.go b/setup/mscs/msc2946/msc2946_test.go deleted file mode 100644 index e8066c34d..000000000 --- a/setup/mscs/msc2946/msc2946_test.go +++ /dev/null @@ -1,464 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package msc2946_test - -import ( - "bytes" - "context" - "crypto/ed25519" - "encoding/json" - "io/ioutil" - "net/http" - "net/url" - "testing" - "time" - - "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/internal/hooks" - "github.com/matrix-org/dendrite/internal/httputil" - roomserver "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/mscs/msc2946" - userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" -) - -var ( - client = &http.Client{ - Timeout: 10 * time.Second, - } - roomVer = gomatrixserverlib.RoomVersionV6 -) - -// Basic sanity check of MSC2946 logic. Tests a single room with a few state events -// and a bit of recursion to subspaces. Makes a graph like: -// Root -// ____|_____ -// | | | -// R1 R2 S1 -// |_________ -// | | | -// R3 R4 S2 -// | <-- this link is just a parent, not a child -// R5 -// -// Alice is not joined to R4, but R4 is "world_readable". -func TestMSC2946(t *testing.T) { - alice := "@alice:localhost" - // give access token to alice - nopUserAPI := &testUserAPI{ - accessTokens: make(map[string]userapi.Device), - } - nopUserAPI.accessTokens["alice"] = userapi.Device{ - AccessToken: "alice", - DisplayName: "Alice", - UserID: alice, - } - rootSpace := "!rootspace:localhost" - subSpaceS1 := "!subspaceS1:localhost" - subSpaceS2 := "!subspaceS2:localhost" - room1 := "!room1:localhost" - room2 := "!room2:localhost" - room3 := "!room3:localhost" - room4 := "!room4:localhost" - empty := "" - room5 := "!room5:localhost" - allRooms := []string{ - rootSpace, subSpaceS1, subSpaceS2, - room1, room2, room3, room4, room5, - } - rootToR1 := mustCreateEvent(t, fledglingEvent{ - RoomID: rootSpace, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room1, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - rootToR2 := mustCreateEvent(t, fledglingEvent{ - RoomID: rootSpace, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room2, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - rootToS1 := mustCreateEvent(t, fledglingEvent{ - RoomID: rootSpace, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &subSpaceS1, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - s1ToR3 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room3, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - s1ToR4 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room4, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - s1ToS2 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &subSpaceS2, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - // This is a parent link only - s2ToR5 := mustCreateEvent(t, fledglingEvent{ - RoomID: room5, - Sender: alice, - Type: msc2946.ConstSpaceParentEventType, - StateKey: &subSpaceS2, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - // history visibility for R4 - r4HisVis := mustCreateEvent(t, fledglingEvent{ - RoomID: room4, - Sender: "@someone:localhost", - Type: gomatrixserverlib.MRoomHistoryVisibility, - StateKey: &empty, - Content: map[string]interface{}{ - "history_visibility": "world_readable", - }, - }) - var joinEvents []*gomatrixserverlib.HeaderedEvent - for _, roomID := range allRooms { - if roomID == room4 { - continue // not joined to that room - } - joinEvents = append(joinEvents, mustCreateEvent(t, fledglingEvent{ - RoomID: roomID, - Sender: alice, - StateKey: &alice, - Type: gomatrixserverlib.MRoomMember, - Content: map[string]interface{}{ - "membership": "join", - }, - })) - } - roomNameTuple := gomatrixserverlib.StateKeyTuple{ - EventType: "m.room.name", - StateKey: "", - } - hisVisTuple := gomatrixserverlib.StateKeyTuple{ - EventType: "m.room.history_visibility", - StateKey: "", - } - nopRsAPI := &testRoomserverAPI{ - joinEvents: joinEvents, - events: map[string]*gomatrixserverlib.HeaderedEvent{ - rootToR1.EventID(): rootToR1, - rootToR2.EventID(): rootToR2, - rootToS1.EventID(): rootToS1, - s1ToR3.EventID(): s1ToR3, - s1ToR4.EventID(): s1ToR4, - s1ToS2.EventID(): s1ToS2, - s2ToR5.EventID(): s2ToR5, - r4HisVis.EventID(): r4HisVis, - }, - pubRoomState: map[string]map[gomatrixserverlib.StateKeyTuple]string{ - rootSpace: { - roomNameTuple: "Root", - hisVisTuple: "shared", - }, - subSpaceS1: { - roomNameTuple: "Sub-Space 1", - hisVisTuple: "joined", - }, - subSpaceS2: { - roomNameTuple: "Sub-Space 2", - hisVisTuple: "shared", - }, - room1: { - hisVisTuple: "joined", - }, - room2: { - hisVisTuple: "joined", - }, - room3: { - hisVisTuple: "joined", - }, - room4: { - hisVisTuple: "world_readable", - }, - room5: { - hisVisTuple: "joined", - }, - }, - } - allEvents := []*gomatrixserverlib.HeaderedEvent{ - rootToR1, rootToR2, rootToS1, - s1ToR3, s1ToR4, s1ToS2, - s2ToR5, r4HisVis, - } - allEvents = append(allEvents, joinEvents...) - router := injectEvents(t, nopUserAPI, nopRsAPI, allEvents) - cancel := runServer(t, router) - defer cancel() - - t.Run("returns no events for unknown rooms", func(t *testing.T) { - res := postSpaces(t, 200, "alice", "!unknown:localhost", newReq(t, map[string]interface{}{})) - if len(res.Events) > 0 { - t.Errorf("got %d events, want 0", len(res.Events)) - } - if len(res.Rooms) > 0 { - t.Errorf("got %d rooms, want 0", len(res.Rooms)) - } - }) - t.Run("returns the entire graph", func(t *testing.T) { - res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{})) - if len(res.Events) != 7 { - t.Errorf("got %d events, want 7", len(res.Events)) - } - if len(res.Rooms) != len(allRooms) { - t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms)) - } - }) - t.Run("can update the graph", func(t *testing.T) { - // remove R3 from the graph - rmS1ToR3 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room3, - Content: map[string]interface{}{}, // redacted - }) - nopRsAPI.events[rmS1ToR3.EventID()] = rmS1ToR3 - hooks.Run(hooks.KindNewEventPersisted, rmS1ToR3) - - res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{})) - if len(res.Events) != 6 { // one less since we don't return redacted events - t.Errorf("got %d events, want 6", len(res.Events)) - } - if len(res.Rooms) != (len(allRooms) - 1) { // one less due to lack of R3 - t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms)-1) - } - }) -} - -func newReq(t *testing.T, jsonBody map[string]interface{}) *gomatrixserverlib.MSC2946SpacesRequest { - t.Helper() - b, err := json.Marshal(jsonBody) - if err != nil { - t.Fatalf("Failed to marshal request: %s", err) - } - var r gomatrixserverlib.MSC2946SpacesRequest - if err := json.Unmarshal(b, &r); err != nil { - t.Fatalf("Failed to unmarshal request: %s", err) - } - return &r -} - -func runServer(t *testing.T, router *mux.Router) func() { - t.Helper() - externalServ := &http.Server{ - Addr: string(":8010"), - WriteTimeout: 60 * time.Second, - Handler: router, - } - go func() { - externalServ.ListenAndServe() - }() - // wait to listen on the port - time.Sleep(500 * time.Millisecond) - return func() { - externalServ.Shutdown(context.TODO()) - } -} - -func postSpaces(t *testing.T, expectCode int, accessToken, roomID string, req *gomatrixserverlib.MSC2946SpacesRequest) *gomatrixserverlib.MSC2946SpacesResponse { - t.Helper() - var r gomatrixserverlib.MSC2946SpacesRequest - msc2946.Defaults(&r) - data, err := json.Marshal(req) - if err != nil { - t.Fatalf("failed to marshal request: %s", err) - } - httpReq, err := http.NewRequest( - "POST", "http://localhost:8010/_matrix/client/unstable/org.matrix.msc2946/rooms/"+url.PathEscape(roomID)+"/spaces", - bytes.NewBuffer(data), - ) - httpReq.Header.Set("Authorization", "Bearer "+accessToken) - if err != nil { - t.Fatalf("failed to prepare request: %s", err) - } - res, err := client.Do(httpReq) - if err != nil { - t.Fatalf("failed to do request: %s", err) - } - if res.StatusCode != expectCode { - body, _ := ioutil.ReadAll(res.Body) - t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body)) - } - if res.StatusCode == 200 { - var result gomatrixserverlib.MSC2946SpacesResponse - body, err := ioutil.ReadAll(res.Body) - if err != nil { - t.Fatalf("response 200 OK but failed to read response body: %s", err) - } - t.Logf("Body: %s", string(body)) - if err := json.Unmarshal(body, &result); err != nil { - t.Fatalf("response 200 OK but failed to deserialise JSON : %s\nbody: %s", err, string(body)) - } - return &result - } - return nil -} - -type testUserAPI struct { - userapi.UserInternalAPITrace - accessTokens map[string]userapi.Device -} - -func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error { - dev, ok := u.accessTokens[req.AccessToken] - if !ok { - res.Err = "unknown token" - return nil - } - res.Device = &dev - return nil -} - -type testRoomserverAPI struct { - // use a trace API as it implements method stubs so we don't need to have them here. - // We'll override the functions we care about. - roomserver.RoomserverInternalAPITrace - joinEvents []*gomatrixserverlib.HeaderedEvent - events map[string]*gomatrixserverlib.HeaderedEvent - pubRoomState map[string]map[gomatrixserverlib.StateKeyTuple]string -} - -func (r *testRoomserverAPI) QueryServerJoinedToRoom(ctx context.Context, req *roomserver.QueryServerJoinedToRoomRequest, res *roomserver.QueryServerJoinedToRoomResponse) error { - res.IsInRoom = true - res.RoomExists = true - return nil -} - -func (r *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *roomserver.QueryBulkStateContentRequest, res *roomserver.QueryBulkStateContentResponse) error { - res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string) - for _, roomID := range req.RoomIDs { - pubRoomData, ok := r.pubRoomState[roomID] - if ok { - res.Rooms[roomID] = pubRoomData - } - } - return nil -} - -func (r *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *roomserver.QueryCurrentStateRequest, res *roomserver.QueryCurrentStateResponse) error { - res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) - checkEvent := func(he *gomatrixserverlib.HeaderedEvent) { - if he.RoomID() != req.RoomID { - return - } - if he.StateKey() == nil { - return - } - tuple := gomatrixserverlib.StateKeyTuple{ - EventType: he.Type(), - StateKey: *he.StateKey(), - } - for _, t := range req.StateTuples { - if t == tuple { - res.StateEvents[t] = he - } - } - } - for _, he := range r.joinEvents { - checkEvent(he) - } - for _, he := range r.events { - checkEvent(he) - } - return nil -} - -func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*gomatrixserverlib.HeaderedEvent) *mux.Router { - t.Helper() - cfg := &config.Dendrite{} - cfg.Defaults(true) - cfg.Global.ServerName = "localhost" - cfg.MSCs.Database.ConnectionString = "file:msc2946_test.db" - cfg.MSCs.MSCs = []string{"msc2946"} - base := &base.BaseDendrite{ - Cfg: cfg, - PublicClientAPIMux: mux.NewRouter().PathPrefix(httputil.PublicClientPathPrefix).Subrouter(), - PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(), - } - - err := msc2946.Enable(base, rsAPI, userAPI, nil, nil) - if err != nil { - t.Fatalf("failed to enable MSC2946: %s", err) - } - for _, ev := range events { - hooks.Run(hooks.KindNewEventPersisted, ev) - } - return base.PublicClientAPIMux -} - -type fledglingEvent struct { - Type string - StateKey *string - Content interface{} - Sender string - RoomID string -} - -func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) { - t.Helper() - seed := make([]byte, ed25519.SeedSize) // zero seed - key := ed25519.NewKeyFromSeed(seed) - eb := gomatrixserverlib.EventBuilder{ - Sender: ev.Sender, - Depth: 999, - Type: ev.Type, - StateKey: ev.StateKey, - RoomID: ev.RoomID, - } - err := eb.SetContent(ev.Content) - if err != nil { - t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content) - } - // make sure the origin_server_ts changes so we can test recency - time.Sleep(1 * time.Millisecond) - signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer) - if err != nil { - t.Fatalf("mustCreateEvent: failed to sign event: %s", err) - } - h := signedEvent.Headered(roomVer) - return h -} diff --git a/setup/mscs/msc2946/storage.go b/setup/mscs/msc2946/storage.go deleted file mode 100644 index 20db18594..000000000 --- a/setup/mscs/msc2946/storage.go +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package msc2946 - -import ( - "context" - "database/sql" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" -) - -var ( - relTypes = map[string]int{ - ConstSpaceChildEventType: 1, - ConstSpaceParentEventType: 2, - } -) - -type Database interface { - // StoreReference persists a child or parent space mapping. - StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error - // References returns all events which have the given roomID as a parent or child space. - References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) -} - -type DB struct { - db *sql.DB - writer sqlutil.Writer - insertEdgeStmt *sql.Stmt - selectEdgesStmt *sql.Stmt -} - -// NewDatabase loads the database for msc2836 -func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) { - if dbOpts.ConnectionString.IsPostgres() { - return newPostgresDatabase(dbOpts) - } - return newSQLiteDatabase(dbOpts) -} - -func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { - d := DB{ - writer: sqlutil.NewDummyWriter(), - } - var err error - if d.db, err = sqlutil.Open(dbOpts); err != nil { - return nil, err - } - _, err = d.db.Exec(` - CREATE TABLE IF NOT EXISTS msc2946_edges ( - room_version TEXT NOT NULL, - -- the room ID of the event, the source of the arrow - source_room_id TEXT NOT NULL, - -- the target room ID, the arrow destination - dest_room_id TEXT NOT NULL, - -- the kind of relation, either child or parent (1,2) - rel_type SMALLINT NOT NULL, - event_json TEXT NOT NULL, - CONSTRAINT msc2946_edges_uniq UNIQUE (source_room_id, dest_room_id, rel_type) - ); - `) - if err != nil { - return nil, err - } - if d.insertEdgeStmt, err = d.db.Prepare(` - INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json) - VALUES($1, $2, $3, $4, $5) - ON CONFLICT ON CONSTRAINT msc2946_edges_uniq DO UPDATE SET event_json = $5 - `); err != nil { - return nil, err - } - if d.selectEdgesStmt, err = d.db.Prepare(` - SELECT room_version, event_json FROM msc2946_edges - WHERE source_room_id = $1 OR dest_room_id = $2 - `); err != nil { - return nil, err - } - return &d, err -} - -func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { - d := DB{ - writer: sqlutil.NewExclusiveWriter(), - } - var err error - if d.db, err = sqlutil.Open(dbOpts); err != nil { - return nil, err - } - _, err = d.db.Exec(` - CREATE TABLE IF NOT EXISTS msc2946_edges ( - room_version TEXT NOT NULL, - -- the room ID of the event, the source of the arrow - source_room_id TEXT NOT NULL, - -- the target room ID, the arrow destination - dest_room_id TEXT NOT NULL, - -- the kind of relation, either child or parent (1,2) - rel_type SMALLINT NOT NULL, - event_json TEXT NOT NULL, - UNIQUE (source_room_id, dest_room_id, rel_type) - ); - `) - if err != nil { - return nil, err - } - if d.insertEdgeStmt, err = d.db.Prepare(` - INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json) - VALUES($1, $2, $3, $4, $5) - ON CONFLICT (source_room_id, dest_room_id, rel_type) DO UPDATE SET event_json = $5 - `); err != nil { - return nil, err - } - if d.selectEdgesStmt, err = d.db.Prepare(` - SELECT room_version, event_json FROM msc2946_edges - WHERE source_room_id = $1 OR dest_room_id = $2 - `); err != nil { - return nil, err - } - return &d, err -} - -func (d *DB) StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error { - target := SpaceTarget(he) - if target == "" { - return nil // malformed event - } - relType := relTypes[he.Type()] - _, err := d.insertEdgeStmt.ExecContext(ctx, he.RoomVersion, he.RoomID(), target, relType, he.JSON()) - return err -} - -func (d *DB) References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) { - rows, err := d.selectEdgesStmt.QueryContext(ctx, roomID, roomID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "failed to close References") - refs := make([]*gomatrixserverlib.HeaderedEvent, 0) - for rows.Next() { - var roomVer string - var jsonBytes []byte - if err := rows.Scan(&roomVer, &jsonBytes); err != nil { - return nil, err - } - ev, err := gomatrixserverlib.NewEventFromTrustedJSON(jsonBytes, false, gomatrixserverlib.RoomVersion(roomVer)) - if err != nil { - return nil, err - } - he := ev.Headered(gomatrixserverlib.RoomVersion(roomVer)) - refs = append(refs, he) - } - return refs, nil -} - -// SpaceTarget returns the destination room ID for the space event. This is either a child or a parent -// depending on the event type. -func SpaceTarget(he *gomatrixserverlib.HeaderedEvent) string { - if he.StateKey() == nil { - return "" // no-op - } - switch he.Type() { - case ConstSpaceParentEventType: - return *he.StateKey() - case ConstSpaceChildEventType: - return *he.StateKey() - } - return "" -} From 1a79060b46d3b9c6dc0083fd7b20a2471f795a90 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 1 Mar 2022 14:16:47 +0000 Subject: [PATCH 12/24] Bump GMSL version --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index e44e7393d..ddf3f2175 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902 github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 diff --git a/go.sum b/go.sum index 43b363d30..74a181528 100644 --- a/go.sum +++ b/go.sum @@ -985,6 +985,8 @@ github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6 h1:CqbM5tPbF1mV2h1J6N0NSF8et6HvFlwA98TLCUcjiSI= github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902 h1:WHlrE8BYh/hzn1RKwq3YMAlhHivX47jQKAjZFtkJyPE= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= From 18e3c40da43017a9e1abe775a406844aca5e1643 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 1 Mar 2022 14:22:59 +0000 Subject: [PATCH 13/24] Always send [] from federated rooms, not null --- setup/mscs/msc2946/msc2946.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 3bb56f4b3..bf26e9910 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -426,6 +426,18 @@ func (w *walker) federatedRoomInfo(roomID string, vias []string) (*gomatrixserve util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName) continue } + // ensure nil slices are empty as we send this to the client sometimes + if res.Room.ChildrenState == nil { + res.Room.ChildrenState = []gomatrixserverlib.MSC2946StrippedEvent{} + } + for i := 0; i < len(res.Children); i++ { + child := res.Children[i] + if child.ChildrenState == nil { + child.ChildrenState = []gomatrixserverlib.MSC2946StrippedEvent{} + } + res.Children[i] = child + } + return &res, nil } return nil, nil From 471fda810a24f8945f5da3a76d2730a1ffe80166 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 1 Mar 2022 14:39:06 +0000 Subject: [PATCH 14/24] Remove unnecessary error line (#2237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously this error line would print because we were pulling out all user memberships, but now this is no longer necessary — an event state key that we don't know will no longer get passed to `SelectJoinedUsersSetForRooms` at all. --- roomserver/storage/shared/storage.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index f87782776..dd49f35ca 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -13,7 +13,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - "github.com/sirupsen/logrus" "github.com/tidwall/gjson" ) @@ -1186,9 +1185,6 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [ stateKeyNIDs[i] = nid i++ } - if len(nidToUserID) != len(userNIDToCount) { - logrus.Warnf("SelectJoinedUsersSetForRooms found %d users but BulkSelectEventStateKey only returned state key NIDs for %d of them", len(userNIDToCount), len(nidToUserID)) - } result := make(map[string]int, len(userNIDToCount)) for nid, count := range userNIDToCount { result[nidToUserID[nid]] = count From af610df85a6a2145d5ddcd6419bb8085ae224207 Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Tue, 1 Mar 2022 15:39:56 +0100 Subject: [PATCH 15/24] Return state on calls to /message and lazy load members (#2218) Co-authored-by: Neil Alexander --- syncapi/routing/context.go | 4 +-- syncapi/routing/context_test.go | 6 ++-- syncapi/routing/messages.go | 61 +++++++++++++++++++-------------- sytest-whitelist | 1 + 4 files changed, 42 insertions(+), 30 deletions(-) diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index ef7efa2b0..591139717 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -44,7 +44,7 @@ func Context( syncDB storage.Database, roomID, eventID string, ) util.JSONResponse { - filter, err := parseContextParams(req) + filter, err := parseRoomEventFilter(req) if err != nil { errMsg := "" switch err.(type) { @@ -164,7 +164,7 @@ func applyLazyLoadMembers(filter *gomatrixserverlib.RoomEventFilter, eventsAfter return newState } -func parseContextParams(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) { +func parseRoomEventFilter(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) { // Default room filter filter := &gomatrixserverlib.RoomEventFilter{Limit: 10} diff --git a/syncapi/routing/context_test.go b/syncapi/routing/context_test.go index 1b430d83a..e79a5d5f1 100644 --- a/syncapi/routing/context_test.go +++ b/syncapi/routing/context_test.go @@ -55,13 +55,13 @@ func Test_parseContextParams(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotFilter, err := parseContextParams(tt.req) + gotFilter, err := parseRoomEventFilter(tt.req) if (err != nil) != tt.wantErr { - t.Errorf("parseContextParams() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("parseRoomEventFilter() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(gotFilter, tt.wantFilter) { - t.Errorf("parseContextParams() gotFilter = %v, want %v", gotFilter, tt.wantFilter) + t.Errorf("parseRoomEventFilter() gotFilter = %v, want %v", gotFilter, tt.wantFilter) } }) } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 9bb8c6d24..7cd54eef0 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -19,7 +19,6 @@ import ( "fmt" "net/http" "sort" - "strconv" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" @@ -45,8 +44,8 @@ type messagesReq struct { fromStream *types.StreamingToken device *userapi.Device wasToProvided bool - limit int backwardOrdering bool + filter *gomatrixserverlib.RoomEventFilter } type messagesResp struct { @@ -54,10 +53,9 @@ type messagesResp struct { StartStream string `json:"start_stream,omitempty"` // NOTSPEC: so clients can hit /messages then immediately /sync with a latest sync token End string `json:"end"` Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` + State []gomatrixserverlib.ClientEvent `json:"state"` } -const defaultMessagesLimit = 10 - // OnIncomingMessagesRequest implements the /messages endpoint from the // client-server API. // See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages @@ -83,6 +81,14 @@ func OnIncomingMessagesRequest( } } + filter, err := parseRoomEventFilter(req) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("unable to parse filter"), + } + } + // Extract parameters from the request's URL. // Pagination tokens. var fromStream *types.StreamingToken @@ -143,18 +149,6 @@ func OnIncomingMessagesRequest( wasToProvided = false } - // Maximum number of events to return; defaults to 10. - limit := defaultMessagesLimit - if len(req.URL.Query().Get("limit")) > 0 { - limit, err = strconv.Atoi(req.URL.Query().Get("limit")) - - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("limit could not be parsed into an integer: " + err.Error()), - } - } - } // TODO: Implement filtering (#587) // Check the room ID's format. @@ -176,7 +170,7 @@ func OnIncomingMessagesRequest( to: &to, fromStream: fromStream, wasToProvided: wasToProvided, - limit: limit, + filter: filter, backwardOrdering: backwardOrdering, device: device, } @@ -187,10 +181,27 @@ func OnIncomingMessagesRequest( return jsonerror.InternalServerError() } + // at least fetch the membership events for the users returned in chunk if LazyLoadMembers is set + state := []gomatrixserverlib.ClientEvent{} + if filter.LazyLoadMembers { + memberShipToUser := make(map[string]*gomatrixserverlib.HeaderedEvent) + for _, evt := range clientEvents { + memberShip, err := db.GetStateEvent(req.Context(), roomID, gomatrixserverlib.MRoomMember, evt.Sender) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("failed to get membership event for user") + continue + } + memberShipToUser[evt.Sender] = memberShip + } + for _, evt := range memberShipToUser { + state = append(state, gomatrixserverlib.HeaderedToClientEvent(evt, gomatrixserverlib.FormatAll)) + } + } + util.GetLogger(req.Context()).WithFields(logrus.Fields{ "from": from.String(), "to": to.String(), - "limit": limit, + "limit": filter.Limit, "backwards": backwardOrdering, "return_start": start.String(), "return_end": end.String(), @@ -200,6 +211,7 @@ func OnIncomingMessagesRequest( Chunk: clientEvents, Start: start.String(), End: end.String(), + State: state, } if emptyFromSupplied { res.StartStream = fromStream.String() @@ -234,19 +246,18 @@ func (r *messagesReq) retrieveEvents() ( clientEvents []gomatrixserverlib.ClientEvent, start, end types.TopologyToken, err error, ) { - eventFilter := gomatrixserverlib.DefaultRoomEventFilter() - eventFilter.Limit = r.limit + eventFilter := r.filter // Retrieve the events from the local database. var streamEvents []types.StreamEvent if r.fromStream != nil { toStream := r.to.StreamToken() streamEvents, err = r.db.GetEventsInStreamingRange( - r.ctx, r.fromStream, &toStream, r.roomID, &eventFilter, r.backwardOrdering, + r.ctx, r.fromStream, &toStream, r.roomID, eventFilter, r.backwardOrdering, ) } else { streamEvents, err = r.db.GetEventsInTopologicalRange( - r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering, + r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering, ) } if err != nil { @@ -434,7 +445,7 @@ func (r *messagesReq) handleEmptyEventsSlice() ( // Check if we have backward extremities for this room. if len(backwardExtremities) > 0 { // If so, retrieve as much events as needed through backfilling. - events, err = r.backfill(r.roomID, backwardExtremities, r.limit) + events, err = r.backfill(r.roomID, backwardExtremities, r.filter.Limit) if err != nil { return } @@ -456,7 +467,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent events []*gomatrixserverlib.HeaderedEvent, err error, ) { // Check if we have enough events. - isSetLargeEnough := len(streamEvents) >= r.limit + isSetLargeEnough := len(streamEvents) >= r.filter.Limit if !isSetLargeEnough { // it might be fine we don't have up to 'limit' events, let's find out if r.backwardOrdering { @@ -483,7 +494,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent if len(backwardExtremities) > 0 && !isSetLargeEnough && r.backwardOrdering { var pdus []*gomatrixserverlib.HeaderedEvent // Only ask the remote server for enough events to reach the limit. - pdus, err = r.backfill(r.roomID, backwardExtremities, r.limit-len(streamEvents)) + pdus, err = r.backfill(r.roomID, backwardExtremities, r.filter.Limit-len(streamEvents)) if err != nil { return } diff --git a/sytest-whitelist b/sytest-whitelist index 12522cfb3..2dd56d260 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -597,6 +597,7 @@ Device list doesn't change if remote server is down /context/ on non world readable room does not work /context/ returns correct number of events /context/ with lazy_load_members filter works +GET /rooms/:room_id/messages lazy loads members correctly Can query remote device keys using POST after notification Device deletion propagates over federation Get left notifs in sync and /keys/changes when other user leaves From 8dfc958dddc5f7125633f9b5cf55371ab35ceb6c Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 1 Mar 2022 14:40:47 +0000 Subject: [PATCH 16/24] Also don't send null back when the target room isn't a space room --- setup/mscs/msc2946/msc2946.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index bf26e9910..f53e7b249 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -572,7 +572,7 @@ func (w *walker) childReferences(roomID string) ([]gomatrixserverlib.MSC2946Stri // escape the `.`s so gjson doesn't think it's nested roomType := gjson.GetBytes(res.StateEvents[createTuple].Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str if roomType != ConstCreateEventContentValueSpace { - return nil, nil + return []gomatrixserverlib.MSC2946StrippedEvent{}, nil } } delete(res.StateEvents, createTuple) From ae840590b643bd74786f3e8a869a8ea9e769995b Mon Sep 17 00:00:00 2001 From: kegsay Date: Tue, 1 Mar 2022 16:03:54 +0000 Subject: [PATCH 17/24] Make complement go fast (#2240) --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 124940f71..4a1720295 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -63,7 +63,7 @@ jobs: # Run Complement - run: | set -o pipefail && - go test -v -p 1 -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt + go test -v -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt shell: bash name: Run Complement Tests env: From 352e63915f110cbe4907349a7e59f43f179657e6 Mon Sep 17 00:00:00 2001 From: kegsay Date: Tue, 1 Mar 2022 16:32:48 +0000 Subject: [PATCH 18/24] msc2946: add federation cache (#2238) --- internal/caching/cache_space_rooms.go | 28 +++++++++++++++++++++++++ internal/caching/caches.go | 1 + internal/caching/impl_inmemorylru.go | 12 ++++++++++- setup/mscs/msc2946/msc2946.go | 30 ++++++++++++++++----------- setup/mscs/mscs.go | 2 +- 5 files changed, 59 insertions(+), 14 deletions(-) create mode 100644 internal/caching/cache_space_rooms.go diff --git a/internal/caching/cache_space_rooms.go b/internal/caching/cache_space_rooms.go new file mode 100644 index 000000000..71b81abcf --- /dev/null +++ b/internal/caching/cache_space_rooms.go @@ -0,0 +1,28 @@ +package caching + +import "github.com/matrix-org/gomatrixserverlib" + +const ( + SpaceSummaryRoomsCacheName = "space_summary_rooms" + SpaceSummaryRoomsCacheMaxEntries = 100 + SpaceSummaryRoomsCacheMutable = true +) + +type SpaceSummaryRoomsCache interface { + GetSpaceSummary(roomID string) (r gomatrixserverlib.MSC2946SpacesResponse, ok bool) + StoreSpaceSummary(roomID string, r gomatrixserverlib.MSC2946SpacesResponse) +} + +func (c Caches) GetSpaceSummary(roomID string) (r gomatrixserverlib.MSC2946SpacesResponse, ok bool) { + val, found := c.SpaceSummaryRooms.Get(roomID) + if found && val != nil { + if resp, ok := val.(gomatrixserverlib.MSC2946SpacesResponse); ok { + return resp, true + } + } + return r, false +} + +func (c Caches) StoreSpaceSummary(roomID string, r gomatrixserverlib.MSC2946SpacesResponse) { + c.SpaceSummaryRooms.Set(roomID, r) +} diff --git a/internal/caching/caches.go b/internal/caching/caches.go index e1642a663..39926d735 100644 --- a/internal/caching/caches.go +++ b/internal/caching/caches.go @@ -10,6 +10,7 @@ type Caches struct { RoomServerRoomIDs Cache // RoomServerNIDsCache RoomInfos Cache // RoomInfoCache FederationEvents Cache // FederationEventsCache + SpaceSummaryRooms Cache // SpaceSummaryRoomsCache } // Cache is the interface that an implementation must satisfy. diff --git a/internal/caching/impl_inmemorylru.go b/internal/caching/impl_inmemorylru.go index ccb92852b..7b21f1362 100644 --- a/internal/caching/impl_inmemorylru.go +++ b/internal/caching/impl_inmemorylru.go @@ -55,9 +55,18 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { if err != nil { return nil, err } + spaceRooms, err := NewInMemoryLRUCachePartition( + SpaceSummaryRoomsCacheName, + SpaceSummaryRoomsCacheMutable, + SpaceSummaryRoomsCacheMaxEntries, + enablePrometheus, + ) + if err != nil { + return nil, err + } go cacheCleaner( roomVersions, serverKeys, roomServerRoomIDs, - roomInfos, federationEvents, + roomInfos, federationEvents, spaceRooms, ) return &Caches{ RoomVersions: roomVersions, @@ -65,6 +74,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { RoomServerRoomIDs: roomServerRoomIDs, RoomInfos: roomInfos, FederationEvents: federationEvents, + SpaceSummaryRooms: spaceRooms, }, nil } diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index f53e7b249..7ab50c32e 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -30,6 +30,7 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/clientapi/jsonerror" fs "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" roomserver "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" @@ -54,9 +55,9 @@ type MSC2946ClientResponse struct { // Enable this MSC func Enable( base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI, - fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, + fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, cache caching.SpaceSummaryRoomsCache, ) error { - clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, base.Cfg.Global.ServerName)) + clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, cache, base.Cfg.Global.ServerName)) base.PublicClientAPIMux.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) @@ -74,7 +75,7 @@ func Enable( return util.ErrorResponse(err) } roomID := params["roomID"] - return federatedSpacesHandler(req.Context(), fedReq, roomID, rsAPI, fsAPI, base.Cfg.Global.ServerName) + return federatedSpacesHandler(req.Context(), fedReq, roomID, cache, rsAPI, fsAPI, base.Cfg.Global.ServerName) }, ) base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet) @@ -84,6 +85,7 @@ func Enable( func federatedSpacesHandler( ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, + cache caching.SpaceSummaryRoomsCache, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, thisServer gomatrixserverlib.ServerName, ) util.JSONResponse { @@ -100,6 +102,7 @@ func federatedSpacesHandler( serverName: fedReq.Origin(), thisServer: thisServer, ctx: ctx, + cache: cache, suggestedOnly: u.Query().Get("suggested_only") == "true", limit: 1000, // The main difference is that it does not recurse into spaces and does not support pagination. @@ -115,7 +118,9 @@ func federatedSpacesHandler( } func spacesHandler( - rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, + rsAPI roomserver.RoomserverInternalAPI, + fsAPI fs.FederationInternalAPI, + cache caching.SpaceSummaryRoomsCache, thisServer gomatrixserverlib.ServerName, ) func(*http.Request, *userapi.Device) util.JSONResponse { // declared outside the returned handler so it persists between calls @@ -138,6 +143,7 @@ func spacesHandler( caller: device, thisServer: thisServer, ctx: req.Context(), + cache: cache, rsAPI: rsAPI, fsAPI: fsAPI, @@ -160,6 +166,7 @@ type walker struct { rsAPI roomserver.RoomserverInternalAPI fsAPI fs.FederationInternalAPI ctx context.Context + cache caching.SpaceSummaryRoomsCache suggestedOnly bool limit int maxDepth int @@ -169,13 +176,6 @@ type walker struct { mu sync.Mutex } -func (w *walker) callerID() string { - if w.caller != nil { - return w.caller.UserID + "|" + w.caller.ID - } - return string(w.serverName) -} - func (w *walker) newPaginationCache() (string, paginationInfo) { p := paginationInfo{ processed: make(set), @@ -414,7 +414,12 @@ func (w *walker) federatedRoomInfo(roomID string, vias []string) (*gomatrixserve if w.caller == nil { return nil, nil } - util.GetLogger(w.ctx).Infof("Querying %s via %+v", roomID, vias) + resp, ok := w.cache.GetSpaceSummary(roomID) + if ok { + util.GetLogger(w.ctx).Debugf("Returning cached response for %s", roomID) + return &resp, nil + } + util.GetLogger(w.ctx).Debugf("Querying %s via %+v", roomID, vias) ctx := context.Background() // query more of the spaces graph using these servers for _, serverName := range vias { @@ -437,6 +442,7 @@ func (w *walker) federatedRoomInfo(roomID string, vias []string) (*gomatrixserve } res.Children[i] = child } + w.cache.StoreSpaceSummary(roomID, res) return &res, nil } diff --git a/setup/mscs/mscs.go b/setup/mscs/mscs.go index 10df6a15d..35b7bba3b 100644 --- a/setup/mscs/mscs.go +++ b/setup/mscs/mscs.go @@ -42,7 +42,7 @@ func EnableMSC(base *base.BaseDendrite, monolith *setup.Monolith, msc string) er case "msc2836": return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing) case "msc2946": - return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI, monolith.FederationAPI, monolith.KeyRing) + return msc2946.Enable(base, monolith.RoomserverAPI, monolith.UserAPI, monolith.FederationAPI, monolith.KeyRing, base.Caches) case "msc2444": // enabled inside federationapi case "msc2753": // enabled inside clientapi default: From cda2452ba00afffa9a73870ca09047ce52dd28c7 Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Tue, 1 Mar 2022 17:39:57 +0100 Subject: [PATCH 19/24] Only allow device deletion from session UIA was initiated from (#2235) * Only allow device deletion if the session matches * Make the challenge response available to other packages * Remove userID, as it's not in the spec * Remove tests * Add passing test & remove obsolete config * Rename field, add comment Co-authored-by: Neil Alexander --- clientapi/auth/user_interactive.go | 24 ++++++++++++---------- clientapi/routing/device.go | 33 ++++++++++++++++++++++++++++++ clientapi/routing/register.go | 26 ++++++++++++++++++++--- clientapi/routing/register_test.go | 10 +++++++++ dendrite-config.yaml | 5 ----- sytest-whitelist | 2 ++ 6 files changed, 81 insertions(+), 19 deletions(-) diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go index 9cab7956c..4db75809f 100644 --- a/clientapi/auth/user_interactive.go +++ b/clientapi/auth/user_interactive.go @@ -144,21 +144,23 @@ func (u *UserInteractive) AddCompletedStage(sessionID, authType string) { delete(u.Sessions, sessionID) } +type Challenge struct { + Completed []string `json:"completed"` + Flows []userInteractiveFlow `json:"flows"` + Session string `json:"session"` + // TODO: Return any additional `params` + Params map[string]interface{} `json:"params"` +} + // Challenge returns an HTTP 401 with the supported flows for authenticating func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse { return &util.JSONResponse{ Code: 401, - JSON: struct { - Completed []string `json:"completed"` - Flows []userInteractiveFlow `json:"flows"` - Session string `json:"session"` - // TODO: Return any additional `params` - Params map[string]interface{} `json:"params"` - }{ - u.Completed, - u.Flows, - sessionID, - make(map[string]interface{}), + JSON: Challenge{ + Completed: u.Completed, + Flows: u.Flows, + Session: sessionID, + Params: make(map[string]interface{}), }, } } diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index 9f54a625a..161bc2731 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/tidwall/gjson" ) // https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-devices @@ -163,6 +164,15 @@ func DeleteDeviceById( req *http.Request, userInteractiveAuth *auth.UserInteractive, userAPI api.UserInternalAPI, device *api.Device, deviceID string, ) util.JSONResponse { + var ( + deleteOK bool + sessionID string + ) + defer func() { + if deleteOK { + sessions.deleteSession(sessionID) + } + }() ctx := req.Context() defer req.Body.Close() // nolint:errcheck bodyBytes, err := ioutil.ReadAll(req.Body) @@ -172,8 +182,29 @@ func DeleteDeviceById( JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()), } } + + // check that we know this session, and it matches with the device to delete + s := gjson.GetBytes(bodyBytes, "auth.session").Str + if dev, ok := sessions.getDeviceToDelete(s); ok { + if dev != deviceID { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("session & device mismatch"), + } + } + } + + if s != "" { + sessionID = s + } + login, errRes := userInteractiveAuth.Verify(ctx, bodyBytes, device) if errRes != nil { + switch data := errRes.JSON.(type) { + case auth.Challenge: + sessions.addDeviceToDelete(data.Session, deviceID) + default: + } return *errRes } @@ -201,6 +232,8 @@ func DeleteDeviceById( return jsonerror.InternalServerError() } + deleteOK = true + return util.JSONResponse{ Code: http.StatusOK, JSON: struct{}{}, diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 10cfa4325..c97876594 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -76,6 +76,10 @@ type sessionsDict struct { sessions map[string][]authtypes.LoginType params map[string]registerRequest timer map[string]*time.Timer + // deleteSessionToDeviceID protects requests to DELETE /devices/{deviceID} from being abused. + // If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2, + // the delete request will fail for device2 since the UIA was initiated by trying to delete device1. + deleteSessionToDeviceID map[string]string } // defaultTimeout is the timeout used to clean up sessions @@ -115,6 +119,7 @@ func (d *sessionsDict) deleteSession(sessionID string) { defer d.Unlock() delete(d.params, sessionID) delete(d.sessions, sessionID) + delete(d.deleteSessionToDeviceID, sessionID) // stop the timer, e.g. because the registration was completed if t, ok := d.timer[sessionID]; ok { if !t.Stop() { @@ -129,9 +134,10 @@ func (d *sessionsDict) deleteSession(sessionID string) { func newSessionsDict() *sessionsDict { return &sessionsDict{ - sessions: make(map[string][]authtypes.LoginType), - params: make(map[string]registerRequest), - timer: make(map[string]*time.Timer), + sessions: make(map[string][]authtypes.LoginType), + params: make(map[string]registerRequest), + timer: make(map[string]*time.Timer), + deleteSessionToDeviceID: make(map[string]string), } } @@ -165,6 +171,20 @@ func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtype d.sessions[sessionID] = append(sessions.sessions[sessionID], stage) } +func (d *sessionsDict) addDeviceToDelete(sessionID, deviceID string) { + d.startTimer(defaultTimeOut, sessionID) + d.Lock() + defer d.Unlock() + d.deleteSessionToDeviceID[sessionID] = deviceID +} + +func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) { + d.RLock() + defer d.RUnlock() + deviceID, ok := d.deleteSessionToDeviceID[sessionID] + return deviceID, ok +} + var ( sessions = newSessionsDict() validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index c6b7e61cf..520f5ddeb 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -242,6 +242,7 @@ func TestSessionCleanUp(t *testing.T) { s.addParams(dummySession, registerRequest{Username: "Testing"}) s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha) s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy) + s.addDeviceToDelete(dummySession, "dummyDevice") s.getCompletedStages(dummySession) // reset the timer with a lower timeout s.startTimer(time.Millisecond, dummySession) @@ -249,5 +250,14 @@ func TestSessionCleanUp(t *testing.T) { if data, ok := s.getParams(dummySession); ok { t.Errorf("expected session to be deleted: %+v", data) } + if _, ok := s.timer[dummySession]; ok { + t.Error("expected timer to be delete") + } + if _, ok := s.sessions[dummySession]; ok { + t.Error("expected session to be delete") + } + if _, ok := s.getDeviceToDelete(dummySession); ok { + t.Error("expected session to device to be delete") + } }) } \ No newline at end of file diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 6d086ed77..533b5c952 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -345,11 +345,6 @@ user_api: max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 - device_database: - connection_string: file:userapi_devices.db - max_open_conns: 10 - max_idle_conns: 2 - conn_max_lifetime: -1 # The length of time that a token issued for a relying party from # /_matrix/client/r0/user/{userId}/openid/request_token endpoint # is considered to be valid in milliseconds. diff --git a/sytest-whitelist b/sytest-whitelist index 2dd56d260..0cb80b7b6 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -605,4 +605,6 @@ Remote banned user is kicked and may not rejoin until unbanned registration remembers parameters registration accepts non-ascii passwords registration with inhibit_login inhibits login +The operation must be consistent through an interactive authentication session +Multiple calls to /sync should not cause 500 errors From 726529fe996519c93f4f329c03a968a432b0bb0e Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 1 Mar 2022 16:59:11 +0000 Subject: [PATCH 20/24] Hopefully fix read receipts (#2241) --- syncapi/storage/postgres/receipt_table.go | 2 +- syncapi/storage/sqlite3/receipt_table.go | 2 +- syncapi/streams/stream_receipt.go | 5 ++--- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go index f93081e1a..474d0c020 100644 --- a/syncapi/storage/postgres/receipt_table.go +++ b/syncapi/storage/postgres/receipt_table.go @@ -96,7 +96,7 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room } func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) { - lastPos := streamPos + var lastPos types.StreamPosition rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos) if err != nil { return 0, nil, fmt.Errorf("unable to query room receipts: %w", err) diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go index 6b39ee879..9111a39f6 100644 --- a/syncapi/storage/sqlite3/receipt_table.go +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -101,7 +101,7 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room // SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) { selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1) - lastPos := streamPos + var lastPos types.StreamPosition params := make([]interface{}, len(roomIDs)+1) params[0] = streamPos for k, v := range roomIDs { diff --git a/syncapi/streams/stream_receipt.go b/syncapi/streams/stream_receipt.go index cccadb525..35ffd3a1e 100644 --- a/syncapi/streams/stream_receipt.go +++ b/syncapi/streams/stream_receipt.go @@ -63,7 +63,6 @@ func (p *ReceiptStreamProvider) IncrementalSync( if existing, ok := req.Response.Rooms.Join[roomID]; ok { jr = existing } - var ok bool ev := gomatrixserverlib.ClientEvent{ Type: gomatrixserverlib.MReceipt, @@ -71,8 +70,8 @@ func (p *ReceiptStreamProvider) IncrementalSync( } content := make(map[string]eduAPI.ReceiptMRead) for _, receipt := range receipts { - var read eduAPI.ReceiptMRead - if read, ok = content[receipt.EventID]; !ok { + read, ok := content[receipt.EventID] + if !ok { read = eduAPI.ReceiptMRead{ User: make(map[string]eduAPI.ReceiptTS), } From bb2380c254b65a6586137e9e0ab9e08354aa19f4 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 1 Mar 2022 16:59:52 +0000 Subject: [PATCH 21/24] Allow specifying max age for caches (#2239) * Allow specifying max age for caches * Evict cache entry if it's found to be stale when we call Get * Fix bugs --- internal/caching/cache_federationevents.go | 1 + internal/caching/cache_roominfo.go | 3 ++ internal/caching/cache_roomservernids.go | 1 + internal/caching/cache_roomversions.go | 1 + internal/caching/cache_serverkeys.go | 1 + internal/caching/caches.go | 4 +++ internal/caching/impl_inmemorylru.go | 42 +++++++++++++++++++--- 7 files changed, 48 insertions(+), 5 deletions(-) diff --git a/internal/caching/cache_federationevents.go b/internal/caching/cache_federationevents.go index d10b333a6..b79cc809f 100644 --- a/internal/caching/cache_federationevents.go +++ b/internal/caching/cache_federationevents.go @@ -10,6 +10,7 @@ const ( FederationEventCacheName = "federation_event" FederationEventCacheMaxEntries = 256 FederationEventCacheMutable = true // to allow use of Unset only + FederationEventCacheMaxAge = CacheNoMaxAge ) // FederationCache contains the subset of functions needed for diff --git a/internal/caching/cache_roominfo.go b/internal/caching/cache_roominfo.go index f32d6ba9b..60d221285 100644 --- a/internal/caching/cache_roominfo.go +++ b/internal/caching/cache_roominfo.go @@ -1,6 +1,8 @@ package caching import ( + "time" + "github.com/matrix-org/dendrite/roomserver/types" ) @@ -16,6 +18,7 @@ const ( RoomInfoCacheName = "roominfo" RoomInfoCacheMaxEntries = 1024 RoomInfoCacheMutable = true + RoomInfoCacheMaxAge = time.Minute * 5 ) // RoomInfosCache contains the subset of functions needed for diff --git a/internal/caching/cache_roomservernids.go b/internal/caching/cache_roomservernids.go index 6d413093f..1918a2f1e 100644 --- a/internal/caching/cache_roomservernids.go +++ b/internal/caching/cache_roomservernids.go @@ -10,6 +10,7 @@ const ( RoomServerRoomIDsCacheName = "roomserver_room_ids" RoomServerRoomIDsCacheMaxEntries = 1024 RoomServerRoomIDsCacheMutable = false + RoomServerRoomIDsCacheMaxAge = CacheNoMaxAge ) type RoomServerCaches interface { diff --git a/internal/caching/cache_roomversions.go b/internal/caching/cache_roomversions.go index 0b46d3d4b..92d2eab08 100644 --- a/internal/caching/cache_roomversions.go +++ b/internal/caching/cache_roomversions.go @@ -6,6 +6,7 @@ const ( RoomVersionCacheName = "room_versions" RoomVersionCacheMaxEntries = 1024 RoomVersionCacheMutable = false + RoomVersionCacheMaxAge = CacheNoMaxAge ) // RoomVersionsCache contains the subset of functions needed for diff --git a/internal/caching/cache_serverkeys.go b/internal/caching/cache_serverkeys.go index 4697fb4d2..4eb10fe6f 100644 --- a/internal/caching/cache_serverkeys.go +++ b/internal/caching/cache_serverkeys.go @@ -10,6 +10,7 @@ const ( ServerKeyCacheName = "server_key" ServerKeyCacheMaxEntries = 4096 ServerKeyCacheMutable = true + ServerKeyCacheMaxAge = CacheNoMaxAge ) // ServerKeyCache contains the subset of functions needed for diff --git a/internal/caching/caches.go b/internal/caching/caches.go index 39926d735..722405de6 100644 --- a/internal/caching/caches.go +++ b/internal/caching/caches.go @@ -1,5 +1,7 @@ package caching +import "time" + // Caches contains a set of references to caches. They may be // different implementations as long as they satisfy the Cache // interface. @@ -19,3 +21,5 @@ type Cache interface { Set(key string, value interface{}) Unset(key string) } + +const CacheNoMaxAge = time.Duration(0) diff --git a/internal/caching/impl_inmemorylru.go b/internal/caching/impl_inmemorylru.go index 7b21f1362..c9c8fd08d 100644 --- a/internal/caching/impl_inmemorylru.go +++ b/internal/caching/impl_inmemorylru.go @@ -14,6 +14,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { RoomVersionCacheName, RoomVersionCacheMutable, RoomVersionCacheMaxEntries, + RoomVersionCacheMaxAge, enablePrometheus, ) if err != nil { @@ -23,6 +24,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { ServerKeyCacheName, ServerKeyCacheMutable, ServerKeyCacheMaxEntries, + ServerKeyCacheMaxAge, enablePrometheus, ) if err != nil { @@ -32,6 +34,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { RoomServerRoomIDsCacheName, RoomServerRoomIDsCacheMutable, RoomServerRoomIDsCacheMaxEntries, + RoomServerRoomIDsCacheMaxAge, enablePrometheus, ) if err != nil { @@ -41,6 +44,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { RoomInfoCacheName, RoomInfoCacheMutable, RoomInfoCacheMaxEntries, + RoomInfoCacheMaxAge, enablePrometheus, ) if err != nil { @@ -50,6 +54,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { FederationEventCacheName, FederationEventCacheMutable, FederationEventCacheMaxEntries, + FederationEventCacheMaxAge, enablePrometheus, ) if err != nil { @@ -96,15 +101,22 @@ type InMemoryLRUCachePartition struct { name string mutable bool maxEntries int + maxAge time.Duration lru *lru.Cache } -func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, enablePrometheus bool) (*InMemoryLRUCachePartition, error) { +type inMemoryLRUCacheEntry struct { + value interface{} + created time.Time +} + +func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, maxAge time.Duration, enablePrometheus bool) (*InMemoryLRUCachePartition, error) { var err error cache := InMemoryLRUCachePartition{ name: name, mutable: mutable, maxEntries: maxEntries, + maxAge: maxAge, } cache.lru, err = lru.New(maxEntries) if err != nil { @@ -124,11 +136,16 @@ func NewInMemoryLRUCachePartition(name string, mutable bool, maxEntries int, ena func (c *InMemoryLRUCachePartition) Set(key string, value interface{}) { if !c.mutable { - if peek, ok := c.lru.Peek(key); ok && peek != value { - panic(fmt.Sprintf("invalid use of immutable cache tries to mutate existing value of %q", key)) + if peek, ok := c.lru.Peek(key); ok { + if entry, ok := peek.(*inMemoryLRUCacheEntry); ok && entry.value != value { + panic(fmt.Sprintf("invalid use of immutable cache tries to mutate existing value of %q", key)) + } } } - c.lru.Add(key, value) + c.lru.Add(key, &inMemoryLRUCacheEntry{ + value: value, + created: time.Now(), + }) } func (c *InMemoryLRUCachePartition) Unset(key string) { @@ -139,5 +156,20 @@ func (c *InMemoryLRUCachePartition) Unset(key string) { } func (c *InMemoryLRUCachePartition) Get(key string) (value interface{}, ok bool) { - return c.lru.Get(key) + v, ok := c.lru.Get(key) + if !ok { + return nil, false + } + entry, ok := v.(*inMemoryLRUCacheEntry) + switch { + case ok && c.maxAge == CacheNoMaxAge: + return entry.value, ok // There's no maximum age policy + case ok && time.Since(entry.created) < c.maxAge: + return entry.value, ok // The value for the key isn't stale + default: + // Either the key was found and it was stale, or the key + // wasn't found at all + c.lru.Remove(key) + return nil, false + } } From 8e82739d77ea472932fd2ff7800ce744ace53bfa Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 1 Mar 2022 17:01:08 +0000 Subject: [PATCH 22/24] Set max age of 5 minutes for spaces summary cache --- internal/caching/cache_space_rooms.go | 7 ++++++- internal/caching/impl_inmemorylru.go | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/internal/caching/cache_space_rooms.go b/internal/caching/cache_space_rooms.go index 71b81abcf..6d56cce5f 100644 --- a/internal/caching/cache_space_rooms.go +++ b/internal/caching/cache_space_rooms.go @@ -1,11 +1,16 @@ package caching -import "github.com/matrix-org/gomatrixserverlib" +import ( + "time" + + "github.com/matrix-org/gomatrixserverlib" +) const ( SpaceSummaryRoomsCacheName = "space_summary_rooms" SpaceSummaryRoomsCacheMaxEntries = 100 SpaceSummaryRoomsCacheMutable = true + SpaceSummaryRoomsCacheMaxAge = time.Minute * 5 ) type SpaceSummaryRoomsCache interface { diff --git a/internal/caching/impl_inmemorylru.go b/internal/caching/impl_inmemorylru.go index c9c8fd08d..94fdd1a9b 100644 --- a/internal/caching/impl_inmemorylru.go +++ b/internal/caching/impl_inmemorylru.go @@ -64,6 +64,7 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { SpaceSummaryRoomsCacheName, SpaceSummaryRoomsCacheMutable, SpaceSummaryRoomsCacheMaxEntries, + SpaceSummaryRoomsCacheMaxAge, enablePrometheus, ) if err != nil { From 23f028cf6e492f9e176b788797f2da2d274c1705 Mon Sep 17 00:00:00 2001 From: kegsay Date: Tue, 1 Mar 2022 17:18:06 +0000 Subject: [PATCH 23/24] Add unit test for device list update debouncing (#2220) * Add unit test for device list update debouncing * bugfix: actually return stale device lists in the test... Co-authored-by: Neil Alexander --- keyserver/internal/device_list_update_test.go | 89 ++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index 164be6bec..e8d1bfe8b 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -48,13 +48,19 @@ type mockDeviceListUpdaterDatabase struct { staleUsers map[string]bool prevIDsExist func(string, []int) bool storedKeys []api.DeviceMessage + mu sync.Mutex // protect staleUsers } // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + d.mu.Lock() + defer d.mu.Unlock() var result []string - for userID := range d.staleUsers { + for userID, isStale := range d.staleUsers { + if !isStale { + continue + } _, remoteServer, err := gomatrixserverlib.SplitID('@', userID) if err != nil { return nil, err @@ -75,6 +81,8 @@ func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, do // MarkDeviceListStale sets the stale bit for this user to isStale. func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { + d.mu.Lock() + defer d.mu.Unlock() d.staleUsers[userID] = isStale return nil } @@ -247,3 +255,82 @@ func TestUpdateNoPrevID(t *testing.T) { } } + +// Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the +// update is still ongoing. +func TestDebounce(t *testing.T) { + db := &mockDeviceListUpdaterDatabase{ + staleUsers: make(map[string]bool), + prevIDsExist: func(string, []int) bool { + return true + }, + } + ap := &mockDeviceListUpdaterAPI{} + producer := &mockKeyChangeProducer{} + fedCh := make(chan *http.Response, 1) + srv := gomatrixserverlib.ServerName("example.com") + userID := "@alice:example.com" + keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}` + incomingFedReq := make(chan struct{}) + fedClient := newFedClient(func(req *http.Request) (*http.Response, error) { + if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(userID) { + return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path) + } + close(incomingFedReq) + return <-fedCh, nil + }) + updater := NewDeviceListUpdater(db, ap, producer, fedClient, 1) + if err := updater.Start(); err != nil { + t.Fatalf("failed to start updater: %s", err) + } + + // hit this 5 times + var wg sync.WaitGroup + wg.Add(5) + for i := 0; i < 5; i++ { + go func() { + defer wg.Done() + if err := updater.ManualUpdate(context.Background(), srv, userID); err != nil { + t.Errorf("ManualUpdate: %s", err) + } + }() + } + + // wait until the updater hits federation + select { + case <-incomingFedReq: + case <-time.After(time.Second): + t.Fatalf("timed out waiting for updater to hit federation") + } + + // user should be marked as stale + if !db.staleUsers[userID] { + t.Errorf("user %s not marked as stale", userID) + } + // now send the response over federation + fedCh <- &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(` + { + "user_id": "` + userID + `", + "stream_id": 5, + "devices": [ + { + "device_id": "JLAFKJWSCS", + "keys": ` + keyJSON + `, + "device_display_name": "Mobile Phone" + } + ] + } + `)), + } + close(fedCh) + // wait until all 5 ManualUpdates return. If we hit federation again we won't send a response + // and should panic with read on a closed channel + wg.Wait() + + // user is no longer stale now + if db.staleUsers[userID] { + t.Errorf("user %s is marked as stale", userID) + } +} From 849e40d45653d3d42bbddf2d58d6d5f488cebdba Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 1 Mar 2022 17:25:26 +0000 Subject: [PATCH 24/24] Use correct stream provider in `Latest` for `ReceiptPosition` --- syncapi/streams/streams.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index 6b02c75ea..c71095af6 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -69,7 +69,7 @@ func (s *Streams) Latest(ctx context.Context) types.StreamingToken { return types.StreamingToken{ PDUPosition: s.PDUStreamProvider.LatestPosition(ctx), TypingPosition: s.TypingStreamProvider.LatestPosition(ctx), - ReceiptPosition: s.PDUStreamProvider.LatestPosition(ctx), + ReceiptPosition: s.ReceiptStreamProvider.LatestPosition(ctx), InvitePosition: s.InviteStreamProvider.LatestPosition(ctx), SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx), AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx),