diff --git a/CHANGES.md b/CHANGES.md index 4df8e869a..ee608194d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,39 @@ # Changelog +## Dendrite 0.6.4 (2022-02-21) + +### Features + +* All Client-Server API endpoints are now available under the `/v3` namespace +* The `/whoami` response format now matches the latest Matrix spec version +* Support added for the `/context` endpoint, which should help clients to render quote-replies correctly +* Accounts now have an optional account type field, allowing admin accounts to be created +* Server notices are now supported +* Refactored the user API storage to deduplicate a significant amount of code, as well as merging both user API databases into a single database + * The account database is now used for all user API storage and the device database is now obsolete + * For some installations that have separate account and device databases, this may result in access tokens being revoked and client sessions being logged out — users may need to log in again + * The above can be avoided by moving the `device_devices` table into the account database manually +* Guest registration can now be separately disabled with the new `client_api.guests_disabled` configuration option +* Outbound connections now obey proxy settings from the environment, deprecating the `federation_api.proxy_outbound` configuration options + +### Fixes + +* The roomserver input API will now strictly consume only one database transaction per room, which should prevent situations where the roomserver can deadlock waiting for database connections to become available +* Room joins will now fall back to federation if the local room state is insufficient to create a membership event +* Create events are now correctly filtered from federation `/send` transactions +* Excessive logging when federation is disabled should now be fixed +* Dendrite will no longer panic if trying to retire an invite event that has not been seen yet +* The device list updater will now wait for longer after a connection issue, rather than flooding the logs with errors +* The device list updater will no longer produce unnecessary output events for federated key updates with no changes, which should help to reduce CPU usage +* Local device name changes will now generate key change events correctly +* The sync API will now try to share device list update notifications even if all state key NIDs cannot be fetched +* An off-by-one error in the sync stream token handling which could result in a crash has been fixed +* State events will no longer be re-sent unnecessary by the roomserver to other components if they have already been sent, which should help to reduce the NATS message sizes on the roomserver output topic in some cases +* The roomserver input API now uses the process context and should handle graceful shutdowns better +* Guest registration is now correctly disabled when the `client_api.registration_disabled` configuration option is set +* One-time encryption keys are now cleaned up correctly when a device is logged out or removed +* Invalid state snapshots in the state storage refactoring migration are now reset rather than causing a panic at startup + ## Dendrite 0.6.3 (2022-02-10) ### Features diff --git a/clientapi/routing/auth_fallback.go b/clientapi/routing/auth_fallback.go index 839ca9e54..abfe830fb 100644 --- a/clientapi/routing/auth_fallback.go +++ b/clientapi/routing/auth_fallback.go @@ -162,7 +162,7 @@ func AuthFallback( } // Success. Add recaptcha as a completed login flow - AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) serveSuccess() return nil diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 7ecab9d4e..4426b7fdc 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -70,7 +70,7 @@ func UploadCrossSigningDeviceKeys( if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { return *authErr } - AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) uploadReq.UserID = device.UserID keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes) diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 499510193..acac60fa5 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -74,7 +74,7 @@ func Password( if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil { return *authErr } - AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) // Check the new password strength. if resErr = validatePassword(r.NewPassword); resErr != nil { diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index d00d9886e..10cfa4325 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -72,14 +72,19 @@ func init() { // sessionsDict keeps track of completed auth stages for each session. // It shouldn't be passed by value because it contains a mutex. type sessionsDict struct { - sync.Mutex + sync.RWMutex sessions map[string][]authtypes.LoginType + params map[string]registerRequest + timer map[string]*time.Timer } -// GetCompletedStages returns the completed stages for a session. -func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType { - d.Lock() - defer d.Unlock() +// defaultTimeout is the timeout used to clean up sessions +const defaultTimeOut = time.Minute * 5 + +// getCompletedStages returns the completed stages for a session. +func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType { + d.RLock() + defer d.RUnlock() if completedStages, ok := d.sessions[sessionID]; ok { return completedStages @@ -88,28 +93,79 @@ func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginTyp return make([]authtypes.LoginType, 0) } -func newSessionsDict() *sessionsDict { - return &sessionsDict{ - sessions: make(map[string][]authtypes.LoginType), +// addParams adds a registerRequest to a sessionID and starts a timer to delete that registerRequest +func (d *sessionsDict) addParams(sessionID string, r registerRequest) { + d.startTimer(defaultTimeOut, sessionID) + d.Lock() + defer d.Unlock() + d.params[sessionID] = r +} + +func (d *sessionsDict) getParams(sessionID string) (registerRequest, bool) { + d.RLock() + defer d.RUnlock() + r, ok := d.params[sessionID] + return r, ok +} + +// deleteSession cleans up a given session, either because the registration completed +// successfully, or because a given timeout (default: 5min) was reached. +func (d *sessionsDict) deleteSession(sessionID string) { + d.Lock() + defer d.Unlock() + delete(d.params, sessionID) + delete(d.sessions, sessionID) + // stop the timer, e.g. because the registration was completed + if t, ok := d.timer[sessionID]; ok { + if !t.Stop() { + select { + case <-t.C: + default: + } + } + delete(d.timer, sessionID) } } -// AddCompletedSessionStage records that a session has completed an auth stage. -func AddCompletedSessionStage(sessionID string, stage authtypes.LoginType) { - sessions.Lock() - defer sessions.Unlock() +func newSessionsDict() *sessionsDict { + return &sessionsDict{ + sessions: make(map[string][]authtypes.LoginType), + params: make(map[string]registerRequest), + timer: make(map[string]*time.Timer), + } +} - for _, completedStage := range sessions.sessions[sessionID] { +func (d *sessionsDict) startTimer(duration time.Duration, sessionID string) { + d.Lock() + defer d.Unlock() + t, ok := d.timer[sessionID] + if ok { + if !t.Stop() { + <-t.C + } + t.Reset(duration) + return + } + d.timer[sessionID] = time.AfterFunc(duration, func() { + d.deleteSession(sessionID) + }) +} + +// addCompletedSessionStage records that a session has completed an auth stage +// also starts a timer to delete the session once done. +func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtypes.LoginType) { + d.startTimer(defaultTimeOut, sessionID) + d.Lock() + defer d.Unlock() + for _, completedStage := range d.sessions[sessionID] { if completedStage == stage { return } } - sessions.sessions[sessionID] = append(sessions.sessions[sessionID], stage) + d.sessions[sessionID] = append(sessions.sessions[sessionID], stage) } var ( - // TODO: Remove old sessions. Need to do so on a session-specific timeout. - // sessions stores the completed flow stages for all sessions. Referenced using their sessionID. sessions = newSessionsDict() validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) ) @@ -167,7 +223,7 @@ func newUserInteractiveResponse( params map[string]interface{}, ) userInteractiveResponse { return userInteractiveResponse{ - fs, sessions.GetCompletedStages(sessionID), params, sessionID, + fs, sessions.getCompletedStages(sessionID), params, sessionID, } } @@ -645,12 +701,12 @@ func handleRegistrationFlow( } // Add Recaptcha to the list of completed registration stages - AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) case authtypes.LoginTypeDummy: // there is nothing to do // Add Dummy to the list of completed registration stages - AddCompletedSessionStage(sessionID, authtypes.LoginTypeDummy) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeDummy) case "": // An empty auth type means that we want to fetch the available @@ -666,7 +722,7 @@ func handleRegistrationFlow( // Check if the user's registration flow has been completed successfully // A response with current registration flow and remaining available methods // will be returned if a flow has not been successfully completed yet - return checkAndCompleteFlow(sessions.GetCompletedStages(sessionID), + return checkAndCompleteFlow(sessions.getCompletedStages(sessionID), req, r, sessionID, cfg, userAPI) } @@ -708,7 +764,7 @@ func handleApplicationServiceRegistration( // Don't need to worry about appending to registration stages as // application service registration is entirely separate. return completeRegistration( - req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), + req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, ) } @@ -727,11 +783,11 @@ func checkAndCompleteFlow( if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue return completeRegistration( - req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), + req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, ) } - + sessions.addParams(sessionID, r) // There are still more stages to complete. // Return the flows and those that have been completed. return util.JSONResponse{ @@ -750,11 +806,25 @@ func checkAndCompleteFlow( func completeRegistration( ctx context.Context, userAPI userapi.UserInternalAPI, - username, password, appserviceID, ipAddr, userAgent string, + username, password, appserviceID, ipAddr, userAgent, sessionID string, inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string, accType userapi.AccountType, ) util.JSONResponse { + var registrationOK bool + defer func() { + if registrationOK { + sessions.deleteSession(sessionID) + } + }() + + if data, ok := sessions.getParams(sessionID); ok { + username = data.Username + password = data.Password + deviceID = data.DeviceID + displayName = data.InitialDisplayName + inhibitLogin = data.InhibitLogin + } if username == "" { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -795,6 +865,7 @@ func completeRegistration( // Check whether inhibit_login option is set. If so, don't create an access // token or a device for this user if inhibitLogin { + registrationOK = true return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ @@ -828,6 +899,7 @@ func completeRegistration( } } + registrationOK = true return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ @@ -976,5 +1048,5 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS if ssrr.Admin { accType = userapi.AccountTypeAdmin } - return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID, accType) + return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) } diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 1f615dc26..c6b7e61cf 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -17,6 +17,7 @@ package routing import ( "regexp" "testing" + "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/setup/config" @@ -140,7 +141,7 @@ func TestFlowCheckingExtraneousIncorrectInput(t *testing.T) { func TestEmptyCompletedFlows(t *testing.T) { fakeEmptySessions := newSessionsDict() fakeSessionID := "aRandomSessionIDWhichDoesNotExist" - ret := fakeEmptySessions.GetCompletedStages(fakeSessionID) + ret := fakeEmptySessions.getCompletedStages(fakeSessionID) // check for [] if ret == nil || len(ret) != 0 { @@ -208,3 +209,45 @@ func TestValidationOfApplicationServices(t *testing.T) { t.Errorf("user_id should not have been valid: @_something_else:localhost") } } + +func TestSessionCleanUp(t *testing.T) { + s := newSessionsDict() + + t.Run("session is cleaned up after a while", func(t *testing.T) { + t.Parallel() + dummySession := "helloWorld" + // manually added, as s.addParams() would start the timer with the default timeout + s.params[dummySession] = registerRequest{Username: "Testing"} + s.startTimer(time.Millisecond, dummySession) + time.Sleep(time.Millisecond * 2) + if data, ok := s.getParams(dummySession); ok { + t.Errorf("expected session to be deleted: %+v", data) + } + }) + + t.Run("session is deleted, once the registration completed", func(t *testing.T) { + t.Parallel() + dummySession := "helloWorld2" + s.startTimer(time.Minute, dummySession) + s.deleteSession(dummySession) + if data, ok := s.getParams(dummySession); ok { + t.Errorf("expected session to be deleted: %+v", data) + } + }) + + t.Run("session timer is restarted after second call", func(t *testing.T) { + t.Parallel() + dummySession := "helloWorld3" + // the following will start a timer with the default timeout of 5min + s.addParams(dummySession, registerRequest{Username: "Testing"}) + s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha) + s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy) + s.getCompletedStages(dummySession) + // reset the timer with a lower timeout + s.startTimer(time.Millisecond, dummySession) + time.Sleep(time.Millisecond * 2) + if data, ok := s.getParams(dummySession); ok { + t.Errorf("expected session to be deleted: %+v", data) + } + }) +} \ No newline at end of file diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 088e412c6..d25ee8237 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -235,7 +235,7 @@ func OnIncomingStateTypeRequest( } // If the user has never been in the room then stop at this point. // We won't tell the user about a room they have never joined. - if !membershipRes.HasBeenInRoom { + if !membershipRes.HasBeenInRoom || membershipRes.Membership == gomatrixserverlib.Ban { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)), diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 3003896c8..307fa17e6 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -48,24 +48,27 @@ Example: # read password from stdin %s --config dendrite.yaml -username alice -passwordstdin < my.pass cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin + # reset password for a user, can be used with a combination above to read the password + %s --config dendrite.yaml -reset-password -username alice -password foobarbaz Arguments: ` var ( - username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") - password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)") - pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") - pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") - askPass = flag.Bool("ask-pass", false, "Ask for the password to use") - isAdmin = flag.Bool("admin", false, "Create an admin account") + username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") + password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)") + pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") + pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") + askPass = flag.Bool("ask-pass", false, "Ask for the password to use") + isAdmin = flag.Bool("admin", false, "Create an admin account") + resetPassword = flag.Bool("reset-password", false, "Resets the password for the given username") ) func main() { name := os.Args[0] flag.Usage = func() { - _, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name) + _, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name, name) flag.PrintDefaults() } cfg := setup.ParseFlags(true) @@ -93,6 +96,19 @@ func main() { if *isAdmin { accType = api.AccountTypeAdmin } + + if *resetPassword { + err = accountDB.SetPassword(context.Background(), *username, pass) + if err != nil { + logrus.Fatalf("Failed to update password for user %s: %s", *username, err.Error()) + } + if _, err = accountDB.RemoveAllDevices(context.Background(), *username, ""); err != nil { + logrus.Fatalf("Failed to remove all devices: %s", err.Error()) + } + logrus.Infof("Updated password for user %s and invalidated all logins\n", *username) + return + } + _, err = accountDB.CreateAccount(context.Background(), *username, pass, "", accType) if err != nil { logrus.Fatalln("Failed to create the account:", err.Error()) diff --git a/docs/systemd/monolith-example.service b/docs/systemd/monolith-example.service index 731c6159b..237120ffb 100644 --- a/docs/systemd/monolith-example.service +++ b/docs/systemd/monolith-example.service @@ -13,6 +13,7 @@ Group=dendrite WorkingDirectory=/opt/dendrite/ ExecStart=/opt/dendrite/bin/dendrite-monolith-server Restart=always +LimitNOFILE=65535 [Install] WantedBy=multi-user.target diff --git a/federationapi/api/api.go b/federationapi/api/api.go index f5ee75b4b..4d6b0211c 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -21,7 +21,7 @@ type FederationClient interface { QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) - MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest) (res gomatrixserverlib.MSC2946SpacesResponse, err error) + MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index b31db466c..b8bd5beda 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships( } func (a *FederationInternalAPI) MSC2946Spaces( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest, + ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.MSC2946Spaces(ctx, s, roomID, r) + return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly) }) if err != nil { return res, err diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index f9b2a33d2..01ca6595d 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -526,23 +526,23 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships( } type spacesReq struct { - S gomatrixserverlib.ServerName - Req gomatrixserverlib.MSC2946SpacesRequest - RoomID string - Res gomatrixserverlib.MSC2946SpacesResponse - Err *api.FederationClientError + S gomatrixserverlib.ServerName + SuggestedOnly bool + RoomID string + Res gomatrixserverlib.MSC2946SpacesResponse + Err *api.FederationClientError } func (h *httpFederationInternalAPI) MSC2946Spaces( - ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest, + ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces") defer span.Finish() request := spacesReq{ - S: dst, - Req: r, - RoomID: roomID, + S: dst, + SuggestedOnly: suggestedOnly, + RoomID: roomID, } var response spacesReq apiURL := h.federationAPIURL + FederationAPISpacesSummaryPath diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 8d193d9c9..ca4930f20 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -378,7 +378,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { if err := json.NewDecoder(req.Body).Decode(&request); err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.Req) + res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.SuggestedOnly) if err != nil { ferr, ok := err.(*api.FederationClientError) if ok { diff --git a/go.mod b/go.mod index 2316096df..e44e7393d 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/frankban/quicktest v1.14.0 // indirect github.com/getsentry/sentry-go v0.12.0 github.com/gologme/log v1.3.0 + github.com/google/uuid v1.2.0 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.4.2 github.com/h2non/filetype v1.1.3 // indirect @@ -39,8 +40,8 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed - github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf + github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6 + github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 github.com/morikuni/aec v1.0.0 // indirect @@ -61,11 +62,11 @@ require ( github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.2 go.uber.org/atomic v1.9.0 - golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a + golang.org/x/crypto v0.0.0-20220214200702-86341886e292 golang.org/x/image v0.0.0-20211028202545-6944b10bf410 golang.org/x/mobile v0.0.0-20220112015953-858099ff7816 golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd - golang.org/x/sys v0.0.0-20220207234003-57398862261d // indirect + golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 gopkg.in/h2non/bimg.v1 v1.1.5 gopkg.in/yaml.v2 v2.4.0 diff --git a/go.sum b/go.sum index e79015e51..43b363d30 100644 --- a/go.sum +++ b/go.sum @@ -983,10 +983,10 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed h1:R8EiLWArq7KT96DrUq1xq9scPh8vLwKKeCTnORPyjhU= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= -github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf h1:/nqfHUdQHr3WVdbZieaYFvHF1rin5pvDTa/NOZ/qCyE= -github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6 h1:CqbM5tPbF1mV2h1J6N0NSF8et6HvFlwA98TLCUcjiSI= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= +github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk= +github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= @@ -1510,8 +1510,8 @@ golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5 golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo= -golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1737,8 +1737,8 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc= -golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 h1:nhht2DYV/Sn3qOayu8lM+cU1ii9sTLUeBQwQQfUHtrs= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= diff --git a/internal/version.go b/internal/version.go index a07f01b61..2ea1c5201 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 6 - VersionPatch = 3 + VersionPatch = 4 VersionTag = "" // example: "rc1" ) diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 3933961c1..54eb04f8a 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -15,6 +15,7 @@ package api import ( + "bytes" "context" "encoding/json" "strings" @@ -73,6 +74,26 @@ type DeviceMessage struct { DeviceChangeID int64 } +// DeviceKeysEqual returns true if the device keys updates contain the +// same display name and key JSON. This will return false if either of +// the updates is not a device keys update, or if the user ID/device ID +// differ between the two. +func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool { + if m1.DeviceKeys == nil || m2.DeviceKeys == nil { + return false + } + if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID { + return false + } + if m1.DisplayName != m2.DisplayName { + return false // different display names + } + if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 { + return false // either is empty + } + return bytes.Equal(m1.KeyJSON, m2.KeyJSON) +} + // DeviceKeys represents a set of device keys for a single device // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload type DeviceKeys struct { diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go index bfb2037f8..5124f37e6 100644 --- a/keyserver/internal/cross_signing.go +++ b/keyserver/internal/cross_signing.go @@ -166,26 +166,53 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P } // We can't have a self-signing or user-signing key without a master - // key, so make sure we have one of those. - if !hasMasterKey { - existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID) - if err != nil { - res.Error = &api.KeyError{ - Err: "Retrieving cross-signing keys from database failed: " + err.Error(), - } - return + // key, so make sure we have one of those. We will also only actually do + // something if any of the specified keys in the request are different + // to what we've got in the database, to avoid generating key change + // notifications unnecessarily. + existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID) + if err != nil { + res.Error = &api.KeyError{ + Err: "Retrieving cross-signing keys from database failed: " + err.Error(), } - - _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster] + return } // If we still can't find a master key for the user then stop the upload. // This satisfies the "Fails to upload self-signing key without master key" test. if !hasMasterKey { - res.Error = &api.KeyError{ - Err: "No master key was found", - IsMissingParam: true, + if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey { + res.Error = &api.KeyError{ + Err: "No master key was found", + IsMissingParam: true, + } + return } + } + + // Check if anything actually changed compared to what we have in the database. + changed := false + for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{ + gomatrixserverlib.CrossSigningKeyPurposeMaster, + gomatrixserverlib.CrossSigningKeyPurposeSelfSigning, + gomatrixserverlib.CrossSigningKeyPurposeUserSigning, + } { + old, gotOld := existingKeys[purpose] + new, gotNew := toStore[purpose] + if gotOld != gotNew { + // A new key purpose has been specified that we didn't know before, + // or one has been removed. + changed = true + break + } + if !bytes.Equal(old, new) { + // One of the existing keys for a purpose we already knew about has + // changed. + changed = true + break + } + } + if !changed { return } diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index b208f0ce5..974d0196b 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -224,7 +224,7 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib. }).Info("DeviceListUpdater.Update") // if we haven't missed anything update the database and notify users - if exists { + if exists || event.Deleted { k := event.Keys if event.Deleted { k = nil @@ -267,7 +267,7 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib. return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err) } - if err = emitDeviceKeyChanges(u.producer, existingKeys, keys); err != nil { + if err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false); err != nil { return false, fmt.Errorf("failed to produce device key changes for %s (%s): %w", event.UserID, event.DeviceID, err) } return false, nil @@ -473,7 +473,7 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi if err != nil { return fmt.Errorf("failed to mark device list as fresh: %w", err) } - err = emitDeviceKeyChanges(u.producer, existingKeys, keys) + err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false) if err != nil { return fmt.Errorf("failed to emit key changes for fresh device list: %w", err) } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 0c264b718..0a8bef95d 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -648,7 +648,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per } return } - err = emitDeviceKeyChanges(a.Producer, existingKeys, keysToStore) + err = emitDeviceKeyChanges(a.Producer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates) if err != nil { util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err) } @@ -710,7 +710,11 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform } -func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage) error { +func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error { + // if we only want to update the display names, we can skip the checks below + if onlyUpdateDisplayName { + return producer.ProduceKeyChanges(new) + } // find keys in new that are not in existing var keysAdded []api.DeviceMessage for _, newKey := range new { @@ -718,7 +722,7 @@ func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.Device for _, existingKey := range existing { // Do not treat the absence of keys as equal, or else we will not emit key changes // when users delete devices which never had a key to begin with as both KeyJSONs are nil. - if bytes.Equal(existingKey.KeyJSON, newKey.KeyJSON) && len(existingKey.KeyJSON) > 0 { + if existingKey.DeviceKeysEqual(&newKey) { exists = true break } diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 96d6711c6..66e85f2f3 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -269,6 +269,7 @@ type QueryAuthChainResponse struct { type QuerySharedUsersRequest struct { UserID string + OtherUserIDs []string ExcludeRoomIDs []string IncludeRoomIDs []string } @@ -312,7 +313,10 @@ type QueryBulkStateContentResponse struct { } type QueryCurrentStateRequest struct { - RoomID string + RoomID string + AllowWildcards bool + // State key tuples. If a state_key has '*' and AllowWidlcards is true, returns all matching + // state events with that event type. StateTuples []gomatrixserverlib.StateKeyTuple } diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 012094c62..5491d36b3 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -51,12 +51,8 @@ func SendEventWithState( state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, ) error { - outliers, err := state.Events(event.RoomVersion) - if err != nil { - return err - } - - var ires []InputRoomEvent + outliers := state.Events(event.RoomVersion) + ires := make([]InputRoomEvent, 0, len(outliers)) for _, outlier := range outliers { if haveEventIDs[outlier.EventID()] { continue diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 9af0bf591..0229f822f 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -20,22 +20,17 @@ import ( "sort" "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) -type checkForAuthAndSoftFailStorage interface { - state.StateResolutionStorage - StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) - RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) -} - // CheckForSoftFail returns true if the event should be soft-failed // and false otherwise. The return error value should be checked before // the soft-fail bool. func CheckForSoftFail( ctx context.Context, - db checkForAuthAndSoftFailStorage, + db storage.Database, event *gomatrixserverlib.HeaderedEvent, stateEventIDs []string, ) (bool, error) { @@ -97,7 +92,7 @@ func CheckForSoftFail( // Returns the numeric IDs for the auth events. func CheckAuthEvents( ctx context.Context, - db checkForAuthAndSoftFailStorage, + db storage.Database, event *gomatrixserverlib.HeaderedEvent, authEventIDs []string, ) ([]types.EventNID, error) { diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 22e4b67a0..178533ded 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -19,7 +19,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "sync" "time" @@ -40,19 +39,6 @@ import ( "github.com/tidwall/gjson" ) -type retryAction int -type commitAction int - -const ( - doNotRetry retryAction = iota - retryLater -) - -const ( - commitTransaction commitAction = iota - rollbackTransaction -) - var keyContentFields = map[string]string{ "m.room.join_rules": "join_rule", "m.room.history_visibility": "history_visibility", @@ -117,8 +103,7 @@ func (r *Inputer) Start() error { _ = msg.InProgress() // resets the acknowledgement wait timer defer eventsInProgress.Delete(index) defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - action, err := r.processRoomEventUsingUpdater(r.ProcessContext.Context(), roomID, &inputRoomEvent) - if err != nil { + if err := r.processRoomEvent(r.ProcessContext.Context(), &inputRoomEvent); err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) } @@ -127,11 +112,8 @@ func (r *Inputer) Start() error { "event_id": inputRoomEvent.Event.EventID(), "type": inputRoomEvent.Event.Type(), }).Warn("Roomserver failed to process async event") - } - switch action { - case retryLater: - _ = msg.Nak() - case doNotRetry: + _ = msg.Term() + } else { _ = msg.Ack() } }) @@ -153,37 +135,6 @@ func (r *Inputer) Start() error { return err } -// processRoomEventUsingUpdater opens up a room updater and tries to -// process the event. It returns whether or not we should positively -// or negatively acknowledge the event (i.e. for NATS) and an error -// if it occurred. -func (r *Inputer) processRoomEventUsingUpdater( - ctx context.Context, - roomID string, - inputRoomEvent *api.InputRoomEvent, -) (retryAction, error) { - roomInfo, err := r.DB.RoomInfo(ctx, roomID) - if err != nil { - return doNotRetry, fmt.Errorf("r.DB.RoomInfo: %w", err) - } - updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) - if err != nil { - return retryLater, fmt.Errorf("r.DB.GetRoomUpdater: %w", err) - } - action, err := r.processRoomEvent(ctx, updater, inputRoomEvent) - switch action { - case commitTransaction: - if cerr := updater.Commit(); cerr != nil { - return retryLater, fmt.Errorf("updater.Commit: %w", cerr) - } - case rollbackTransaction: - if rerr := updater.Rollback(); rerr != nil { - return retryLater, fmt.Errorf("updater.Rollback: %w", rerr) - } - } - return doNotRetry, err -} - // InputRoomEvents implements api.RoomserverInternalAPI func (r *Inputer) InputRoomEvents( ctx context.Context, @@ -230,7 +181,7 @@ func (r *Inputer) InputRoomEvents( worker.Act(nil, func() { defer eventsInProgress.Delete(index) defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - _, err := r.processRoomEventUsingUpdater(ctx, roomID, &inputRoomEvent) + err := r.processRoomEvent(ctx, &inputRoomEvent) if err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 4e151699e..531d6959e 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -26,10 +26,10 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/hooks" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -68,15 +68,14 @@ var processRoomEventDuration = prometheus.NewHistogramVec( // nolint:gocyclo func (r *Inputer) processRoomEvent( ctx context.Context, - updater *shared.RoomUpdater, input *api.InputRoomEvent, -) (commitAction, error) { +) error { select { case <-ctx.Done(): // Before we do anything, make sure the context hasn't expired for this pending task. // If it has then we'll give up straight away — it's probably a synchronous input // request and the caller has already given up, but the inbox task was still queued. - return rollbackTransaction, context.DeadlineExceeded + return context.DeadlineExceeded default: } @@ -109,7 +108,7 @@ func (r *Inputer) processRoomEvent( // if we have already got this event then do not process it again, if the input kind is an outlier. // Outliers contain no extra information which may warrant a re-processing. if input.Kind == api.KindOutlier { - evs, err2 := updater.EventsFromIDs(ctx, []string{event.EventID()}) + evs, err2 := r.DB.EventsFromIDs(ctx, []string{event.EventID()}) if err2 == nil && len(evs) == 1 { // check hash matches if we're on early room versions where the event ID was a random string idFormat, err2 := headered.RoomVersion.EventIDFormat() @@ -118,11 +117,11 @@ func (r *Inputer) processRoomEvent( case gomatrixserverlib.EventIDFormatV1: if bytes.Equal(event.EventReference().EventSHA256, evs[0].EventReference().EventSHA256) { logger.Debugf("Already processed event; ignoring") - return rollbackTransaction, nil + return nil } default: logger.Debugf("Already processed event; ignoring") - return rollbackTransaction, nil + return nil } } } @@ -131,17 +130,21 @@ func (r *Inputer) processRoomEvent( // Don't waste time processing the event if the room doesn't exist. // A room entry locally will only be created in response to a create // event. + roomInfo, rerr := r.DB.RoomInfo(ctx, event.RoomID()) + if rerr != nil { + return fmt.Errorf("r.DB.RoomInfo: %w", rerr) + } isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") - if !updater.RoomExists() && !isCreateEvent { - return rollbackTransaction, fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) + if roomInfo == nil && !isCreateEvent { + return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) } var missingAuth, missingPrev bool serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{} if !isCreateEvent { - missingAuthIDs, missingPrevIDs, err := updater.MissingAuthPrevEvents(ctx, event) + missingAuthIDs, missingPrevIDs, err := r.DB.MissingAuthPrevEvents(ctx, event) if err != nil { - return rollbackTransaction, fmt.Errorf("updater.MissingAuthPrevEvents: %w", err) + return fmt.Errorf("updater.MissingAuthPrevEvents: %w", err) } missingAuth = len(missingAuthIDs) > 0 missingPrev = !input.HasState && len(missingPrevIDs) > 0 @@ -153,7 +156,7 @@ func (r *Inputer) processRoomEvent( ExcludeSelf: true, } if err := r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { - return rollbackTransaction, fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) + return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) } // Sort all of the servers into a map so that we can randomise // their order. Then make sure that the input origin and the @@ -182,8 +185,8 @@ func (r *Inputer) processRoomEvent( isRejected := false authEvents := gomatrixserverlib.NewAuthEvents(nil) knownEvents := map[string]*types.Event{} - if err := r.fetchAuthEvents(ctx, updater, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { - return rollbackTransaction, fmt.Errorf("r.fetchAuthEvents: %w", err) + if err := r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + return fmt.Errorf("r.fetchAuthEvents: %w", err) } // Check if the event is allowed by its auth events. If it isn't then @@ -205,12 +208,12 @@ func (r *Inputer) processRoomEvent( // but weren't found. if isRejected { if event.StateKey() != nil { - return commitTransaction, fmt.Errorf( + return fmt.Errorf( "missing auth event %s for state event %s (type %q, state key %q)", authEventID, event.EventID(), event.Type(), *event.StateKey(), ) } else { - return commitTransaction, fmt.Errorf( + return fmt.Errorf( "missing auth event %s for timeline event %s (type %q)", authEventID, event.EventID(), event.Type(), ) @@ -226,7 +229,7 @@ func (r *Inputer) processRoomEvent( // Check that the event passes authentication checks based on the // current room state. var err error - softfail, err = helpers.CheckForSoftFail(ctx, updater, headered, input.StateEventIDs) + softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) if err != nil { logger.WithError(err).Warn("Error authing soft-failed event") } @@ -250,7 +253,8 @@ func (r *Inputer) processRoomEvent( missingState := missingStateReq{ origin: input.Origin, inputer: r, - db: updater, + db: r.DB, + roomInfo: roomInfo, federation: r.FSAPI, keys: r.KeyRing, roomsMu: internal.NewMutexByRoom(), @@ -290,16 +294,16 @@ func (r *Inputer) processRoomEvent( } // Store the event. - _, _, stateAtEvent, redactionEvent, redactedEventID, err := updater.StoreEvent(ctx, event, authEventNIDs, isRejected) + _, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected) if err != nil { - return rollbackTransaction, fmt.Errorf("updater.StoreEvent: %w", err) + return fmt.Errorf("updater.StoreEvent: %w", err) } // if storing this event results in it being redacted then do so. if !isRejected && redactedEventID == event.EventID() { r, rerr := eventutil.RedactEvent(redactionEvent, event) if rerr != nil { - return rollbackTransaction, fmt.Errorf("eventutil.RedactEvent: %w", rerr) + return fmt.Errorf("eventutil.RedactEvent: %w", rerr) } event = r } @@ -310,23 +314,25 @@ func (r *Inputer) processRoomEvent( if input.Kind == api.KindOutlier { logger.Debug("Stored outlier") hooks.Run(hooks.KindNewEventPersisted, headered) - return commitTransaction, nil + return nil } - roomInfo, err := updater.RoomInfo(ctx, event.RoomID()) + // Request the room info again — it's possible that the room has been + // created by now if it didn't exist already. + roomInfo, err = r.DB.RoomInfo(ctx, event.RoomID()) if err != nil { - return rollbackTransaction, fmt.Errorf("updater.RoomInfo: %w", err) + return fmt.Errorf("updater.RoomInfo: %w", err) } if roomInfo == nil { - return rollbackTransaction, fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID()) + return fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID()) } if input.HasState || (!missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0) { // We haven't calculated a state for this event yet. // Lets calculate one. - err = r.calculateAndSetState(ctx, updater, input, roomInfo, &stateAtEvent, event, isRejected) + err = r.calculateAndSetState(ctx, input, roomInfo, &stateAtEvent, event, isRejected) if err != nil { - return rollbackTransaction, fmt.Errorf("r.calculateAndSetState: %w", err) + return fmt.Errorf("r.calculateAndSetState: %w", err) } } @@ -337,16 +343,15 @@ func (r *Inputer) processRoomEvent( "missing_prev": missingPrev, }).Warn("Stored rejected event") if rejectionErr != nil { - return commitTransaction, types.RejectedError(rejectionErr.Error()) + return types.RejectedError(rejectionErr.Error()) } - return commitTransaction, nil + return nil } switch input.Kind { case api.KindNew: if err = r.updateLatestEvents( ctx, // context - updater, // room updater roomInfo, // room info for the room being updated stateAtEvent, // state at event (below) event, // event @@ -354,7 +359,7 @@ func (r *Inputer) processRoomEvent( input.TransactionID, // transaction ID input.HasState, // rewrites state? ); err != nil { - return rollbackTransaction, fmt.Errorf("r.updateLatestEvents: %w", err) + return fmt.Errorf("r.updateLatestEvents: %w", err) } case api.KindOld: err = r.WriteOutputEvents(event.RoomID(), []api.OutputEvent{ @@ -366,7 +371,7 @@ func (r *Inputer) processRoomEvent( }, }) if err != nil { - return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (old): %w", err) + return fmt.Errorf("r.WriteOutputEvents (old): %w", err) } } @@ -385,14 +390,14 @@ func (r *Inputer) processRoomEvent( }, }) if err != nil { - return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (redactions): %w", err) + return fmt.Errorf("r.WriteOutputEvents (redactions): %w", err) } } // Everything was OK — the latest events updater didn't error and // we've sent output events. Finally, generate a hook call. hooks.Run(hooks.KindNewEventPersisted, headered) - return commitTransaction, nil + return nil } // fetchAuthEvents will check to see if any of the @@ -404,7 +409,6 @@ func (r *Inputer) processRoomEvent( // they are now in the database. func (r *Inputer) fetchAuthEvents( ctx context.Context, - updater *shared.RoomUpdater, logger *logrus.Entry, event *gomatrixserverlib.HeaderedEvent, auth *gomatrixserverlib.AuthEvents, @@ -418,7 +422,7 @@ func (r *Inputer) fetchAuthEvents( } for _, authEventID := range authEventIDs { - authEvents, err := updater.EventsFromIDs(ctx, []string{authEventID}) + authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID}) if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { unknown[authEventID] = struct{}{} continue @@ -495,7 +499,7 @@ nextAuthEvent: } // Finally, store the event in the database. - eventNID, _, _, _, _, err := updater.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) + eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) if err != nil { return fmt.Errorf("updater.StoreEvent: %w", err) } @@ -520,14 +524,18 @@ nextAuthEvent: func (r *Inputer) calculateAndSetState( ctx context.Context, - updater *shared.RoomUpdater, input *api.InputRoomEvent, roomInfo *types.RoomInfo, stateAtEvent *types.StateAtEvent, event *gomatrixserverlib.Event, isRejected bool, ) error { - var err error + var succeeded bool + updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) + if err != nil { + return fmt.Errorf("r.DB.GetRoomUpdater: %w", err) + } + defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) roomState := state.NewStateResolution(updater, roomInfo) if input.HasState { @@ -536,7 +544,7 @@ func (r *Inputer) calculateAndSetState( // We've been told what the state at the event is so we don't need to calculate it. // Check that those state events are in the database and store the state. var entries []types.StateEntry - if entries, err = updater.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { + if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err) } entries = types.DeduplicateStateEntries(entries) @@ -557,5 +565,6 @@ func (r *Inputer) calculateAndSetState( if err != nil { return fmt.Errorf("r.DB.SetState: %w", err) } + succeeded = true return nil } diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index ae28ebefa..f4a52031a 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -20,6 +20,7 @@ import ( "context" "fmt" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage/shared" @@ -47,7 +48,6 @@ import ( // Can only be called once at a time func (r *Inputer) updateLatestEvents( ctx context.Context, - updater *shared.RoomUpdater, roomInfo *types.RoomInfo, stateAtEvent types.StateAtEvent, event *gomatrixserverlib.Event, @@ -55,6 +55,14 @@ func (r *Inputer) updateLatestEvents( transactionID *api.TransactionID, rewritesState bool, ) (err error) { + var succeeded bool + updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) + if err != nil { + return fmt.Errorf("r.DB.GetRoomUpdater: %w", err) + } + + defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) + u := latestEventsUpdater{ ctx: ctx, api: r, @@ -71,6 +79,7 @@ func (r *Inputer) updateLatestEvents( return fmt.Errorf("u.doUpdateLatestEvents: %w", err) } + succeeded = true return } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index fc3be7987..a7da9b06d 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -11,7 +11,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" - "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -23,9 +23,25 @@ type parsedRespState struct { StateEvents []*gomatrixserverlib.Event } +func (p *parsedRespState) Events() []*gomatrixserverlib.Event { + eventsByID := make(map[string]*gomatrixserverlib.Event, len(p.AuthEvents)+len(p.StateEvents)) + for i, event := range p.AuthEvents { + eventsByID[event.EventID()] = p.AuthEvents[i] + } + for i, event := range p.StateEvents { + eventsByID[event.EventID()] = p.StateEvents[i] + } + allEvents := make([]*gomatrixserverlib.Event, 0, len(eventsByID)) + for _, event := range eventsByID { + allEvents = append(allEvents, event) + } + return gomatrixserverlib.ReverseTopologicalOrdering(allEvents, gomatrixserverlib.TopologicalOrderByAuthEvents) +} + type missingStateReq struct { origin gomatrixserverlib.ServerName - db *shared.RoomUpdater + db storage.Database + roomInfo *types.RoomInfo inputer *Inputer keys gomatrixserverlib.JSONVerifier federation fedapi.FederationInternalAPI @@ -80,7 +96,7 @@ func (t *missingStateReq) processEventWithMissingState( // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // in the gap in the DAG for _, newEvent := range newEvents { - _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, @@ -123,11 +139,8 @@ func (t *missingStateReq) processEventWithMissingState( t.hadEventsMutex.Unlock() sendOutliers := func(resolvedState *parsedRespState) error { - outliers, oerr := gomatrixserverlib.OrderAuthAndStateEvents(resolvedState.AuthEvents, resolvedState.StateEvents, roomVersion) - if oerr != nil { - return fmt.Errorf("gomatrixserverlib.OrderAuthAndStateEvents: %w", oerr) - } - var outlierRoomEvents []api.InputRoomEvent + outliers := resolvedState.Events() + outlierRoomEvents := make([]api.InputRoomEvent, 0, len(outliers)) for _, outlier := range outliers { if hadEvents[outlier.EventID()] { continue @@ -139,8 +152,7 @@ func (t *missingStateReq) processEventWithMissingState( }) } for _, ire := range outlierRoomEvents { - _, err = t.inputer.processRoomEvent(ctx, t.db, &ire) - if err != nil { + if err = t.inputer.processRoomEvent(ctx, &ire); err != nil { if _, ok := err.(types.RejectedError); !ok { return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err) } @@ -163,7 +175,7 @@ func (t *missingStateReq) processEventWithMissingState( stateIDs = append(stateIDs, event.EventID()) } - _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ Kind: api.KindOld, Event: backwardsExtremity.Headered(roomVersion), Origin: t.origin, @@ -182,7 +194,7 @@ func (t *missingStateReq) processEventWithMissingState( // they will automatically fast-forward based on the room state at the // extremity in the last step. for _, newEvent := range newEvents { - _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, @@ -473,8 +485,10 @@ retryAllowedState: // without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled, prevStateKnown bool, err error) { logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) - - latest := t.db.LatestEvents() + latest, _, _, err := t.db.LatestEventIDs(ctx, t.roomInfo.RoomNID) + if err != nil { + return nil, false, false, fmt.Errorf("t.DB.LatestEventIDs: %w", err) + } latestEvents := make([]string, len(latest)) for i, ev := range latest { latestEvents[i] = ev.EventID diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index c8bbe7705..70cc5d62c 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -621,12 +621,25 @@ func (r *Queryer) QueryPublishedRooms( func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error { res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) for _, tuple := range req.StateTuples { - ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey) - if err != nil { - return err - } - if ev != nil { - res.StateEvents[tuple] = ev + if tuple.StateKey == "*" && req.AllowWildcards { + events, err := r.DB.GetStateEventsWithEventType(ctx, req.RoomID, tuple.EventType) + if err != nil { + return err + } + for _, e := range events { + res.StateEvents[gomatrixserverlib.StateKeyTuple{ + EventType: e.Type(), + StateKey: *e.StateKey(), + }] = e + } + } else { + ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey) + if err != nil { + return err + } + if ev != nil { + res.StateEvents[tuple] = ev + } } } return nil @@ -696,7 +709,7 @@ func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUser } roomIDs = roomIDs[:j] - users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs) + users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs) if err != nil { return err } diff --git a/roomserver/state/state.go b/roomserver/state/state.go index e5f69521e..187b996cd 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -814,6 +814,7 @@ func (v *StateResolution) resolveConflictsV2( // events may be duplicated across these sets but that's OK. authSets := make(map[string][]*gomatrixserverlib.Event, len(conflicted)) authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3) + gotAuthEvents := make(map[string]struct{}, estimate*3) authDifference := make([]*gomatrixserverlib.Event, 0, estimate) // For each conflicted event, let's try and get the needed auth events. @@ -850,9 +851,22 @@ func (v *StateResolution) resolveConflictsV2( if err != nil { return nil, err } - authEvents = append(authEvents, authSets[key]...) + + // Only add auth events into the authEvents slice once, otherwise the + // check for the auth difference can become expensive and produce + // duplicate entries, which just waste memory and CPU time. + for _, event := range authSets[key] { + if _, ok := gotAuthEvents[event.EventID()]; !ok { + authEvents = append(authEvents, event) + gotAuthEvents[event.EventID()] = struct{}{} + } + } } + // Kill the reference to this so that the GC may pick it up, since we no + // longer need this after this point. + gotAuthEvents = nil // nolint:ineffassign + // This function helps us to work out whether an event exists in one of the // auth sets. isInAuthList := func(k string, event *gomatrixserverlib.Event) bool { @@ -866,11 +880,12 @@ func (v *StateResolution) resolveConflictsV2( // This function works out if an event exists in all of the auth sets. isInAllAuthLists := func(event *gomatrixserverlib.Event) bool { - found := true for k := range authSets { - found = found && isInAuthList(k, event) + if !isInAuthList(k, event) { + return false + } } - return found + return true } // Look through all of the auth events that we've been given and work out if diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index a9851e05b..a2b22b401 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -35,6 +35,11 @@ type Database interface { stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, ) (types.StateSnapshotNID, error) + + MissingAuthPrevEvents( + ctx context.Context, e *gomatrixserverlib.Event, + ) (missingAuth, missingPrev []string, err error) + // Look up the state of a room at each event for a list of string event IDs. // Returns an error if there is an error talking to the database. // The length of []types.StateAtEvent is guaranteed to equal the length of eventIDs if no error is returned. @@ -141,13 +146,14 @@ type Database interface { // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) + GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) - // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. - JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) + // JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms. + JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) // GetServerInRoom returns true if we think a server is in a given room or false otherwise. diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 48c2c35cd..127178743 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -66,7 +66,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( ` var selectJoinedUsersSetForRoomsSQL = "" + - "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + + " WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " GROUP BY target_nid" @@ -306,13 +307,10 @@ func (s *membershipStatements) SelectRoomsWithMembership( func (s *membershipStatements) SelectJoinedUsersSetForRooms( ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, + userNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]int, error) { - roomIDarray := make([]int64, len(roomNIDs)) - for i := range roomNIDs { - roomIDarray[i] = int64(roomNIDs[i]) - } stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt) - rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) + rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs)) if err != nil { return nil, err } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 810a18ef2..d4a2ee3b9 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -103,25 +103,6 @@ func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { return u.currentStateSnapshotNID } -func (u *RoomUpdater) MissingAuthPrevEvents( - ctx context.Context, e *gomatrixserverlib.Event, -) (missingAuth, missingPrev []string, err error) { - for _, authEventID := range e.AuthEventIDs() { - if nids, err := u.EventNIDs(ctx, []string{authEventID}); err != nil || len(nids) == 0 { - missingAuth = append(missingAuth, authEventID) - } - } - - for _, prevEventID := range e.PrevEventIDs() { - state, err := u.StateAtEventIDs(ctx, []string{prevEventID}) - if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) { - missingPrev = append(missingPrev, prevEventID) - } - } - - return -} - // StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { @@ -146,13 +127,6 @@ func (u *RoomUpdater) SnapshotNIDFromEventID( return u.d.snapshotNIDFromEventID(ctx, u.txn, eventID) } -func (u *RoomUpdater) StoreEvent( - ctx context.Context, event *gomatrixserverlib.Event, - authEventNIDs []types.EventNID, isRejected bool, -) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { - return u.d.storeEvent(ctx, u, event, authEventNIDs, isRejected) -} - func (u *RoomUpdater) StateBlockNIDs( ctx context.Context, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { @@ -212,44 +186,16 @@ func (u *RoomUpdater) EventIDs( return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs) } -func (u *RoomUpdater) EventNIDs( - ctx context.Context, eventIDs []string, -) (map[string]types.EventNID, error) { - return u.d.eventNIDs(ctx, u.txn, eventIDs, NoFilter) -} - -func (u *RoomUpdater) UnsentEventNIDs( - ctx context.Context, eventIDs []string, -) (map[string]types.EventNID, error) { - return u.d.eventNIDs(ctx, u.txn, eventIDs, FilterUnsentOnly) -} - func (u *RoomUpdater) StateAtEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs) } -func (u *RoomUpdater) StateEntriesForEventIDs( - ctx context.Context, eventIDs []string, -) ([]types.StateEntry, error) { - return u.d.EventsTable.BulkSelectStateEventByID(ctx, u.txn, eventIDs) -} - -func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false) -} - func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true) } -func (u *RoomUpdater) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, -) ([]types.EventNID, error) { - return u.d.getMembershipEventNIDsForRoom(ctx, u.txn, roomNID, joinOnly, localOnly) -} - // IsReferenced implements types.RoomRecentEventsUpdater func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index e270e121c..f87782776 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -674,6 +674,29 @@ func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true) } +func (d *Database) MissingAuthPrevEvents( + ctx context.Context, e *gomatrixserverlib.Event, +) (missingAuth, missingPrev []string, err error) { + authEventNIDs, err := d.EventNIDs(ctx, e.AuthEventIDs()) + if err != nil { + return nil, nil, fmt.Errorf("d.EventNIDs: %w", err) + } + for _, authEventID := range e.AuthEventIDs() { + if _, ok := authEventNIDs[authEventID]; !ok { + missingAuth = append(missingAuth, authEventID) + } + } + + for _, prevEventID := range e.PrevEventIDs() { + state, err := d.StateAtEventIDs(ctx, []string{prevEventID}) + if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) { + missingPrev = append(missingPrev, prevEventID) + } + } + + return +} + func (d *Database) assignRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, @@ -956,6 +979,62 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s return nil, nil } +// Same as GetStateEvent but returns all matching state events with this event type. Returns no error +// if there are no events with this event type. +func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) { + roomInfo, err := d.RoomInfo(ctx, roomID) + if err != nil { + return nil, err + } + if roomInfo == nil { + return nil, fmt.Errorf("room %s doesn't exist", roomID) + } + // e.g invited rooms + if roomInfo.IsStub { + return nil, nil + } + eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) + if err == sql.ErrNoRows { + // No rooms have an event of this type, otherwise we'd have an event type NID + return nil, nil + } + if err != nil { + return nil, err + } + entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID) + if err != nil { + return nil, err + } + var eventNIDs []types.EventNID + for _, e := range entries { + if e.EventTypeNID == eventTypeNID { + eventNIDs = append(eventNIDs, e.EventNID) + } + } + eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) + if err != nil { + eventIDs = map[types.EventNID]string{} + } + // return the events requested + eventPairs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs) + if err != nil { + return nil, err + } + if len(eventPairs) == 0 { + return nil, nil + } + var result []*gomatrixserverlib.HeaderedEvent + for _, pair := range eventPairs { + ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[pair.EventNID], pair.EventJSON, false, roomInfo.RoomVersion) + if err != nil { + return nil, err + } + result = append(result, ev.Headered(roomInfo.RoomVersion)) + } + + return result, nil +} + // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) { var membershipState tables.MembershipState @@ -1081,13 +1160,23 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu return result, nil } -// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. -func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { +// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms. +func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) { roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs) if err != nil { return nil, err } - userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs) + userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs) + if err != nil { + return nil, err + } + userNIDs := make([]types.EventStateKeyNID, 0, len(userNIDsMap)) + nidToUserID := make(map[types.EventStateKeyNID]string, len(userNIDsMap)) + for id, nid := range userNIDsMap { + userNIDs = append(userNIDs, nid) + nidToUserID[nid] = id + } + userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs) if err != nil { return nil, err } @@ -1097,10 +1186,6 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) stateKeyNIDs[i] = nid i++ } - nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs) - if err != nil { - return nil, err - } if len(nidToUserID) != len(userNIDToCount) { logrus.Warnf("SelectJoinedUsersSetForRooms found %d users but BulkSelectEventStateKey only returned state key NIDs for %d of them", len(userNIDToCount), len(nidToUserID)) } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 181b4b4c9..43567a94c 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -42,7 +42,8 @@ const membershipSchema = ` ` var selectJoinedUsersSetForRoomsSQL = "" + - "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + + " WHERE room_nid IN ($1) AND target_nid IN ($2) AND" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " GROUP BY target_nid" @@ -280,18 +281,22 @@ func (s *membershipStatements) SelectRoomsWithMembership( return roomNIDs, nil } -func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { - iRoomNIDs := make([]interface{}, len(roomNIDs)) - for i, v := range roomNIDs { - iRoomNIDs[i] = v +func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) { + params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs)) + for _, v := range roomNIDs { + params = append(params, v) } - query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) + for _, v := range userNIDs { + params = append(params, v) + } + query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) + query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1) var rows *sql.Rows var err error if txn != nil { - rows, err = txn.QueryContext(ctx, query, iRoomNIDs...) + rows, err = txn.QueryContext(ctx, query, params...) } else { - rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...) + rows, err = s.db.QueryContext(ctx, query, params...) } if err != nil { return nil, err diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index e3fed700b..04e3c96cc 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -127,9 +127,8 @@ type Membership interface { SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) - // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the - // counts of how many rooms they are joined. - SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) + // SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms. + SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 0af22c19a..29c781a88 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -654,11 +654,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo AuthEvents: res.AuthChain, StateEvents: stateEvents, } - eventsInOrder, err := respState.Events(rc.roomVersion) - if err != nil { - util.GetLogger(rc.ctx).WithError(err).Error("failed to calculate order to send events in MSC2836EventRelationshipsResponse") - return - } + eventsInOrder := respState.Events(rc.roomVersion) // everything gets sent as an outlier because auth chain events may be disjoint from the DAG // as may the threaded events. var ires []roomserver.InputRoomEvent @@ -669,7 +665,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo }) } // we've got the data by this point so use a background context - err = roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false) + err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver") } diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 3824c99a2..3bb56f4b3 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -18,17 +18,18 @@ package msc2946 import ( "context" "encoding/json" - "fmt" "net/http" + "net/url" + "sort" + "strconv" "strings" "sync" "time" + "github.com/google/uuid" "github.com/gorilla/mux" - chttputil "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" fs "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/internal/httputil" roomserver "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" @@ -39,15 +40,15 @@ import ( ) const ( - ConstCreateEventContentKey = "type" - ConstSpaceChildEventType = "m.space.child" - ConstSpaceParentEventType = "m.space.parent" + ConstCreateEventContentKey = "type" + ConstCreateEventContentValueSpace = "m.space" + ConstSpaceChildEventType = "m.space.child" + ConstSpaceParentEventType = "m.space.parent" ) -// Defaults sets the request defaults -func Defaults(r *gomatrixserverlib.MSC2946SpacesRequest) { - r.Limit = 2000 - r.MaxRoomsPerSpace = -1 +type MSC2946ClientResponse struct { + Rooms []gomatrixserverlib.MSC2946Room `json:"rooms"` + NextBatch string `json:"next_batch,omitempty"` } // Enable this MSC @@ -55,26 +56,11 @@ func Enable( base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI, fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, ) error { - db, err := NewDatabase(&base.Cfg.MSCs.Database) - if err != nil { - return fmt.Errorf("cannot enable MSC2946: %w", err) - } - hooks.Enable() - hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { - he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) - hookErr := db.StoreReference(context.Background(), he) - if hookErr != nil { - util.GetLogger(context.Background()).WithError(hookErr).WithField("event_id", he.EventID()).Error( - "failed to StoreReference", - ) - } - }) + clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, base.Cfg.Global.ServerName)) + base.PublicClientAPIMux.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) + base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) - base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/spaces", - httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(db, rsAPI, fsAPI, base.Cfg.Global.ServerName)), - ).Methods(http.MethodPost, http.MethodOptions) - - base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/spaces/{roomID}", httputil.MakeExternalAPI( + fedAPI := httputil.MakeExternalAPI( "msc2946_fed_spaces", func(req *http.Request) util.JSONResponse { fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( req, time.Now(), base.Cfg.Global.ServerName, keyRing, @@ -88,105 +74,99 @@ func Enable( return util.ErrorResponse(err) } roomID := params["roomID"] - return federatedSpacesHandler(req.Context(), fedReq, roomID, db, rsAPI, fsAPI, base.Cfg.Global.ServerName) + return federatedSpacesHandler(req.Context(), fedReq, roomID, rsAPI, fsAPI, base.Cfg.Global.ServerName) }, - )).Methods(http.MethodPost, http.MethodOptions) + ) + base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet) + base.PublicFederationAPIMux.Handle("/v1/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet) return nil } func federatedSpacesHandler( - ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, db Database, + ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, thisServer gomatrixserverlib.ServerName, ) util.JSONResponse { - inMemoryBatchCache := make(map[string]set) - var r gomatrixserverlib.MSC2946SpacesRequest - Defaults(&r) - if err := json.Unmarshal(fedReq.Content(), &r); err != nil { + u, err := url.Parse(fedReq.RequestURI()) + if err != nil { return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), + Code: 400, + JSON: jsonerror.InvalidParam("bad request uri"), } } - w := walker{ - req: &r, - rootRoomID: roomID, - serverName: fedReq.Origin(), - thisServer: thisServer, - ctx: ctx, - db: db, - rsAPI: rsAPI, - fsAPI: fsAPI, - inMemoryBatchCache: inMemoryBatchCache, - } - res := w.walk() - return util.JSONResponse{ - Code: 200, - JSON: res, + w := walker{ + rootRoomID: roomID, + serverName: fedReq.Origin(), + thisServer: thisServer, + ctx: ctx, + suggestedOnly: u.Query().Get("suggested_only") == "true", + limit: 1000, + // The main difference is that it does not recurse into spaces and does not support pagination. + // This is somewhat equivalent to a Client-Server request with a max_depth=1. + maxDepth: 1, + + rsAPI: rsAPI, + fsAPI: fsAPI, + // inline cache as we don't have pagination in federation mode + paginationCache: make(map[string]paginationInfo), } + return w.walk() } func spacesHandler( - db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, + rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, thisServer gomatrixserverlib.ServerName, ) func(*http.Request, *userapi.Device) util.JSONResponse { + // declared outside the returned handler so it persists between calls + // TODO: clear based on... time? + paginationCache := make(map[string]paginationInfo) + return func(req *http.Request, device *userapi.Device) util.JSONResponse { - inMemoryBatchCache := make(map[string]set) // Extract the room ID from the request. Sanity check request data. params, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } roomID := params["roomID"] - var r gomatrixserverlib.MSC2946SpacesRequest - Defaults(&r) - if resErr := chttputil.UnmarshalJSONRequest(req, &r); resErr != nil { - return *resErr - } w := walker{ - req: &r, - rootRoomID: roomID, - caller: device, - thisServer: thisServer, - ctx: req.Context(), + suggestedOnly: req.URL.Query().Get("suggested_only") == "true", + limit: parseInt(req.URL.Query().Get("limit"), 1000), + maxDepth: parseInt(req.URL.Query().Get("max_depth"), -1), + paginationToken: req.URL.Query().Get("from"), + rootRoomID: roomID, + caller: device, + thisServer: thisServer, + ctx: req.Context(), - db: db, - rsAPI: rsAPI, - fsAPI: fsAPI, - inMemoryBatchCache: inMemoryBatchCache, - } - res := w.walk() - return util.JSONResponse{ - Code: 200, - JSON: res, + rsAPI: rsAPI, + fsAPI: fsAPI, + paginationCache: paginationCache, } + return w.walk() } } +type paginationInfo struct { + processed set + unvisited []roomVisit +} + type walker struct { - req *gomatrixserverlib.MSC2946SpacesRequest - rootRoomID string - caller *userapi.Device - serverName gomatrixserverlib.ServerName - thisServer gomatrixserverlib.ServerName - db Database - rsAPI roomserver.RoomserverInternalAPI - fsAPI fs.FederationInternalAPI - ctx context.Context + rootRoomID string + caller *userapi.Device + serverName gomatrixserverlib.ServerName + thisServer gomatrixserverlib.ServerName + rsAPI roomserver.RoomserverInternalAPI + fsAPI fs.FederationInternalAPI + ctx context.Context + suggestedOnly bool + limit int + maxDepth int + paginationToken string - // user ID|device ID|batch_num => event/room IDs sent to client - inMemoryBatchCache map[string]set - mu sync.Mutex -} - -func (w *walker) roomIsExcluded(roomID string) bool { - for _, exclRoom := range w.req.ExcludeRooms { - if exclRoom == roomID { - return true - } - } - return false + paginationCache map[string]paginationInfo + mu sync.Mutex } func (w *walker) callerID() string { @@ -196,144 +176,207 @@ func (w *walker) callerID() string { return string(w.serverName) } -func (w *walker) alreadySent(id string) bool { - w.mu.Lock() - defer w.mu.Unlock() - m, ok := w.inMemoryBatchCache[w.callerID()] - if !ok { - return false +func (w *walker) newPaginationCache() (string, paginationInfo) { + p := paginationInfo{ + processed: make(set), + unvisited: nil, } - return m[id] + tok := uuid.NewString() + return tok, p } -func (w *walker) markSent(id string) { +func (w *walker) loadPaginationCache(paginationToken string) *paginationInfo { w.mu.Lock() defer w.mu.Unlock() - m := w.inMemoryBatchCache[w.callerID()] - if m == nil { - m = make(set) - } - m[id] = true - w.inMemoryBatchCache[w.callerID()] = m + p := w.paginationCache[paginationToken] + return &p } -func (w *walker) walk() *gomatrixserverlib.MSC2946SpacesResponse { - var res gomatrixserverlib.MSC2946SpacesResponse - // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms - unvisited := []string{w.rootRoomID} - processed := make(set) +func (w *walker) storePaginationCache(paginationToken string, cache paginationInfo) { + w.mu.Lock() + defer w.mu.Unlock() + w.paginationCache[paginationToken] = cache +} + +type roomVisit struct { + roomID string + depth int + vias []string // vias to query this room by +} + +func (w *walker) walk() util.JSONResponse { + if !w.authorised(w.rootRoomID) { + if w.caller != nil { + // CS API format + return util.JSONResponse{ + Code: 403, + JSON: jsonerror.Forbidden("room is unknown/forbidden"), + } + } else { + // SS API format + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("room is unknown/forbidden"), + } + } + } + + var discoveredRooms []gomatrixserverlib.MSC2946Room + + var cache *paginationInfo + if w.paginationToken != "" { + cache = w.loadPaginationCache(w.paginationToken) + if cache == nil { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("invalid from"), + } + } + } else { + tok, c := w.newPaginationCache() + cache = &c + w.paginationToken = tok + // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms + c.unvisited = append(c.unvisited, roomVisit{ + roomID: w.rootRoomID, + depth: 0, + }) + } + + processed := cache.processed + unvisited := cache.unvisited + + // Depth first -> stack data structure for len(unvisited) > 0 { - roomID := unvisited[0] - unvisited = unvisited[1:] - // If this room has already been processed, skip. NB: do not remember this between calls - if processed[roomID] || roomID == "" { + if len(discoveredRooms) >= w.limit { + break + } + + // pop the stack + rv := unvisited[len(unvisited)-1] + unvisited = unvisited[:len(unvisited)-1] + // If this room has already been processed, skip. + // If this room exceeds the specified depth, skip. + if processed.isSet(rv.roomID) || rv.roomID == "" || (w.maxDepth > 0 && rv.depth > w.maxDepth) { continue } + // Mark this room as processed. - processed[roomID] = true + processed.set(rv.roomID) + + // if this room is not a space room, skip. + var roomType string + create := w.stateEvent(rv.roomID, gomatrixserverlib.MRoomCreate, "") + if create != nil { + // escape the `.`s so gjson doesn't think it's nested + roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str + } // Collect rooms/events to send back (either locally or fetched via federation) - var discoveredRooms []gomatrixserverlib.MSC2946Room - var discoveredEvents []gomatrixserverlib.MSC2946StrippedEvent + var discoveredChildEvents []gomatrixserverlib.MSC2946StrippedEvent // If we know about this room and the caller is authorised (joined/world_readable) then pull // events locally - if w.roomExists(roomID) && w.authorised(roomID) { - // Get all `m.space.child` and `m.space.parent` state events for the room. *In addition*, get - // all `m.space.child` and `m.space.parent` state events which *point to* (via `state_key` or `content.room_id`) - // this room. This requires servers to store reverse lookups. - events, err := w.references(roomID) + if w.roomExists(rv.roomID) && w.authorised(rv.roomID) { + // Get all `m.space.child` state events for this room + events, err := w.childReferences(rv.roomID) if err != nil { - util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Error("failed to extract references for room") + util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Error("failed to extract references for room") continue } - discoveredEvents = events + discoveredChildEvents = events - pubRoom := w.publicRoomsChunk(roomID) - roomType := "" - create := w.stateEvent(roomID, gomatrixserverlib.MRoomCreate, "") - if create != nil { - // escape the `.`s so gjson doesn't think it's nested - roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str - } + pubRoom := w.publicRoomsChunk(rv.roomID) - // Add the total number of events to `PublicRoomsChunk` under `num_refs`. Add `PublicRoomsChunk` to `rooms`. discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{ - PublicRoom: *pubRoom, - NumRefs: len(discoveredEvents), - RoomType: roomType, + PublicRoom: *pubRoom, + RoomType: roomType, + ChildrenState: events, }) } else { // attempt to query this room over federation, as either we've never heard of it before // or we've left it and hence are not authorised (but info may be exposed regardless) - fedRes, err := w.federatedRoomInfo(roomID) + fedRes, err := w.federatedRoomInfo(rv.roomID, rv.vias) if err != nil { - util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Errorf("failed to query federated spaces") + util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Errorf("failed to query federated spaces") continue } if fedRes != nil { - discoveredRooms = fedRes.Rooms - discoveredEvents = fedRes.Events + discoveredChildEvents = fedRes.Room.ChildrenState + discoveredRooms = append(discoveredRooms, fedRes.Room) + if len(fedRes.Children) > 0 { + discoveredRooms = append(discoveredRooms, fedRes.Children...) + } + // mark this room as a space room as the federated server responded. + // we need to do this so we add the children of this room to the unvisited stack + // as these children may be rooms we do know about. + roomType = ConstCreateEventContentValueSpace } } - // If this room has not ever been in `rooms` (across multiple requests), send it now - for _, room := range discoveredRooms { - if !w.alreadySent(room.RoomID) && !w.roomIsExcluded(room.RoomID) { - res.Rooms = append(res.Rooms, room) - w.markSent(room.RoomID) - } + // don't walk the children + // if the parent is not a space room + if roomType != ConstCreateEventContentValueSpace { + continue } - uniqueRooms := make(set) - - // If this is the root room from the original request, insert all these events into `events` if - // they haven't been added before (across multiple requests). - if w.rootRoomID == roomID { - for _, ev := range discoveredEvents { - if !w.alreadySent(eventKey(&ev)) { - res.Events = append(res.Events, ev) - uniqueRooms[ev.RoomID] = true - uniqueRooms[spaceTargetStripped(&ev)] = true - w.markSent(eventKey(&ev)) - } - } - } else { - // Else add them to `events` honouring the `limit` and `max_rooms_per_space` values. If either - // are exceeded, stop adding events. If the event has already been added, do not add it again. - numAdded := 0 - for _, ev := range discoveredEvents { - if w.req.Limit > 0 && len(res.Events) >= w.req.Limit { - break - } - if w.req.MaxRoomsPerSpace > 0 && numAdded >= w.req.MaxRoomsPerSpace { - break - } - if w.alreadySent(eventKey(&ev)) { - continue - } - // Skip the room if it's part of exclude_rooms but ONLY IF the source matches, as we still - // want to catch arrows which point to excluded rooms. - if w.roomIsExcluded(ev.RoomID) { - continue - } - res.Events = append(res.Events, ev) - uniqueRooms[ev.RoomID] = true - uniqueRooms[spaceTargetStripped(&ev)] = true - w.markSent(eventKey(&ev)) - // we don't distinguish between child state events and parent state events for the purposes of - // max_rooms_per_space, maybe we should? - numAdded++ - } - } - - // For each referenced room ID in the events being returned to the caller (both parent and child) + // For each referenced room ID in the child events being returned to the caller // add the room ID to the queue of unvisited rooms. Loop from the beginning. - for roomID := range uniqueRooms { - unvisited = append(unvisited, roomID) + // We need to invert the order here because the child events are lo->hi on the timestamp, + // so we need to ensure we pop in the same lo->hi order, which won't be the case if we + // insert the highest timestamp last in a stack. + for i := len(discoveredChildEvents) - 1; i >= 0; i-- { + spaceContent := struct { + Via []string `json:"via"` + }{} + ev := discoveredChildEvents[i] + _ = json.Unmarshal(ev.Content, &spaceContent) + unvisited = append(unvisited, roomVisit{ + roomID: ev.StateKey, + depth: rv.depth + 1, + vias: spaceContent.Via, + }) } } - return &res + + if len(unvisited) > 0 { + // we still have more rooms so we need to send back a pagination token, + // we probably hit a room limit + cache.processed = processed + cache.unvisited = unvisited + w.storePaginationCache(w.paginationToken, *cache) + } else { + // clear the pagination token so we don't send it back to the client + // Note we do NOT nuke the cache just in case this response is lost + // and the client retries it. + w.paginationToken = "" + } + + if w.caller != nil { + // return CS API format + return util.JSONResponse{ + Code: 200, + JSON: MSC2946ClientResponse{ + Rooms: discoveredRooms, + NextBatch: w.paginationToken, + }, + } + } + // return SS API format + // the first discovered room will be the room asked for, and subsequent ones the depth=1 children + if len(discoveredRooms) == 0 { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("room is unknown/forbidden"), + } + } + return util.JSONResponse{ + Code: 200, + JSON: gomatrixserverlib.MSC2946SpacesResponse{ + Room: discoveredRooms[0], + Children: discoveredRooms[1:], + }, + } } func (w *walker) stateEvent(roomID, evType, stateKey string) *gomatrixserverlib.HeaderedEvent { @@ -366,42 +409,19 @@ func (w *walker) publicRoomsChunk(roomID string) *gomatrixserverlib.PublicRoom { // federatedRoomInfo returns more of the spaces graph from another server. Returns nil if this was // unsuccessful. -func (w *walker) federatedRoomInfo(roomID string) (*gomatrixserverlib.MSC2946SpacesResponse, error) { +func (w *walker) federatedRoomInfo(roomID string, vias []string) (*gomatrixserverlib.MSC2946SpacesResponse, error) { // only do federated requests for client requests if w.caller == nil { return nil, nil } - // extract events which point to this room ID and extract their vias - events, err := w.db.References(w.ctx, roomID) - if err != nil { - return nil, fmt.Errorf("failed to get References events: %w", err) - } - vias := make(set) - for _, ev := range events { - if ev.StateKeyEquals(roomID) { - // event points at this room, extract vias - content := struct { - Vias []string `json:"via"` - }{} - if err = json.Unmarshal(ev.Content(), &content); err != nil { - continue // silently ignore corrupted state events - } - for _, v := range content.Vias { - vias[v] = true - } - } - } - util.GetLogger(w.ctx).Infof("Querying federatedRoomInfo via %+v", vias) + util.GetLogger(w.ctx).Infof("Querying %s via %+v", roomID, vias) ctx := context.Background() // query more of the spaces graph using these servers - for serverName := range vias { + for _, serverName := range vias { if serverName == string(w.thisServer) { continue } - res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, gomatrixserverlib.MSC2946SpacesRequest{ - Limit: w.req.Limit, - MaxRoomsPerSpace: w.req.MaxRoomsPerSpace, - }) + res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly) if err != nil { util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName) continue @@ -501,7 +521,7 @@ func (w *walker) authorisedUser(roomID string) bool { hisVisEv := queryRes.StateEvents[hisVisTuple] if memberEv != nil { membership, _ := memberEv.Membership() - if membership == gomatrixserverlib.Join { + if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite { return true } } @@ -514,40 +534,85 @@ func (w *walker) authorisedUser(roomID string) bool { return false } -// references returns all references pointing to or from this room. -func (w *walker) references(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) { - events, err := w.db.References(w.ctx, roomID) +// references returns all child references pointing to or from this room. +func (w *walker) childReferences(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) { + createTuple := gomatrixserverlib.StateKeyTuple{ + EventType: gomatrixserverlib.MRoomCreate, + StateKey: "", + } + var res roomserver.QueryCurrentStateResponse + err := w.rsAPI.QueryCurrentState(context.Background(), &roomserver.QueryCurrentStateRequest{ + RoomID: roomID, + AllowWildcards: true, + StateTuples: []gomatrixserverlib.StateKeyTuple{ + createTuple, { + EventType: ConstSpaceChildEventType, + StateKey: "*", + }, + }, + }, &res) if err != nil { return nil, err } - el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(events)) - for _, ev := range events { + + // don't return any child refs if the room is not a space room + if res.StateEvents[createTuple] != nil { + // escape the `.`s so gjson doesn't think it's nested + roomType := gjson.GetBytes(res.StateEvents[createTuple].Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str + if roomType != ConstCreateEventContentValueSpace { + return nil, nil + } + } + delete(res.StateEvents, createTuple) + + el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(res.StateEvents)) + for _, ev := range res.StateEvents { + content := gjson.ParseBytes(ev.Content()) // only return events that have a `via` key as per MSC1772 // else we'll incorrectly walk redacted events (as the link // is in the state_key) - if gjson.GetBytes(ev.Content(), "via").Exists() { + if content.Get("via").Exists() { strip := stripped(ev.Event) if strip == nil { continue } + // if suggested only and this child isn't suggested, skip it. + // if suggested only = false we include everything so don't need to check the content. + if w.suggestedOnly && !content.Get("suggested").Bool() { + continue + } el = append(el, *strip) } } + // sort by origin_server_ts as per MSC2946 + sort.Slice(el, func(i, j int) bool { + return el[i].OriginServerTS < el[j].OriginServerTS + }) + return el, nil } -type set map[string]bool +type set map[string]struct{} + +func (s set) set(val string) { + s[val] = struct{}{} +} +func (s set) isSet(val string) bool { + _, ok := s[val] + return ok +} func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEvent { if ev.StateKey() == nil { return nil } return &gomatrixserverlib.MSC2946StrippedEvent{ - Type: ev.Type(), - StateKey: *ev.StateKey(), - Content: ev.Content(), - Sender: ev.Sender(), - RoomID: ev.RoomID(), + Type: ev.Type(), + StateKey: *ev.StateKey(), + Content: ev.Content(), + Sender: ev.Sender(), + RoomID: ev.RoomID(), + OriginServerTS: ev.OriginServerTS(), } } @@ -567,3 +632,11 @@ func spaceTargetStripped(event *gomatrixserverlib.MSC2946StrippedEvent) string { } return "" } + +func parseInt(intstr string, defaultVal int) int { + i, err := strconv.ParseInt(intstr, 10, 32) + if err != nil { + return defaultVal + } + return int(i) +} diff --git a/setup/mscs/msc2946/msc2946_test.go b/setup/mscs/msc2946/msc2946_test.go deleted file mode 100644 index e8066c34d..000000000 --- a/setup/mscs/msc2946/msc2946_test.go +++ /dev/null @@ -1,464 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package msc2946_test - -import ( - "bytes" - "context" - "crypto/ed25519" - "encoding/json" - "io/ioutil" - "net/http" - "net/url" - "testing" - "time" - - "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/internal/hooks" - "github.com/matrix-org/dendrite/internal/httputil" - roomserver "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/mscs/msc2946" - userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" -) - -var ( - client = &http.Client{ - Timeout: 10 * time.Second, - } - roomVer = gomatrixserverlib.RoomVersionV6 -) - -// Basic sanity check of MSC2946 logic. Tests a single room with a few state events -// and a bit of recursion to subspaces. Makes a graph like: -// Root -// ____|_____ -// | | | -// R1 R2 S1 -// |_________ -// | | | -// R3 R4 S2 -// | <-- this link is just a parent, not a child -// R5 -// -// Alice is not joined to R4, but R4 is "world_readable". -func TestMSC2946(t *testing.T) { - alice := "@alice:localhost" - // give access token to alice - nopUserAPI := &testUserAPI{ - accessTokens: make(map[string]userapi.Device), - } - nopUserAPI.accessTokens["alice"] = userapi.Device{ - AccessToken: "alice", - DisplayName: "Alice", - UserID: alice, - } - rootSpace := "!rootspace:localhost" - subSpaceS1 := "!subspaceS1:localhost" - subSpaceS2 := "!subspaceS2:localhost" - room1 := "!room1:localhost" - room2 := "!room2:localhost" - room3 := "!room3:localhost" - room4 := "!room4:localhost" - empty := "" - room5 := "!room5:localhost" - allRooms := []string{ - rootSpace, subSpaceS1, subSpaceS2, - room1, room2, room3, room4, room5, - } - rootToR1 := mustCreateEvent(t, fledglingEvent{ - RoomID: rootSpace, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room1, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - rootToR2 := mustCreateEvent(t, fledglingEvent{ - RoomID: rootSpace, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room2, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - rootToS1 := mustCreateEvent(t, fledglingEvent{ - RoomID: rootSpace, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &subSpaceS1, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - s1ToR3 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room3, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - s1ToR4 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room4, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - s1ToS2 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &subSpaceS2, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - // This is a parent link only - s2ToR5 := mustCreateEvent(t, fledglingEvent{ - RoomID: room5, - Sender: alice, - Type: msc2946.ConstSpaceParentEventType, - StateKey: &subSpaceS2, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - // history visibility for R4 - r4HisVis := mustCreateEvent(t, fledglingEvent{ - RoomID: room4, - Sender: "@someone:localhost", - Type: gomatrixserverlib.MRoomHistoryVisibility, - StateKey: &empty, - Content: map[string]interface{}{ - "history_visibility": "world_readable", - }, - }) - var joinEvents []*gomatrixserverlib.HeaderedEvent - for _, roomID := range allRooms { - if roomID == room4 { - continue // not joined to that room - } - joinEvents = append(joinEvents, mustCreateEvent(t, fledglingEvent{ - RoomID: roomID, - Sender: alice, - StateKey: &alice, - Type: gomatrixserverlib.MRoomMember, - Content: map[string]interface{}{ - "membership": "join", - }, - })) - } - roomNameTuple := gomatrixserverlib.StateKeyTuple{ - EventType: "m.room.name", - StateKey: "", - } - hisVisTuple := gomatrixserverlib.StateKeyTuple{ - EventType: "m.room.history_visibility", - StateKey: "", - } - nopRsAPI := &testRoomserverAPI{ - joinEvents: joinEvents, - events: map[string]*gomatrixserverlib.HeaderedEvent{ - rootToR1.EventID(): rootToR1, - rootToR2.EventID(): rootToR2, - rootToS1.EventID(): rootToS1, - s1ToR3.EventID(): s1ToR3, - s1ToR4.EventID(): s1ToR4, - s1ToS2.EventID(): s1ToS2, - s2ToR5.EventID(): s2ToR5, - r4HisVis.EventID(): r4HisVis, - }, - pubRoomState: map[string]map[gomatrixserverlib.StateKeyTuple]string{ - rootSpace: { - roomNameTuple: "Root", - hisVisTuple: "shared", - }, - subSpaceS1: { - roomNameTuple: "Sub-Space 1", - hisVisTuple: "joined", - }, - subSpaceS2: { - roomNameTuple: "Sub-Space 2", - hisVisTuple: "shared", - }, - room1: { - hisVisTuple: "joined", - }, - room2: { - hisVisTuple: "joined", - }, - room3: { - hisVisTuple: "joined", - }, - room4: { - hisVisTuple: "world_readable", - }, - room5: { - hisVisTuple: "joined", - }, - }, - } - allEvents := []*gomatrixserverlib.HeaderedEvent{ - rootToR1, rootToR2, rootToS1, - s1ToR3, s1ToR4, s1ToS2, - s2ToR5, r4HisVis, - } - allEvents = append(allEvents, joinEvents...) - router := injectEvents(t, nopUserAPI, nopRsAPI, allEvents) - cancel := runServer(t, router) - defer cancel() - - t.Run("returns no events for unknown rooms", func(t *testing.T) { - res := postSpaces(t, 200, "alice", "!unknown:localhost", newReq(t, map[string]interface{}{})) - if len(res.Events) > 0 { - t.Errorf("got %d events, want 0", len(res.Events)) - } - if len(res.Rooms) > 0 { - t.Errorf("got %d rooms, want 0", len(res.Rooms)) - } - }) - t.Run("returns the entire graph", func(t *testing.T) { - res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{})) - if len(res.Events) != 7 { - t.Errorf("got %d events, want 7", len(res.Events)) - } - if len(res.Rooms) != len(allRooms) { - t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms)) - } - }) - t.Run("can update the graph", func(t *testing.T) { - // remove R3 from the graph - rmS1ToR3 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room3, - Content: map[string]interface{}{}, // redacted - }) - nopRsAPI.events[rmS1ToR3.EventID()] = rmS1ToR3 - hooks.Run(hooks.KindNewEventPersisted, rmS1ToR3) - - res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{})) - if len(res.Events) != 6 { // one less since we don't return redacted events - t.Errorf("got %d events, want 6", len(res.Events)) - } - if len(res.Rooms) != (len(allRooms) - 1) { // one less due to lack of R3 - t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms)-1) - } - }) -} - -func newReq(t *testing.T, jsonBody map[string]interface{}) *gomatrixserverlib.MSC2946SpacesRequest { - t.Helper() - b, err := json.Marshal(jsonBody) - if err != nil { - t.Fatalf("Failed to marshal request: %s", err) - } - var r gomatrixserverlib.MSC2946SpacesRequest - if err := json.Unmarshal(b, &r); err != nil { - t.Fatalf("Failed to unmarshal request: %s", err) - } - return &r -} - -func runServer(t *testing.T, router *mux.Router) func() { - t.Helper() - externalServ := &http.Server{ - Addr: string(":8010"), - WriteTimeout: 60 * time.Second, - Handler: router, - } - go func() { - externalServ.ListenAndServe() - }() - // wait to listen on the port - time.Sleep(500 * time.Millisecond) - return func() { - externalServ.Shutdown(context.TODO()) - } -} - -func postSpaces(t *testing.T, expectCode int, accessToken, roomID string, req *gomatrixserverlib.MSC2946SpacesRequest) *gomatrixserverlib.MSC2946SpacesResponse { - t.Helper() - var r gomatrixserverlib.MSC2946SpacesRequest - msc2946.Defaults(&r) - data, err := json.Marshal(req) - if err != nil { - t.Fatalf("failed to marshal request: %s", err) - } - httpReq, err := http.NewRequest( - "POST", "http://localhost:8010/_matrix/client/unstable/org.matrix.msc2946/rooms/"+url.PathEscape(roomID)+"/spaces", - bytes.NewBuffer(data), - ) - httpReq.Header.Set("Authorization", "Bearer "+accessToken) - if err != nil { - t.Fatalf("failed to prepare request: %s", err) - } - res, err := client.Do(httpReq) - if err != nil { - t.Fatalf("failed to do request: %s", err) - } - if res.StatusCode != expectCode { - body, _ := ioutil.ReadAll(res.Body) - t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body)) - } - if res.StatusCode == 200 { - var result gomatrixserverlib.MSC2946SpacesResponse - body, err := ioutil.ReadAll(res.Body) - if err != nil { - t.Fatalf("response 200 OK but failed to read response body: %s", err) - } - t.Logf("Body: %s", string(body)) - if err := json.Unmarshal(body, &result); err != nil { - t.Fatalf("response 200 OK but failed to deserialise JSON : %s\nbody: %s", err, string(body)) - } - return &result - } - return nil -} - -type testUserAPI struct { - userapi.UserInternalAPITrace - accessTokens map[string]userapi.Device -} - -func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error { - dev, ok := u.accessTokens[req.AccessToken] - if !ok { - res.Err = "unknown token" - return nil - } - res.Device = &dev - return nil -} - -type testRoomserverAPI struct { - // use a trace API as it implements method stubs so we don't need to have them here. - // We'll override the functions we care about. - roomserver.RoomserverInternalAPITrace - joinEvents []*gomatrixserverlib.HeaderedEvent - events map[string]*gomatrixserverlib.HeaderedEvent - pubRoomState map[string]map[gomatrixserverlib.StateKeyTuple]string -} - -func (r *testRoomserverAPI) QueryServerJoinedToRoom(ctx context.Context, req *roomserver.QueryServerJoinedToRoomRequest, res *roomserver.QueryServerJoinedToRoomResponse) error { - res.IsInRoom = true - res.RoomExists = true - return nil -} - -func (r *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *roomserver.QueryBulkStateContentRequest, res *roomserver.QueryBulkStateContentResponse) error { - res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string) - for _, roomID := range req.RoomIDs { - pubRoomData, ok := r.pubRoomState[roomID] - if ok { - res.Rooms[roomID] = pubRoomData - } - } - return nil -} - -func (r *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *roomserver.QueryCurrentStateRequest, res *roomserver.QueryCurrentStateResponse) error { - res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) - checkEvent := func(he *gomatrixserverlib.HeaderedEvent) { - if he.RoomID() != req.RoomID { - return - } - if he.StateKey() == nil { - return - } - tuple := gomatrixserverlib.StateKeyTuple{ - EventType: he.Type(), - StateKey: *he.StateKey(), - } - for _, t := range req.StateTuples { - if t == tuple { - res.StateEvents[t] = he - } - } - } - for _, he := range r.joinEvents { - checkEvent(he) - } - for _, he := range r.events { - checkEvent(he) - } - return nil -} - -func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*gomatrixserverlib.HeaderedEvent) *mux.Router { - t.Helper() - cfg := &config.Dendrite{} - cfg.Defaults(true) - cfg.Global.ServerName = "localhost" - cfg.MSCs.Database.ConnectionString = "file:msc2946_test.db" - cfg.MSCs.MSCs = []string{"msc2946"} - base := &base.BaseDendrite{ - Cfg: cfg, - PublicClientAPIMux: mux.NewRouter().PathPrefix(httputil.PublicClientPathPrefix).Subrouter(), - PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(), - } - - err := msc2946.Enable(base, rsAPI, userAPI, nil, nil) - if err != nil { - t.Fatalf("failed to enable MSC2946: %s", err) - } - for _, ev := range events { - hooks.Run(hooks.KindNewEventPersisted, ev) - } - return base.PublicClientAPIMux -} - -type fledglingEvent struct { - Type string - StateKey *string - Content interface{} - Sender string - RoomID string -} - -func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) { - t.Helper() - seed := make([]byte, ed25519.SeedSize) // zero seed - key := ed25519.NewKeyFromSeed(seed) - eb := gomatrixserverlib.EventBuilder{ - Sender: ev.Sender, - Depth: 999, - Type: ev.Type, - StateKey: ev.StateKey, - RoomID: ev.RoomID, - } - err := eb.SetContent(ev.Content) - if err != nil { - t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content) - } - // make sure the origin_server_ts changes so we can test recency - time.Sleep(1 * time.Millisecond) - signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer) - if err != nil { - t.Fatalf("mustCreateEvent: failed to sign event: %s", err) - } - h := signedEvent.Headered(roomVer) - return h -} diff --git a/setup/mscs/msc2946/storage.go b/setup/mscs/msc2946/storage.go deleted file mode 100644 index 20db18594..000000000 --- a/setup/mscs/msc2946/storage.go +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package msc2946 - -import ( - "context" - "database/sql" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" -) - -var ( - relTypes = map[string]int{ - ConstSpaceChildEventType: 1, - ConstSpaceParentEventType: 2, - } -) - -type Database interface { - // StoreReference persists a child or parent space mapping. - StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error - // References returns all events which have the given roomID as a parent or child space. - References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) -} - -type DB struct { - db *sql.DB - writer sqlutil.Writer - insertEdgeStmt *sql.Stmt - selectEdgesStmt *sql.Stmt -} - -// NewDatabase loads the database for msc2836 -func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) { - if dbOpts.ConnectionString.IsPostgres() { - return newPostgresDatabase(dbOpts) - } - return newSQLiteDatabase(dbOpts) -} - -func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { - d := DB{ - writer: sqlutil.NewDummyWriter(), - } - var err error - if d.db, err = sqlutil.Open(dbOpts); err != nil { - return nil, err - } - _, err = d.db.Exec(` - CREATE TABLE IF NOT EXISTS msc2946_edges ( - room_version TEXT NOT NULL, - -- the room ID of the event, the source of the arrow - source_room_id TEXT NOT NULL, - -- the target room ID, the arrow destination - dest_room_id TEXT NOT NULL, - -- the kind of relation, either child or parent (1,2) - rel_type SMALLINT NOT NULL, - event_json TEXT NOT NULL, - CONSTRAINT msc2946_edges_uniq UNIQUE (source_room_id, dest_room_id, rel_type) - ); - `) - if err != nil { - return nil, err - } - if d.insertEdgeStmt, err = d.db.Prepare(` - INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json) - VALUES($1, $2, $3, $4, $5) - ON CONFLICT ON CONSTRAINT msc2946_edges_uniq DO UPDATE SET event_json = $5 - `); err != nil { - return nil, err - } - if d.selectEdgesStmt, err = d.db.Prepare(` - SELECT room_version, event_json FROM msc2946_edges - WHERE source_room_id = $1 OR dest_room_id = $2 - `); err != nil { - return nil, err - } - return &d, err -} - -func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { - d := DB{ - writer: sqlutil.NewExclusiveWriter(), - } - var err error - if d.db, err = sqlutil.Open(dbOpts); err != nil { - return nil, err - } - _, err = d.db.Exec(` - CREATE TABLE IF NOT EXISTS msc2946_edges ( - room_version TEXT NOT NULL, - -- the room ID of the event, the source of the arrow - source_room_id TEXT NOT NULL, - -- the target room ID, the arrow destination - dest_room_id TEXT NOT NULL, - -- the kind of relation, either child or parent (1,2) - rel_type SMALLINT NOT NULL, - event_json TEXT NOT NULL, - UNIQUE (source_room_id, dest_room_id, rel_type) - ); - `) - if err != nil { - return nil, err - } - if d.insertEdgeStmt, err = d.db.Prepare(` - INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json) - VALUES($1, $2, $3, $4, $5) - ON CONFLICT (source_room_id, dest_room_id, rel_type) DO UPDATE SET event_json = $5 - `); err != nil { - return nil, err - } - if d.selectEdgesStmt, err = d.db.Prepare(` - SELECT room_version, event_json FROM msc2946_edges - WHERE source_room_id = $1 OR dest_room_id = $2 - `); err != nil { - return nil, err - } - return &d, err -} - -func (d *DB) StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error { - target := SpaceTarget(he) - if target == "" { - return nil // malformed event - } - relType := relTypes[he.Type()] - _, err := d.insertEdgeStmt.ExecContext(ctx, he.RoomVersion, he.RoomID(), target, relType, he.JSON()) - return err -} - -func (d *DB) References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) { - rows, err := d.selectEdgesStmt.QueryContext(ctx, roomID, roomID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "failed to close References") - refs := make([]*gomatrixserverlib.HeaderedEvent, 0) - for rows.Next() { - var roomVer string - var jsonBytes []byte - if err := rows.Scan(&roomVer, &jsonBytes); err != nil { - return nil, err - } - ev, err := gomatrixserverlib.NewEventFromTrustedJSON(jsonBytes, false, gomatrixserverlib.RoomVersion(roomVer)) - if err != nil { - return nil, err - } - he := ev.Headered(gomatrixserverlib.RoomVersion(roomVer)) - refs = append(refs, he) - } - return refs, nil -} - -// SpaceTarget returns the destination room ID for the space event. This is either a child or a parent -// depending on the event type. -func SpaceTarget(he *gomatrixserverlib.HeaderedEvent) string { - if he.StateKey() == nil { - return "" // no-op - } - switch he.Type() { - case ConstSpaceParentEventType: - return *he.StateKey() - case ConstSpaceChildEventType: - return *he.StateKey() - } - return "" -} diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 37a9e2d39..dc4acd8da 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -82,7 +82,16 @@ func DeviceListCatchup( util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed") return to, hasNew, nil } - // QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user. + + // Work out which user IDs we care about — that includes those in the original request, + // the response from QueryKeyChanges (which includes ALL users who have changed keys) + // as well as every user who has a join or leave event in the current sync response. We + // will request information about which rooms these users are joined to, so that we can + // see if we still share any rooms with them. + joinUserIDs, leaveUserIDs := membershipEvents(res) + queryRes.UserIDs = append(queryRes.UserIDs, joinUserIDs...) + queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...) + queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs) var sharedUsersMap map[string]int sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs) util.GetLogger(ctx).Debugf( @@ -100,9 +109,8 @@ func DeviceListCatchup( userSet[userID] = true } } - // if the response has any join/leave events, add them now. + // Finally, add in users who have joined or left. // TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them. - joinUserIDs, leaveUserIDs := membershipEvents(res) for _, userID := range joinUserIDs { if !userSet[userID] { res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) @@ -213,7 +221,8 @@ func filterSharedUsers( var result []string var sharedUsersRes roomserverAPI.QuerySharedUsersResponse err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{ - UserID: userID, + UserID: userID, + OtherUserIDs: usersWithChangedKeys, }, &sharedUsersRes) if err != nil { // default to all users so we do needless queries rather than miss some important device update diff --git a/sytest-blacklist b/sytest-blacklist index 16abce8da..e8617dcdf 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -24,6 +24,7 @@ Local device key changes get to remote servers with correct prev_id # Flakey Local device key changes appear in /keys/changes +/context/ with lazy_load_members filter works # we don't support groups Remove group category diff --git a/sytest-whitelist b/sytest-whitelist index 187a0f475..12522cfb3 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -596,4 +596,12 @@ Device list doesn't change if remote server is down /context/ on joined room works /context/ on non world readable room does not work /context/ returns correct number of events -/context/ with lazy_load_members filter works \ No newline at end of file +/context/ with lazy_load_members filter works +Can query remote device keys using POST after notification +Device deletion propagates over federation +Get left notifs in sync and /keys/changes when other user leaves +Remote banned user is kicked and may not rejoin until unbanned +registration remembers parameters +registration accepts non-ascii passwords +registration with inhibit_login inhibits login +