diff --git a/clientapi/routing/auth_fallback.go b/clientapi/routing/auth_fallback.go index 839ca9e54..abfe830fb 100644 --- a/clientapi/routing/auth_fallback.go +++ b/clientapi/routing/auth_fallback.go @@ -162,7 +162,7 @@ func AuthFallback( } // Success. Add recaptcha as a completed login flow - AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) serveSuccess() return nil diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 7ecab9d4e..4426b7fdc 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -70,7 +70,7 @@ func UploadCrossSigningDeviceKeys( if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { return *authErr } - AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) uploadReq.UserID = device.UserID keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes) diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 499510193..acac60fa5 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -74,7 +74,7 @@ func Password( if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil { return *authErr } - AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) // Check the new password strength. if resErr = validatePassword(r.NewPassword); resErr != nil { diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index d00d9886e..10cfa4325 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -72,14 +72,19 @@ func init() { // sessionsDict keeps track of completed auth stages for each session. // It shouldn't be passed by value because it contains a mutex. type sessionsDict struct { - sync.Mutex + sync.RWMutex sessions map[string][]authtypes.LoginType + params map[string]registerRequest + timer map[string]*time.Timer } -// GetCompletedStages returns the completed stages for a session. -func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType { - d.Lock() - defer d.Unlock() +// defaultTimeout is the timeout used to clean up sessions +const defaultTimeOut = time.Minute * 5 + +// getCompletedStages returns the completed stages for a session. +func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType { + d.RLock() + defer d.RUnlock() if completedStages, ok := d.sessions[sessionID]; ok { return completedStages @@ -88,28 +93,79 @@ func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginTyp return make([]authtypes.LoginType, 0) } -func newSessionsDict() *sessionsDict { - return &sessionsDict{ - sessions: make(map[string][]authtypes.LoginType), +// addParams adds a registerRequest to a sessionID and starts a timer to delete that registerRequest +func (d *sessionsDict) addParams(sessionID string, r registerRequest) { + d.startTimer(defaultTimeOut, sessionID) + d.Lock() + defer d.Unlock() + d.params[sessionID] = r +} + +func (d *sessionsDict) getParams(sessionID string) (registerRequest, bool) { + d.RLock() + defer d.RUnlock() + r, ok := d.params[sessionID] + return r, ok +} + +// deleteSession cleans up a given session, either because the registration completed +// successfully, or because a given timeout (default: 5min) was reached. +func (d *sessionsDict) deleteSession(sessionID string) { + d.Lock() + defer d.Unlock() + delete(d.params, sessionID) + delete(d.sessions, sessionID) + // stop the timer, e.g. because the registration was completed + if t, ok := d.timer[sessionID]; ok { + if !t.Stop() { + select { + case <-t.C: + default: + } + } + delete(d.timer, sessionID) } } -// AddCompletedSessionStage records that a session has completed an auth stage. -func AddCompletedSessionStage(sessionID string, stage authtypes.LoginType) { - sessions.Lock() - defer sessions.Unlock() +func newSessionsDict() *sessionsDict { + return &sessionsDict{ + sessions: make(map[string][]authtypes.LoginType), + params: make(map[string]registerRequest), + timer: make(map[string]*time.Timer), + } +} - for _, completedStage := range sessions.sessions[sessionID] { +func (d *sessionsDict) startTimer(duration time.Duration, sessionID string) { + d.Lock() + defer d.Unlock() + t, ok := d.timer[sessionID] + if ok { + if !t.Stop() { + <-t.C + } + t.Reset(duration) + return + } + d.timer[sessionID] = time.AfterFunc(duration, func() { + d.deleteSession(sessionID) + }) +} + +// addCompletedSessionStage records that a session has completed an auth stage +// also starts a timer to delete the session once done. +func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtypes.LoginType) { + d.startTimer(defaultTimeOut, sessionID) + d.Lock() + defer d.Unlock() + for _, completedStage := range d.sessions[sessionID] { if completedStage == stage { return } } - sessions.sessions[sessionID] = append(sessions.sessions[sessionID], stage) + d.sessions[sessionID] = append(sessions.sessions[sessionID], stage) } var ( - // TODO: Remove old sessions. Need to do so on a session-specific timeout. - // sessions stores the completed flow stages for all sessions. Referenced using their sessionID. sessions = newSessionsDict() validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) ) @@ -167,7 +223,7 @@ func newUserInteractiveResponse( params map[string]interface{}, ) userInteractiveResponse { return userInteractiveResponse{ - fs, sessions.GetCompletedStages(sessionID), params, sessionID, + fs, sessions.getCompletedStages(sessionID), params, sessionID, } } @@ -645,12 +701,12 @@ func handleRegistrationFlow( } // Add Recaptcha to the list of completed registration stages - AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) case authtypes.LoginTypeDummy: // there is nothing to do // Add Dummy to the list of completed registration stages - AddCompletedSessionStage(sessionID, authtypes.LoginTypeDummy) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeDummy) case "": // An empty auth type means that we want to fetch the available @@ -666,7 +722,7 @@ func handleRegistrationFlow( // Check if the user's registration flow has been completed successfully // A response with current registration flow and remaining available methods // will be returned if a flow has not been successfully completed yet - return checkAndCompleteFlow(sessions.GetCompletedStages(sessionID), + return checkAndCompleteFlow(sessions.getCompletedStages(sessionID), req, r, sessionID, cfg, userAPI) } @@ -708,7 +764,7 @@ func handleApplicationServiceRegistration( // Don't need to worry about appending to registration stages as // application service registration is entirely separate. return completeRegistration( - req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), + req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, ) } @@ -727,11 +783,11 @@ func checkAndCompleteFlow( if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue return completeRegistration( - req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), + req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, ) } - + sessions.addParams(sessionID, r) // There are still more stages to complete. // Return the flows and those that have been completed. return util.JSONResponse{ @@ -750,11 +806,25 @@ func checkAndCompleteFlow( func completeRegistration( ctx context.Context, userAPI userapi.UserInternalAPI, - username, password, appserviceID, ipAddr, userAgent string, + username, password, appserviceID, ipAddr, userAgent, sessionID string, inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string, accType userapi.AccountType, ) util.JSONResponse { + var registrationOK bool + defer func() { + if registrationOK { + sessions.deleteSession(sessionID) + } + }() + + if data, ok := sessions.getParams(sessionID); ok { + username = data.Username + password = data.Password + deviceID = data.DeviceID + displayName = data.InitialDisplayName + inhibitLogin = data.InhibitLogin + } if username == "" { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -795,6 +865,7 @@ func completeRegistration( // Check whether inhibit_login option is set. If so, don't create an access // token or a device for this user if inhibitLogin { + registrationOK = true return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ @@ -828,6 +899,7 @@ func completeRegistration( } } + registrationOK = true return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ @@ -976,5 +1048,5 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS if ssrr.Admin { accType = userapi.AccountTypeAdmin } - return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID, accType) + return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) } diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 1f615dc26..c6b7e61cf 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -17,6 +17,7 @@ package routing import ( "regexp" "testing" + "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/setup/config" @@ -140,7 +141,7 @@ func TestFlowCheckingExtraneousIncorrectInput(t *testing.T) { func TestEmptyCompletedFlows(t *testing.T) { fakeEmptySessions := newSessionsDict() fakeSessionID := "aRandomSessionIDWhichDoesNotExist" - ret := fakeEmptySessions.GetCompletedStages(fakeSessionID) + ret := fakeEmptySessions.getCompletedStages(fakeSessionID) // check for [] if ret == nil || len(ret) != 0 { @@ -208,3 +209,45 @@ func TestValidationOfApplicationServices(t *testing.T) { t.Errorf("user_id should not have been valid: @_something_else:localhost") } } + +func TestSessionCleanUp(t *testing.T) { + s := newSessionsDict() + + t.Run("session is cleaned up after a while", func(t *testing.T) { + t.Parallel() + dummySession := "helloWorld" + // manually added, as s.addParams() would start the timer with the default timeout + s.params[dummySession] = registerRequest{Username: "Testing"} + s.startTimer(time.Millisecond, dummySession) + time.Sleep(time.Millisecond * 2) + if data, ok := s.getParams(dummySession); ok { + t.Errorf("expected session to be deleted: %+v", data) + } + }) + + t.Run("session is deleted, once the registration completed", func(t *testing.T) { + t.Parallel() + dummySession := "helloWorld2" + s.startTimer(time.Minute, dummySession) + s.deleteSession(dummySession) + if data, ok := s.getParams(dummySession); ok { + t.Errorf("expected session to be deleted: %+v", data) + } + }) + + t.Run("session timer is restarted after second call", func(t *testing.T) { + t.Parallel() + dummySession := "helloWorld3" + // the following will start a timer with the default timeout of 5min + s.addParams(dummySession, registerRequest{Username: "Testing"}) + s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha) + s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy) + s.getCompletedStages(dummySession) + // reset the timer with a lower timeout + s.startTimer(time.Millisecond, dummySession) + time.Sleep(time.Millisecond * 2) + if data, ok := s.getParams(dummySession); ok { + t.Errorf("expected session to be deleted: %+v", data) + } + }) +} \ No newline at end of file diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 3003896c8..307fa17e6 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -48,24 +48,27 @@ Example: # read password from stdin %s --config dendrite.yaml -username alice -passwordstdin < my.pass cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin + # reset password for a user, can be used with a combination above to read the password + %s --config dendrite.yaml -reset-password -username alice -password foobarbaz Arguments: ` var ( - username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") - password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)") - pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") - pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") - askPass = flag.Bool("ask-pass", false, "Ask for the password to use") - isAdmin = flag.Bool("admin", false, "Create an admin account") + username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") + password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)") + pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") + pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") + askPass = flag.Bool("ask-pass", false, "Ask for the password to use") + isAdmin = flag.Bool("admin", false, "Create an admin account") + resetPassword = flag.Bool("reset-password", false, "Resets the password for the given username") ) func main() { name := os.Args[0] flag.Usage = func() { - _, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name) + _, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name, name) flag.PrintDefaults() } cfg := setup.ParseFlags(true) @@ -93,6 +96,19 @@ func main() { if *isAdmin { accType = api.AccountTypeAdmin } + + if *resetPassword { + err = accountDB.SetPassword(context.Background(), *username, pass) + if err != nil { + logrus.Fatalf("Failed to update password for user %s: %s", *username, err.Error()) + } + if _, err = accountDB.RemoveAllDevices(context.Background(), *username, ""); err != nil { + logrus.Fatalf("Failed to remove all devices: %s", err.Error()) + } + logrus.Infof("Updated password for user %s and invalidated all logins\n", *username) + return + } + _, err = accountDB.CreateAccount(context.Background(), *username, pass, "", accType) if err != nil { logrus.Fatalln("Failed to create the account:", err.Error()) diff --git a/docs/systemd/monolith-example.service b/docs/systemd/monolith-example.service index 731c6159b..237120ffb 100644 --- a/docs/systemd/monolith-example.service +++ b/docs/systemd/monolith-example.service @@ -13,6 +13,7 @@ Group=dendrite WorkingDirectory=/opt/dendrite/ ExecStart=/opt/dendrite/bin/dendrite-monolith-server Restart=always +LimitNOFILE=65535 [Install] WantedBy=multi-user.target diff --git a/federationapi/api/api.go b/federationapi/api/api.go index f5ee75b4b..4d6b0211c 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -21,7 +21,7 @@ type FederationClient interface { QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) - MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest) (res gomatrixserverlib.MSC2946SpacesResponse, err error) + MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index b31db466c..b8bd5beda 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships( } func (a *FederationInternalAPI) MSC2946Spaces( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest, + ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.MSC2946Spaces(ctx, s, roomID, r) + return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly) }) if err != nil { return res, err diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index f9b2a33d2..01ca6595d 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -526,23 +526,23 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships( } type spacesReq struct { - S gomatrixserverlib.ServerName - Req gomatrixserverlib.MSC2946SpacesRequest - RoomID string - Res gomatrixserverlib.MSC2946SpacesResponse - Err *api.FederationClientError + S gomatrixserverlib.ServerName + SuggestedOnly bool + RoomID string + Res gomatrixserverlib.MSC2946SpacesResponse + Err *api.FederationClientError } func (h *httpFederationInternalAPI) MSC2946Spaces( - ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest, + ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces") defer span.Finish() request := spacesReq{ - S: dst, - Req: r, - RoomID: roomID, + S: dst, + SuggestedOnly: suggestedOnly, + RoomID: roomID, } var response spacesReq apiURL := h.federationAPIURL + FederationAPISpacesSummaryPath diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 8d193d9c9..ca4930f20 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -378,7 +378,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { if err := json.NewDecoder(req.Body).Decode(&request); err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.Req) + res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.SuggestedOnly) if err != nil { ferr, ok := err.(*api.FederationClientError) if ok { diff --git a/go.mod b/go.mod index a416ec98f..e44e7393d 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/frankban/quicktest v1.14.0 // indirect github.com/getsentry/sentry-go v0.12.0 github.com/gologme/log v1.3.0 + github.com/google/uuid v1.2.0 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.4.2 github.com/h2non/filetype v1.1.3 // indirect @@ -39,7 +40,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6 github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 @@ -61,11 +62,11 @@ require ( github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.2 go.uber.org/atomic v1.9.0 - golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a + golang.org/x/crypto v0.0.0-20220214200702-86341886e292 golang.org/x/image v0.0.0-20211028202545-6944b10bf410 golang.org/x/mobile v0.0.0-20220112015953-858099ff7816 golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd - golang.org/x/sys v0.0.0-20220207234003-57398862261d // indirect + golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 gopkg.in/h2non/bimg.v1 v1.1.5 gopkg.in/yaml.v2 v2.4.0 diff --git a/go.sum b/go.sum index ecb0496bc..43b363d30 100644 --- a/go.sum +++ b/go.sum @@ -983,8 +983,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052 h1:+4Q/JQ3fGgA7sIHaLMlqREX8yEpsI+HlVoW9WId7SNc= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6 h1:CqbM5tPbF1mV2h1J6N0NSF8et6HvFlwA98TLCUcjiSI= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= @@ -1510,8 +1510,8 @@ golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5 golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo= -golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1737,8 +1737,8 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc= -golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 h1:nhht2DYV/Sn3qOayu8lM+cU1ii9sTLUeBQwQQfUHtrs= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go index bfb2037f8..5124f37e6 100644 --- a/keyserver/internal/cross_signing.go +++ b/keyserver/internal/cross_signing.go @@ -166,26 +166,53 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P } // We can't have a self-signing or user-signing key without a master - // key, so make sure we have one of those. - if !hasMasterKey { - existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID) - if err != nil { - res.Error = &api.KeyError{ - Err: "Retrieving cross-signing keys from database failed: " + err.Error(), - } - return + // key, so make sure we have one of those. We will also only actually do + // something if any of the specified keys in the request are different + // to what we've got in the database, to avoid generating key change + // notifications unnecessarily. + existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID) + if err != nil { + res.Error = &api.KeyError{ + Err: "Retrieving cross-signing keys from database failed: " + err.Error(), } - - _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster] + return } // If we still can't find a master key for the user then stop the upload. // This satisfies the "Fails to upload self-signing key without master key" test. if !hasMasterKey { - res.Error = &api.KeyError{ - Err: "No master key was found", - IsMissingParam: true, + if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey { + res.Error = &api.KeyError{ + Err: "No master key was found", + IsMissingParam: true, + } + return } + } + + // Check if anything actually changed compared to what we have in the database. + changed := false + for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{ + gomatrixserverlib.CrossSigningKeyPurposeMaster, + gomatrixserverlib.CrossSigningKeyPurposeSelfSigning, + gomatrixserverlib.CrossSigningKeyPurposeUserSigning, + } { + old, gotOld := existingKeys[purpose] + new, gotNew := toStore[purpose] + if gotOld != gotNew { + // A new key purpose has been specified that we didn't know before, + // or one has been removed. + changed = true + break + } + if !bytes.Equal(old, new) { + // One of the existing keys for a purpose we already knew about has + // changed. + changed = true + break + } + } + if !changed { return } diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 96d6711c6..66e85f2f3 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -269,6 +269,7 @@ type QueryAuthChainResponse struct { type QuerySharedUsersRequest struct { UserID string + OtherUserIDs []string ExcludeRoomIDs []string IncludeRoomIDs []string } @@ -312,7 +313,10 @@ type QueryBulkStateContentResponse struct { } type QueryCurrentStateRequest struct { - RoomID string + RoomID string + AllowWildcards bool + // State key tuples. If a state_key has '*' and AllowWidlcards is true, returns all matching + // state events with that event type. StateTuples []gomatrixserverlib.StateKeyTuple } diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 012094c62..5491d36b3 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -51,12 +51,8 @@ func SendEventWithState( state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, ) error { - outliers, err := state.Events(event.RoomVersion) - if err != nil { - return err - } - - var ires []InputRoomEvent + outliers := state.Events(event.RoomVersion) + ires := make([]InputRoomEvent, 0, len(outliers)) for _, outlier := range outliers { if haveEventIDs[outlier.EventID()] { continue diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 4655e92a9..a7da9b06d 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -23,6 +23,21 @@ type parsedRespState struct { StateEvents []*gomatrixserverlib.Event } +func (p *parsedRespState) Events() []*gomatrixserverlib.Event { + eventsByID := make(map[string]*gomatrixserverlib.Event, len(p.AuthEvents)+len(p.StateEvents)) + for i, event := range p.AuthEvents { + eventsByID[event.EventID()] = p.AuthEvents[i] + } + for i, event := range p.StateEvents { + eventsByID[event.EventID()] = p.StateEvents[i] + } + allEvents := make([]*gomatrixserverlib.Event, 0, len(eventsByID)) + for _, event := range eventsByID { + allEvents = append(allEvents, event) + } + return gomatrixserverlib.ReverseTopologicalOrdering(allEvents, gomatrixserverlib.TopologicalOrderByAuthEvents) +} + type missingStateReq struct { origin gomatrixserverlib.ServerName db storage.Database @@ -124,11 +139,8 @@ func (t *missingStateReq) processEventWithMissingState( t.hadEventsMutex.Unlock() sendOutliers := func(resolvedState *parsedRespState) error { - outliers, oerr := gomatrixserverlib.OrderAuthAndStateEvents(resolvedState.AuthEvents, resolvedState.StateEvents, roomVersion) - if oerr != nil { - return fmt.Errorf("gomatrixserverlib.OrderAuthAndStateEvents: %w", oerr) - } - var outlierRoomEvents []api.InputRoomEvent + outliers := resolvedState.Events() + outlierRoomEvents := make([]api.InputRoomEvent, 0, len(outliers)) for _, outlier := range outliers { if hadEvents[outlier.EventID()] { continue diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index c8bbe7705..70cc5d62c 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -621,12 +621,25 @@ func (r *Queryer) QueryPublishedRooms( func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error { res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) for _, tuple := range req.StateTuples { - ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey) - if err != nil { - return err - } - if ev != nil { - res.StateEvents[tuple] = ev + if tuple.StateKey == "*" && req.AllowWildcards { + events, err := r.DB.GetStateEventsWithEventType(ctx, req.RoomID, tuple.EventType) + if err != nil { + return err + } + for _, e := range events { + res.StateEvents[gomatrixserverlib.StateKeyTuple{ + EventType: e.Type(), + StateKey: *e.StateKey(), + }] = e + } + } else { + ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey) + if err != nil { + return err + } + if ev != nil { + res.StateEvents[tuple] = ev + } } } return nil @@ -696,7 +709,7 @@ func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUser } roomIDs = roomIDs[:j] - users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs) + users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs) if err != nil { return err } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 685505d52..a2b22b401 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -146,13 +146,14 @@ type Database interface { // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) + GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) - // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. - JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) + // JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms. + JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) // GetServerInRoom returns true if we think a server is in a given room or false otherwise. diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 48c2c35cd..127178743 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -66,7 +66,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( ` var selectJoinedUsersSetForRoomsSQL = "" + - "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + + " WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " GROUP BY target_nid" @@ -306,13 +307,10 @@ func (s *membershipStatements) SelectRoomsWithMembership( func (s *membershipStatements) SelectJoinedUsersSetForRooms( ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, + userNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]int, error) { - roomIDarray := make([]int64, len(roomNIDs)) - for i := range roomNIDs { - roomIDarray[i] = int64(roomNIDs[i]) - } stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt) - rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) + rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs)) if err != nil { return nil, err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 6e84b2832..f87782776 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -979,6 +979,62 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s return nil, nil } +// Same as GetStateEvent but returns all matching state events with this event type. Returns no error +// if there are no events with this event type. +func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) { + roomInfo, err := d.RoomInfo(ctx, roomID) + if err != nil { + return nil, err + } + if roomInfo == nil { + return nil, fmt.Errorf("room %s doesn't exist", roomID) + } + // e.g invited rooms + if roomInfo.IsStub { + return nil, nil + } + eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) + if err == sql.ErrNoRows { + // No rooms have an event of this type, otherwise we'd have an event type NID + return nil, nil + } + if err != nil { + return nil, err + } + entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID) + if err != nil { + return nil, err + } + var eventNIDs []types.EventNID + for _, e := range entries { + if e.EventTypeNID == eventTypeNID { + eventNIDs = append(eventNIDs, e.EventNID) + } + } + eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) + if err != nil { + eventIDs = map[types.EventNID]string{} + } + // return the events requested + eventPairs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs) + if err != nil { + return nil, err + } + if len(eventPairs) == 0 { + return nil, nil + } + var result []*gomatrixserverlib.HeaderedEvent + for _, pair := range eventPairs { + ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[pair.EventNID], pair.EventJSON, false, roomInfo.RoomVersion) + if err != nil { + return nil, err + } + result = append(result, ev.Headered(roomInfo.RoomVersion)) + } + + return result, nil +} + // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) { var membershipState tables.MembershipState @@ -1104,13 +1160,23 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu return result, nil } -// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. -func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { +// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms. +func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) { roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs) if err != nil { return nil, err } - userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs) + userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs) + if err != nil { + return nil, err + } + userNIDs := make([]types.EventStateKeyNID, 0, len(userNIDsMap)) + nidToUserID := make(map[types.EventStateKeyNID]string, len(userNIDsMap)) + for id, nid := range userNIDsMap { + userNIDs = append(userNIDs, nid) + nidToUserID[nid] = id + } + userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs) if err != nil { return nil, err } @@ -1120,10 +1186,6 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) stateKeyNIDs[i] = nid i++ } - nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs) - if err != nil { - return nil, err - } if len(nidToUserID) != len(userNIDToCount) { logrus.Warnf("SelectJoinedUsersSetForRooms found %d users but BulkSelectEventStateKey only returned state key NIDs for %d of them", len(userNIDToCount), len(nidToUserID)) } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 181b4b4c9..43567a94c 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -42,7 +42,8 @@ const membershipSchema = ` ` var selectJoinedUsersSetForRoomsSQL = "" + - "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + + " WHERE room_nid IN ($1) AND target_nid IN ($2) AND" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " GROUP BY target_nid" @@ -280,18 +281,22 @@ func (s *membershipStatements) SelectRoomsWithMembership( return roomNIDs, nil } -func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { - iRoomNIDs := make([]interface{}, len(roomNIDs)) - for i, v := range roomNIDs { - iRoomNIDs[i] = v +func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) { + params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs)) + for _, v := range roomNIDs { + params = append(params, v) } - query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) + for _, v := range userNIDs { + params = append(params, v) + } + query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) + query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1) var rows *sql.Rows var err error if txn != nil { - rows, err = txn.QueryContext(ctx, query, iRoomNIDs...) + rows, err = txn.QueryContext(ctx, query, params...) } else { - rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...) + rows, err = s.db.QueryContext(ctx, query, params...) } if err != nil { return nil, err diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index e3fed700b..04e3c96cc 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -127,9 +127,8 @@ type Membership interface { SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) - // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the - // counts of how many rooms they are joined. - SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) + // SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms. + SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 0af22c19a..29c781a88 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -654,11 +654,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo AuthEvents: res.AuthChain, StateEvents: stateEvents, } - eventsInOrder, err := respState.Events(rc.roomVersion) - if err != nil { - util.GetLogger(rc.ctx).WithError(err).Error("failed to calculate order to send events in MSC2836EventRelationshipsResponse") - return - } + eventsInOrder := respState.Events(rc.roomVersion) // everything gets sent as an outlier because auth chain events may be disjoint from the DAG // as may the threaded events. var ires []roomserver.InputRoomEvent @@ -669,7 +665,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo }) } // we've got the data by this point so use a background context - err = roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false) + err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver") } diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 3824c99a2..3bb56f4b3 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -18,17 +18,18 @@ package msc2946 import ( "context" "encoding/json" - "fmt" "net/http" + "net/url" + "sort" + "strconv" "strings" "sync" "time" + "github.com/google/uuid" "github.com/gorilla/mux" - chttputil "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" fs "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/internal/httputil" roomserver "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" @@ -39,15 +40,15 @@ import ( ) const ( - ConstCreateEventContentKey = "type" - ConstSpaceChildEventType = "m.space.child" - ConstSpaceParentEventType = "m.space.parent" + ConstCreateEventContentKey = "type" + ConstCreateEventContentValueSpace = "m.space" + ConstSpaceChildEventType = "m.space.child" + ConstSpaceParentEventType = "m.space.parent" ) -// Defaults sets the request defaults -func Defaults(r *gomatrixserverlib.MSC2946SpacesRequest) { - r.Limit = 2000 - r.MaxRoomsPerSpace = -1 +type MSC2946ClientResponse struct { + Rooms []gomatrixserverlib.MSC2946Room `json:"rooms"` + NextBatch string `json:"next_batch,omitempty"` } // Enable this MSC @@ -55,26 +56,11 @@ func Enable( base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI, fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, ) error { - db, err := NewDatabase(&base.Cfg.MSCs.Database) - if err != nil { - return fmt.Errorf("cannot enable MSC2946: %w", err) - } - hooks.Enable() - hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { - he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) - hookErr := db.StoreReference(context.Background(), he) - if hookErr != nil { - util.GetLogger(context.Background()).WithError(hookErr).WithField("event_id", he.EventID()).Error( - "failed to StoreReference", - ) - } - }) + clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, base.Cfg.Global.ServerName)) + base.PublicClientAPIMux.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) + base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) - base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/spaces", - httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(db, rsAPI, fsAPI, base.Cfg.Global.ServerName)), - ).Methods(http.MethodPost, http.MethodOptions) - - base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/spaces/{roomID}", httputil.MakeExternalAPI( + fedAPI := httputil.MakeExternalAPI( "msc2946_fed_spaces", func(req *http.Request) util.JSONResponse { fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( req, time.Now(), base.Cfg.Global.ServerName, keyRing, @@ -88,105 +74,99 @@ func Enable( return util.ErrorResponse(err) } roomID := params["roomID"] - return federatedSpacesHandler(req.Context(), fedReq, roomID, db, rsAPI, fsAPI, base.Cfg.Global.ServerName) + return federatedSpacesHandler(req.Context(), fedReq, roomID, rsAPI, fsAPI, base.Cfg.Global.ServerName) }, - )).Methods(http.MethodPost, http.MethodOptions) + ) + base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet) + base.PublicFederationAPIMux.Handle("/v1/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet) return nil } func federatedSpacesHandler( - ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, db Database, + ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, thisServer gomatrixserverlib.ServerName, ) util.JSONResponse { - inMemoryBatchCache := make(map[string]set) - var r gomatrixserverlib.MSC2946SpacesRequest - Defaults(&r) - if err := json.Unmarshal(fedReq.Content(), &r); err != nil { + u, err := url.Parse(fedReq.RequestURI()) + if err != nil { return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), + Code: 400, + JSON: jsonerror.InvalidParam("bad request uri"), } } - w := walker{ - req: &r, - rootRoomID: roomID, - serverName: fedReq.Origin(), - thisServer: thisServer, - ctx: ctx, - db: db, - rsAPI: rsAPI, - fsAPI: fsAPI, - inMemoryBatchCache: inMemoryBatchCache, - } - res := w.walk() - return util.JSONResponse{ - Code: 200, - JSON: res, + w := walker{ + rootRoomID: roomID, + serverName: fedReq.Origin(), + thisServer: thisServer, + ctx: ctx, + suggestedOnly: u.Query().Get("suggested_only") == "true", + limit: 1000, + // The main difference is that it does not recurse into spaces and does not support pagination. + // This is somewhat equivalent to a Client-Server request with a max_depth=1. + maxDepth: 1, + + rsAPI: rsAPI, + fsAPI: fsAPI, + // inline cache as we don't have pagination in federation mode + paginationCache: make(map[string]paginationInfo), } + return w.walk() } func spacesHandler( - db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, + rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, thisServer gomatrixserverlib.ServerName, ) func(*http.Request, *userapi.Device) util.JSONResponse { + // declared outside the returned handler so it persists between calls + // TODO: clear based on... time? + paginationCache := make(map[string]paginationInfo) + return func(req *http.Request, device *userapi.Device) util.JSONResponse { - inMemoryBatchCache := make(map[string]set) // Extract the room ID from the request. Sanity check request data. params, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } roomID := params["roomID"] - var r gomatrixserverlib.MSC2946SpacesRequest - Defaults(&r) - if resErr := chttputil.UnmarshalJSONRequest(req, &r); resErr != nil { - return *resErr - } w := walker{ - req: &r, - rootRoomID: roomID, - caller: device, - thisServer: thisServer, - ctx: req.Context(), + suggestedOnly: req.URL.Query().Get("suggested_only") == "true", + limit: parseInt(req.URL.Query().Get("limit"), 1000), + maxDepth: parseInt(req.URL.Query().Get("max_depth"), -1), + paginationToken: req.URL.Query().Get("from"), + rootRoomID: roomID, + caller: device, + thisServer: thisServer, + ctx: req.Context(), - db: db, - rsAPI: rsAPI, - fsAPI: fsAPI, - inMemoryBatchCache: inMemoryBatchCache, - } - res := w.walk() - return util.JSONResponse{ - Code: 200, - JSON: res, + rsAPI: rsAPI, + fsAPI: fsAPI, + paginationCache: paginationCache, } + return w.walk() } } +type paginationInfo struct { + processed set + unvisited []roomVisit +} + type walker struct { - req *gomatrixserverlib.MSC2946SpacesRequest - rootRoomID string - caller *userapi.Device - serverName gomatrixserverlib.ServerName - thisServer gomatrixserverlib.ServerName - db Database - rsAPI roomserver.RoomserverInternalAPI - fsAPI fs.FederationInternalAPI - ctx context.Context + rootRoomID string + caller *userapi.Device + serverName gomatrixserverlib.ServerName + thisServer gomatrixserverlib.ServerName + rsAPI roomserver.RoomserverInternalAPI + fsAPI fs.FederationInternalAPI + ctx context.Context + suggestedOnly bool + limit int + maxDepth int + paginationToken string - // user ID|device ID|batch_num => event/room IDs sent to client - inMemoryBatchCache map[string]set - mu sync.Mutex -} - -func (w *walker) roomIsExcluded(roomID string) bool { - for _, exclRoom := range w.req.ExcludeRooms { - if exclRoom == roomID { - return true - } - } - return false + paginationCache map[string]paginationInfo + mu sync.Mutex } func (w *walker) callerID() string { @@ -196,144 +176,207 @@ func (w *walker) callerID() string { return string(w.serverName) } -func (w *walker) alreadySent(id string) bool { - w.mu.Lock() - defer w.mu.Unlock() - m, ok := w.inMemoryBatchCache[w.callerID()] - if !ok { - return false +func (w *walker) newPaginationCache() (string, paginationInfo) { + p := paginationInfo{ + processed: make(set), + unvisited: nil, } - return m[id] + tok := uuid.NewString() + return tok, p } -func (w *walker) markSent(id string) { +func (w *walker) loadPaginationCache(paginationToken string) *paginationInfo { w.mu.Lock() defer w.mu.Unlock() - m := w.inMemoryBatchCache[w.callerID()] - if m == nil { - m = make(set) - } - m[id] = true - w.inMemoryBatchCache[w.callerID()] = m + p := w.paginationCache[paginationToken] + return &p } -func (w *walker) walk() *gomatrixserverlib.MSC2946SpacesResponse { - var res gomatrixserverlib.MSC2946SpacesResponse - // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms - unvisited := []string{w.rootRoomID} - processed := make(set) +func (w *walker) storePaginationCache(paginationToken string, cache paginationInfo) { + w.mu.Lock() + defer w.mu.Unlock() + w.paginationCache[paginationToken] = cache +} + +type roomVisit struct { + roomID string + depth int + vias []string // vias to query this room by +} + +func (w *walker) walk() util.JSONResponse { + if !w.authorised(w.rootRoomID) { + if w.caller != nil { + // CS API format + return util.JSONResponse{ + Code: 403, + JSON: jsonerror.Forbidden("room is unknown/forbidden"), + } + } else { + // SS API format + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("room is unknown/forbidden"), + } + } + } + + var discoveredRooms []gomatrixserverlib.MSC2946Room + + var cache *paginationInfo + if w.paginationToken != "" { + cache = w.loadPaginationCache(w.paginationToken) + if cache == nil { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("invalid from"), + } + } + } else { + tok, c := w.newPaginationCache() + cache = &c + w.paginationToken = tok + // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms + c.unvisited = append(c.unvisited, roomVisit{ + roomID: w.rootRoomID, + depth: 0, + }) + } + + processed := cache.processed + unvisited := cache.unvisited + + // Depth first -> stack data structure for len(unvisited) > 0 { - roomID := unvisited[0] - unvisited = unvisited[1:] - // If this room has already been processed, skip. NB: do not remember this between calls - if processed[roomID] || roomID == "" { + if len(discoveredRooms) >= w.limit { + break + } + + // pop the stack + rv := unvisited[len(unvisited)-1] + unvisited = unvisited[:len(unvisited)-1] + // If this room has already been processed, skip. + // If this room exceeds the specified depth, skip. + if processed.isSet(rv.roomID) || rv.roomID == "" || (w.maxDepth > 0 && rv.depth > w.maxDepth) { continue } + // Mark this room as processed. - processed[roomID] = true + processed.set(rv.roomID) + + // if this room is not a space room, skip. + var roomType string + create := w.stateEvent(rv.roomID, gomatrixserverlib.MRoomCreate, "") + if create != nil { + // escape the `.`s so gjson doesn't think it's nested + roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str + } // Collect rooms/events to send back (either locally or fetched via federation) - var discoveredRooms []gomatrixserverlib.MSC2946Room - var discoveredEvents []gomatrixserverlib.MSC2946StrippedEvent + var discoveredChildEvents []gomatrixserverlib.MSC2946StrippedEvent // If we know about this room and the caller is authorised (joined/world_readable) then pull // events locally - if w.roomExists(roomID) && w.authorised(roomID) { - // Get all `m.space.child` and `m.space.parent` state events for the room. *In addition*, get - // all `m.space.child` and `m.space.parent` state events which *point to* (via `state_key` or `content.room_id`) - // this room. This requires servers to store reverse lookups. - events, err := w.references(roomID) + if w.roomExists(rv.roomID) && w.authorised(rv.roomID) { + // Get all `m.space.child` state events for this room + events, err := w.childReferences(rv.roomID) if err != nil { - util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Error("failed to extract references for room") + util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Error("failed to extract references for room") continue } - discoveredEvents = events + discoveredChildEvents = events - pubRoom := w.publicRoomsChunk(roomID) - roomType := "" - create := w.stateEvent(roomID, gomatrixserverlib.MRoomCreate, "") - if create != nil { - // escape the `.`s so gjson doesn't think it's nested - roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str - } + pubRoom := w.publicRoomsChunk(rv.roomID) - // Add the total number of events to `PublicRoomsChunk` under `num_refs`. Add `PublicRoomsChunk` to `rooms`. discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{ - PublicRoom: *pubRoom, - NumRefs: len(discoveredEvents), - RoomType: roomType, + PublicRoom: *pubRoom, + RoomType: roomType, + ChildrenState: events, }) } else { // attempt to query this room over federation, as either we've never heard of it before // or we've left it and hence are not authorised (but info may be exposed regardless) - fedRes, err := w.federatedRoomInfo(roomID) + fedRes, err := w.federatedRoomInfo(rv.roomID, rv.vias) if err != nil { - util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Errorf("failed to query federated spaces") + util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Errorf("failed to query federated spaces") continue } if fedRes != nil { - discoveredRooms = fedRes.Rooms - discoveredEvents = fedRes.Events + discoveredChildEvents = fedRes.Room.ChildrenState + discoveredRooms = append(discoveredRooms, fedRes.Room) + if len(fedRes.Children) > 0 { + discoveredRooms = append(discoveredRooms, fedRes.Children...) + } + // mark this room as a space room as the federated server responded. + // we need to do this so we add the children of this room to the unvisited stack + // as these children may be rooms we do know about. + roomType = ConstCreateEventContentValueSpace } } - // If this room has not ever been in `rooms` (across multiple requests), send it now - for _, room := range discoveredRooms { - if !w.alreadySent(room.RoomID) && !w.roomIsExcluded(room.RoomID) { - res.Rooms = append(res.Rooms, room) - w.markSent(room.RoomID) - } + // don't walk the children + // if the parent is not a space room + if roomType != ConstCreateEventContentValueSpace { + continue } - uniqueRooms := make(set) - - // If this is the root room from the original request, insert all these events into `events` if - // they haven't been added before (across multiple requests). - if w.rootRoomID == roomID { - for _, ev := range discoveredEvents { - if !w.alreadySent(eventKey(&ev)) { - res.Events = append(res.Events, ev) - uniqueRooms[ev.RoomID] = true - uniqueRooms[spaceTargetStripped(&ev)] = true - w.markSent(eventKey(&ev)) - } - } - } else { - // Else add them to `events` honouring the `limit` and `max_rooms_per_space` values. If either - // are exceeded, stop adding events. If the event has already been added, do not add it again. - numAdded := 0 - for _, ev := range discoveredEvents { - if w.req.Limit > 0 && len(res.Events) >= w.req.Limit { - break - } - if w.req.MaxRoomsPerSpace > 0 && numAdded >= w.req.MaxRoomsPerSpace { - break - } - if w.alreadySent(eventKey(&ev)) { - continue - } - // Skip the room if it's part of exclude_rooms but ONLY IF the source matches, as we still - // want to catch arrows which point to excluded rooms. - if w.roomIsExcluded(ev.RoomID) { - continue - } - res.Events = append(res.Events, ev) - uniqueRooms[ev.RoomID] = true - uniqueRooms[spaceTargetStripped(&ev)] = true - w.markSent(eventKey(&ev)) - // we don't distinguish between child state events and parent state events for the purposes of - // max_rooms_per_space, maybe we should? - numAdded++ - } - } - - // For each referenced room ID in the events being returned to the caller (both parent and child) + // For each referenced room ID in the child events being returned to the caller // add the room ID to the queue of unvisited rooms. Loop from the beginning. - for roomID := range uniqueRooms { - unvisited = append(unvisited, roomID) + // We need to invert the order here because the child events are lo->hi on the timestamp, + // so we need to ensure we pop in the same lo->hi order, which won't be the case if we + // insert the highest timestamp last in a stack. + for i := len(discoveredChildEvents) - 1; i >= 0; i-- { + spaceContent := struct { + Via []string `json:"via"` + }{} + ev := discoveredChildEvents[i] + _ = json.Unmarshal(ev.Content, &spaceContent) + unvisited = append(unvisited, roomVisit{ + roomID: ev.StateKey, + depth: rv.depth + 1, + vias: spaceContent.Via, + }) } } - return &res + + if len(unvisited) > 0 { + // we still have more rooms so we need to send back a pagination token, + // we probably hit a room limit + cache.processed = processed + cache.unvisited = unvisited + w.storePaginationCache(w.paginationToken, *cache) + } else { + // clear the pagination token so we don't send it back to the client + // Note we do NOT nuke the cache just in case this response is lost + // and the client retries it. + w.paginationToken = "" + } + + if w.caller != nil { + // return CS API format + return util.JSONResponse{ + Code: 200, + JSON: MSC2946ClientResponse{ + Rooms: discoveredRooms, + NextBatch: w.paginationToken, + }, + } + } + // return SS API format + // the first discovered room will be the room asked for, and subsequent ones the depth=1 children + if len(discoveredRooms) == 0 { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("room is unknown/forbidden"), + } + } + return util.JSONResponse{ + Code: 200, + JSON: gomatrixserverlib.MSC2946SpacesResponse{ + Room: discoveredRooms[0], + Children: discoveredRooms[1:], + }, + } } func (w *walker) stateEvent(roomID, evType, stateKey string) *gomatrixserverlib.HeaderedEvent { @@ -366,42 +409,19 @@ func (w *walker) publicRoomsChunk(roomID string) *gomatrixserverlib.PublicRoom { // federatedRoomInfo returns more of the spaces graph from another server. Returns nil if this was // unsuccessful. -func (w *walker) federatedRoomInfo(roomID string) (*gomatrixserverlib.MSC2946SpacesResponse, error) { +func (w *walker) federatedRoomInfo(roomID string, vias []string) (*gomatrixserverlib.MSC2946SpacesResponse, error) { // only do federated requests for client requests if w.caller == nil { return nil, nil } - // extract events which point to this room ID and extract their vias - events, err := w.db.References(w.ctx, roomID) - if err != nil { - return nil, fmt.Errorf("failed to get References events: %w", err) - } - vias := make(set) - for _, ev := range events { - if ev.StateKeyEquals(roomID) { - // event points at this room, extract vias - content := struct { - Vias []string `json:"via"` - }{} - if err = json.Unmarshal(ev.Content(), &content); err != nil { - continue // silently ignore corrupted state events - } - for _, v := range content.Vias { - vias[v] = true - } - } - } - util.GetLogger(w.ctx).Infof("Querying federatedRoomInfo via %+v", vias) + util.GetLogger(w.ctx).Infof("Querying %s via %+v", roomID, vias) ctx := context.Background() // query more of the spaces graph using these servers - for serverName := range vias { + for _, serverName := range vias { if serverName == string(w.thisServer) { continue } - res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, gomatrixserverlib.MSC2946SpacesRequest{ - Limit: w.req.Limit, - MaxRoomsPerSpace: w.req.MaxRoomsPerSpace, - }) + res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly) if err != nil { util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName) continue @@ -501,7 +521,7 @@ func (w *walker) authorisedUser(roomID string) bool { hisVisEv := queryRes.StateEvents[hisVisTuple] if memberEv != nil { membership, _ := memberEv.Membership() - if membership == gomatrixserverlib.Join { + if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite { return true } } @@ -514,40 +534,85 @@ func (w *walker) authorisedUser(roomID string) bool { return false } -// references returns all references pointing to or from this room. -func (w *walker) references(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) { - events, err := w.db.References(w.ctx, roomID) +// references returns all child references pointing to or from this room. +func (w *walker) childReferences(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) { + createTuple := gomatrixserverlib.StateKeyTuple{ + EventType: gomatrixserverlib.MRoomCreate, + StateKey: "", + } + var res roomserver.QueryCurrentStateResponse + err := w.rsAPI.QueryCurrentState(context.Background(), &roomserver.QueryCurrentStateRequest{ + RoomID: roomID, + AllowWildcards: true, + StateTuples: []gomatrixserverlib.StateKeyTuple{ + createTuple, { + EventType: ConstSpaceChildEventType, + StateKey: "*", + }, + }, + }, &res) if err != nil { return nil, err } - el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(events)) - for _, ev := range events { + + // don't return any child refs if the room is not a space room + if res.StateEvents[createTuple] != nil { + // escape the `.`s so gjson doesn't think it's nested + roomType := gjson.GetBytes(res.StateEvents[createTuple].Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str + if roomType != ConstCreateEventContentValueSpace { + return nil, nil + } + } + delete(res.StateEvents, createTuple) + + el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(res.StateEvents)) + for _, ev := range res.StateEvents { + content := gjson.ParseBytes(ev.Content()) // only return events that have a `via` key as per MSC1772 // else we'll incorrectly walk redacted events (as the link // is in the state_key) - if gjson.GetBytes(ev.Content(), "via").Exists() { + if content.Get("via").Exists() { strip := stripped(ev.Event) if strip == nil { continue } + // if suggested only and this child isn't suggested, skip it. + // if suggested only = false we include everything so don't need to check the content. + if w.suggestedOnly && !content.Get("suggested").Bool() { + continue + } el = append(el, *strip) } } + // sort by origin_server_ts as per MSC2946 + sort.Slice(el, func(i, j int) bool { + return el[i].OriginServerTS < el[j].OriginServerTS + }) + return el, nil } -type set map[string]bool +type set map[string]struct{} + +func (s set) set(val string) { + s[val] = struct{}{} +} +func (s set) isSet(val string) bool { + _, ok := s[val] + return ok +} func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEvent { if ev.StateKey() == nil { return nil } return &gomatrixserverlib.MSC2946StrippedEvent{ - Type: ev.Type(), - StateKey: *ev.StateKey(), - Content: ev.Content(), - Sender: ev.Sender(), - RoomID: ev.RoomID(), + Type: ev.Type(), + StateKey: *ev.StateKey(), + Content: ev.Content(), + Sender: ev.Sender(), + RoomID: ev.RoomID(), + OriginServerTS: ev.OriginServerTS(), } } @@ -567,3 +632,11 @@ func spaceTargetStripped(event *gomatrixserverlib.MSC2946StrippedEvent) string { } return "" } + +func parseInt(intstr string, defaultVal int) int { + i, err := strconv.ParseInt(intstr, 10, 32) + if err != nil { + return defaultVal + } + return int(i) +} diff --git a/setup/mscs/msc2946/msc2946_test.go b/setup/mscs/msc2946/msc2946_test.go deleted file mode 100644 index e8066c34d..000000000 --- a/setup/mscs/msc2946/msc2946_test.go +++ /dev/null @@ -1,464 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package msc2946_test - -import ( - "bytes" - "context" - "crypto/ed25519" - "encoding/json" - "io/ioutil" - "net/http" - "net/url" - "testing" - "time" - - "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/internal/hooks" - "github.com/matrix-org/dendrite/internal/httputil" - roomserver "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/mscs/msc2946" - userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" -) - -var ( - client = &http.Client{ - Timeout: 10 * time.Second, - } - roomVer = gomatrixserverlib.RoomVersionV6 -) - -// Basic sanity check of MSC2946 logic. Tests a single room with a few state events -// and a bit of recursion to subspaces. Makes a graph like: -// Root -// ____|_____ -// | | | -// R1 R2 S1 -// |_________ -// | | | -// R3 R4 S2 -// | <-- this link is just a parent, not a child -// R5 -// -// Alice is not joined to R4, but R4 is "world_readable". -func TestMSC2946(t *testing.T) { - alice := "@alice:localhost" - // give access token to alice - nopUserAPI := &testUserAPI{ - accessTokens: make(map[string]userapi.Device), - } - nopUserAPI.accessTokens["alice"] = userapi.Device{ - AccessToken: "alice", - DisplayName: "Alice", - UserID: alice, - } - rootSpace := "!rootspace:localhost" - subSpaceS1 := "!subspaceS1:localhost" - subSpaceS2 := "!subspaceS2:localhost" - room1 := "!room1:localhost" - room2 := "!room2:localhost" - room3 := "!room3:localhost" - room4 := "!room4:localhost" - empty := "" - room5 := "!room5:localhost" - allRooms := []string{ - rootSpace, subSpaceS1, subSpaceS2, - room1, room2, room3, room4, room5, - } - rootToR1 := mustCreateEvent(t, fledglingEvent{ - RoomID: rootSpace, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room1, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - rootToR2 := mustCreateEvent(t, fledglingEvent{ - RoomID: rootSpace, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room2, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - rootToS1 := mustCreateEvent(t, fledglingEvent{ - RoomID: rootSpace, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &subSpaceS1, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - s1ToR3 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room3, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - s1ToR4 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room4, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - s1ToS2 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &subSpaceS2, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - // This is a parent link only - s2ToR5 := mustCreateEvent(t, fledglingEvent{ - RoomID: room5, - Sender: alice, - Type: msc2946.ConstSpaceParentEventType, - StateKey: &subSpaceS2, - Content: map[string]interface{}{ - "via": []string{"localhost"}, - }, - }) - // history visibility for R4 - r4HisVis := mustCreateEvent(t, fledglingEvent{ - RoomID: room4, - Sender: "@someone:localhost", - Type: gomatrixserverlib.MRoomHistoryVisibility, - StateKey: &empty, - Content: map[string]interface{}{ - "history_visibility": "world_readable", - }, - }) - var joinEvents []*gomatrixserverlib.HeaderedEvent - for _, roomID := range allRooms { - if roomID == room4 { - continue // not joined to that room - } - joinEvents = append(joinEvents, mustCreateEvent(t, fledglingEvent{ - RoomID: roomID, - Sender: alice, - StateKey: &alice, - Type: gomatrixserverlib.MRoomMember, - Content: map[string]interface{}{ - "membership": "join", - }, - })) - } - roomNameTuple := gomatrixserverlib.StateKeyTuple{ - EventType: "m.room.name", - StateKey: "", - } - hisVisTuple := gomatrixserverlib.StateKeyTuple{ - EventType: "m.room.history_visibility", - StateKey: "", - } - nopRsAPI := &testRoomserverAPI{ - joinEvents: joinEvents, - events: map[string]*gomatrixserverlib.HeaderedEvent{ - rootToR1.EventID(): rootToR1, - rootToR2.EventID(): rootToR2, - rootToS1.EventID(): rootToS1, - s1ToR3.EventID(): s1ToR3, - s1ToR4.EventID(): s1ToR4, - s1ToS2.EventID(): s1ToS2, - s2ToR5.EventID(): s2ToR5, - r4HisVis.EventID(): r4HisVis, - }, - pubRoomState: map[string]map[gomatrixserverlib.StateKeyTuple]string{ - rootSpace: { - roomNameTuple: "Root", - hisVisTuple: "shared", - }, - subSpaceS1: { - roomNameTuple: "Sub-Space 1", - hisVisTuple: "joined", - }, - subSpaceS2: { - roomNameTuple: "Sub-Space 2", - hisVisTuple: "shared", - }, - room1: { - hisVisTuple: "joined", - }, - room2: { - hisVisTuple: "joined", - }, - room3: { - hisVisTuple: "joined", - }, - room4: { - hisVisTuple: "world_readable", - }, - room5: { - hisVisTuple: "joined", - }, - }, - } - allEvents := []*gomatrixserverlib.HeaderedEvent{ - rootToR1, rootToR2, rootToS1, - s1ToR3, s1ToR4, s1ToS2, - s2ToR5, r4HisVis, - } - allEvents = append(allEvents, joinEvents...) - router := injectEvents(t, nopUserAPI, nopRsAPI, allEvents) - cancel := runServer(t, router) - defer cancel() - - t.Run("returns no events for unknown rooms", func(t *testing.T) { - res := postSpaces(t, 200, "alice", "!unknown:localhost", newReq(t, map[string]interface{}{})) - if len(res.Events) > 0 { - t.Errorf("got %d events, want 0", len(res.Events)) - } - if len(res.Rooms) > 0 { - t.Errorf("got %d rooms, want 0", len(res.Rooms)) - } - }) - t.Run("returns the entire graph", func(t *testing.T) { - res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{})) - if len(res.Events) != 7 { - t.Errorf("got %d events, want 7", len(res.Events)) - } - if len(res.Rooms) != len(allRooms) { - t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms)) - } - }) - t.Run("can update the graph", func(t *testing.T) { - // remove R3 from the graph - rmS1ToR3 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS1, - Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room3, - Content: map[string]interface{}{}, // redacted - }) - nopRsAPI.events[rmS1ToR3.EventID()] = rmS1ToR3 - hooks.Run(hooks.KindNewEventPersisted, rmS1ToR3) - - res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{})) - if len(res.Events) != 6 { // one less since we don't return redacted events - t.Errorf("got %d events, want 6", len(res.Events)) - } - if len(res.Rooms) != (len(allRooms) - 1) { // one less due to lack of R3 - t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms)-1) - } - }) -} - -func newReq(t *testing.T, jsonBody map[string]interface{}) *gomatrixserverlib.MSC2946SpacesRequest { - t.Helper() - b, err := json.Marshal(jsonBody) - if err != nil { - t.Fatalf("Failed to marshal request: %s", err) - } - var r gomatrixserverlib.MSC2946SpacesRequest - if err := json.Unmarshal(b, &r); err != nil { - t.Fatalf("Failed to unmarshal request: %s", err) - } - return &r -} - -func runServer(t *testing.T, router *mux.Router) func() { - t.Helper() - externalServ := &http.Server{ - Addr: string(":8010"), - WriteTimeout: 60 * time.Second, - Handler: router, - } - go func() { - externalServ.ListenAndServe() - }() - // wait to listen on the port - time.Sleep(500 * time.Millisecond) - return func() { - externalServ.Shutdown(context.TODO()) - } -} - -func postSpaces(t *testing.T, expectCode int, accessToken, roomID string, req *gomatrixserverlib.MSC2946SpacesRequest) *gomatrixserverlib.MSC2946SpacesResponse { - t.Helper() - var r gomatrixserverlib.MSC2946SpacesRequest - msc2946.Defaults(&r) - data, err := json.Marshal(req) - if err != nil { - t.Fatalf("failed to marshal request: %s", err) - } - httpReq, err := http.NewRequest( - "POST", "http://localhost:8010/_matrix/client/unstable/org.matrix.msc2946/rooms/"+url.PathEscape(roomID)+"/spaces", - bytes.NewBuffer(data), - ) - httpReq.Header.Set("Authorization", "Bearer "+accessToken) - if err != nil { - t.Fatalf("failed to prepare request: %s", err) - } - res, err := client.Do(httpReq) - if err != nil { - t.Fatalf("failed to do request: %s", err) - } - if res.StatusCode != expectCode { - body, _ := ioutil.ReadAll(res.Body) - t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body)) - } - if res.StatusCode == 200 { - var result gomatrixserverlib.MSC2946SpacesResponse - body, err := ioutil.ReadAll(res.Body) - if err != nil { - t.Fatalf("response 200 OK but failed to read response body: %s", err) - } - t.Logf("Body: %s", string(body)) - if err := json.Unmarshal(body, &result); err != nil { - t.Fatalf("response 200 OK but failed to deserialise JSON : %s\nbody: %s", err, string(body)) - } - return &result - } - return nil -} - -type testUserAPI struct { - userapi.UserInternalAPITrace - accessTokens map[string]userapi.Device -} - -func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error { - dev, ok := u.accessTokens[req.AccessToken] - if !ok { - res.Err = "unknown token" - return nil - } - res.Device = &dev - return nil -} - -type testRoomserverAPI struct { - // use a trace API as it implements method stubs so we don't need to have them here. - // We'll override the functions we care about. - roomserver.RoomserverInternalAPITrace - joinEvents []*gomatrixserverlib.HeaderedEvent - events map[string]*gomatrixserverlib.HeaderedEvent - pubRoomState map[string]map[gomatrixserverlib.StateKeyTuple]string -} - -func (r *testRoomserverAPI) QueryServerJoinedToRoom(ctx context.Context, req *roomserver.QueryServerJoinedToRoomRequest, res *roomserver.QueryServerJoinedToRoomResponse) error { - res.IsInRoom = true - res.RoomExists = true - return nil -} - -func (r *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *roomserver.QueryBulkStateContentRequest, res *roomserver.QueryBulkStateContentResponse) error { - res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string) - for _, roomID := range req.RoomIDs { - pubRoomData, ok := r.pubRoomState[roomID] - if ok { - res.Rooms[roomID] = pubRoomData - } - } - return nil -} - -func (r *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *roomserver.QueryCurrentStateRequest, res *roomserver.QueryCurrentStateResponse) error { - res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) - checkEvent := func(he *gomatrixserverlib.HeaderedEvent) { - if he.RoomID() != req.RoomID { - return - } - if he.StateKey() == nil { - return - } - tuple := gomatrixserverlib.StateKeyTuple{ - EventType: he.Type(), - StateKey: *he.StateKey(), - } - for _, t := range req.StateTuples { - if t == tuple { - res.StateEvents[t] = he - } - } - } - for _, he := range r.joinEvents { - checkEvent(he) - } - for _, he := range r.events { - checkEvent(he) - } - return nil -} - -func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*gomatrixserverlib.HeaderedEvent) *mux.Router { - t.Helper() - cfg := &config.Dendrite{} - cfg.Defaults(true) - cfg.Global.ServerName = "localhost" - cfg.MSCs.Database.ConnectionString = "file:msc2946_test.db" - cfg.MSCs.MSCs = []string{"msc2946"} - base := &base.BaseDendrite{ - Cfg: cfg, - PublicClientAPIMux: mux.NewRouter().PathPrefix(httputil.PublicClientPathPrefix).Subrouter(), - PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(), - } - - err := msc2946.Enable(base, rsAPI, userAPI, nil, nil) - if err != nil { - t.Fatalf("failed to enable MSC2946: %s", err) - } - for _, ev := range events { - hooks.Run(hooks.KindNewEventPersisted, ev) - } - return base.PublicClientAPIMux -} - -type fledglingEvent struct { - Type string - StateKey *string - Content interface{} - Sender string - RoomID string -} - -func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) { - t.Helper() - seed := make([]byte, ed25519.SeedSize) // zero seed - key := ed25519.NewKeyFromSeed(seed) - eb := gomatrixserverlib.EventBuilder{ - Sender: ev.Sender, - Depth: 999, - Type: ev.Type, - StateKey: ev.StateKey, - RoomID: ev.RoomID, - } - err := eb.SetContent(ev.Content) - if err != nil { - t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content) - } - // make sure the origin_server_ts changes so we can test recency - time.Sleep(1 * time.Millisecond) - signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer) - if err != nil { - t.Fatalf("mustCreateEvent: failed to sign event: %s", err) - } - h := signedEvent.Headered(roomVer) - return h -} diff --git a/setup/mscs/msc2946/storage.go b/setup/mscs/msc2946/storage.go deleted file mode 100644 index 20db18594..000000000 --- a/setup/mscs/msc2946/storage.go +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package msc2946 - -import ( - "context" - "database/sql" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" -) - -var ( - relTypes = map[string]int{ - ConstSpaceChildEventType: 1, - ConstSpaceParentEventType: 2, - } -) - -type Database interface { - // StoreReference persists a child or parent space mapping. - StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error - // References returns all events which have the given roomID as a parent or child space. - References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) -} - -type DB struct { - db *sql.DB - writer sqlutil.Writer - insertEdgeStmt *sql.Stmt - selectEdgesStmt *sql.Stmt -} - -// NewDatabase loads the database for msc2836 -func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) { - if dbOpts.ConnectionString.IsPostgres() { - return newPostgresDatabase(dbOpts) - } - return newSQLiteDatabase(dbOpts) -} - -func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { - d := DB{ - writer: sqlutil.NewDummyWriter(), - } - var err error - if d.db, err = sqlutil.Open(dbOpts); err != nil { - return nil, err - } - _, err = d.db.Exec(` - CREATE TABLE IF NOT EXISTS msc2946_edges ( - room_version TEXT NOT NULL, - -- the room ID of the event, the source of the arrow - source_room_id TEXT NOT NULL, - -- the target room ID, the arrow destination - dest_room_id TEXT NOT NULL, - -- the kind of relation, either child or parent (1,2) - rel_type SMALLINT NOT NULL, - event_json TEXT NOT NULL, - CONSTRAINT msc2946_edges_uniq UNIQUE (source_room_id, dest_room_id, rel_type) - ); - `) - if err != nil { - return nil, err - } - if d.insertEdgeStmt, err = d.db.Prepare(` - INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json) - VALUES($1, $2, $3, $4, $5) - ON CONFLICT ON CONSTRAINT msc2946_edges_uniq DO UPDATE SET event_json = $5 - `); err != nil { - return nil, err - } - if d.selectEdgesStmt, err = d.db.Prepare(` - SELECT room_version, event_json FROM msc2946_edges - WHERE source_room_id = $1 OR dest_room_id = $2 - `); err != nil { - return nil, err - } - return &d, err -} - -func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { - d := DB{ - writer: sqlutil.NewExclusiveWriter(), - } - var err error - if d.db, err = sqlutil.Open(dbOpts); err != nil { - return nil, err - } - _, err = d.db.Exec(` - CREATE TABLE IF NOT EXISTS msc2946_edges ( - room_version TEXT NOT NULL, - -- the room ID of the event, the source of the arrow - source_room_id TEXT NOT NULL, - -- the target room ID, the arrow destination - dest_room_id TEXT NOT NULL, - -- the kind of relation, either child or parent (1,2) - rel_type SMALLINT NOT NULL, - event_json TEXT NOT NULL, - UNIQUE (source_room_id, dest_room_id, rel_type) - ); - `) - if err != nil { - return nil, err - } - if d.insertEdgeStmt, err = d.db.Prepare(` - INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json) - VALUES($1, $2, $3, $4, $5) - ON CONFLICT (source_room_id, dest_room_id, rel_type) DO UPDATE SET event_json = $5 - `); err != nil { - return nil, err - } - if d.selectEdgesStmt, err = d.db.Prepare(` - SELECT room_version, event_json FROM msc2946_edges - WHERE source_room_id = $1 OR dest_room_id = $2 - `); err != nil { - return nil, err - } - return &d, err -} - -func (d *DB) StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error { - target := SpaceTarget(he) - if target == "" { - return nil // malformed event - } - relType := relTypes[he.Type()] - _, err := d.insertEdgeStmt.ExecContext(ctx, he.RoomVersion, he.RoomID(), target, relType, he.JSON()) - return err -} - -func (d *DB) References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) { - rows, err := d.selectEdgesStmt.QueryContext(ctx, roomID, roomID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "failed to close References") - refs := make([]*gomatrixserverlib.HeaderedEvent, 0) - for rows.Next() { - var roomVer string - var jsonBytes []byte - if err := rows.Scan(&roomVer, &jsonBytes); err != nil { - return nil, err - } - ev, err := gomatrixserverlib.NewEventFromTrustedJSON(jsonBytes, false, gomatrixserverlib.RoomVersion(roomVer)) - if err != nil { - return nil, err - } - he := ev.Headered(gomatrixserverlib.RoomVersion(roomVer)) - refs = append(refs, he) - } - return refs, nil -} - -// SpaceTarget returns the destination room ID for the space event. This is either a child or a parent -// depending on the event type. -func SpaceTarget(he *gomatrixserverlib.HeaderedEvent) string { - if he.StateKey() == nil { - return "" // no-op - } - switch he.Type() { - case ConstSpaceParentEventType: - return *he.StateKey() - case ConstSpaceChildEventType: - return *he.StateKey() - } - return "" -} diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 37a9e2d39..dc4acd8da 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -82,7 +82,16 @@ func DeviceListCatchup( util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed") return to, hasNew, nil } - // QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user. + + // Work out which user IDs we care about — that includes those in the original request, + // the response from QueryKeyChanges (which includes ALL users who have changed keys) + // as well as every user who has a join or leave event in the current sync response. We + // will request information about which rooms these users are joined to, so that we can + // see if we still share any rooms with them. + joinUserIDs, leaveUserIDs := membershipEvents(res) + queryRes.UserIDs = append(queryRes.UserIDs, joinUserIDs...) + queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...) + queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs) var sharedUsersMap map[string]int sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs) util.GetLogger(ctx).Debugf( @@ -100,9 +109,8 @@ func DeviceListCatchup( userSet[userID] = true } } - // if the response has any join/leave events, add them now. + // Finally, add in users who have joined or left. // TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them. - joinUserIDs, leaveUserIDs := membershipEvents(res) for _, userID := range joinUserIDs { if !userSet[userID] { res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) @@ -213,7 +221,8 @@ func filterSharedUsers( var result []string var sharedUsersRes roomserverAPI.QuerySharedUsersResponse err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{ - UserID: userID, + UserID: userID, + OtherUserIDs: usersWithChangedKeys, }, &sharedUsersRes) if err != nil { // default to all users so we do needless queries rather than miss some important device update diff --git a/sytest-blacklist b/sytest-blacklist index 16abce8da..e8617dcdf 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -24,6 +24,7 @@ Local device key changes get to remote servers with correct prev_id # Flakey Local device key changes appear in /keys/changes +/context/ with lazy_load_members filter works # we don't support groups Remove group category diff --git a/sytest-whitelist b/sytest-whitelist index 6e76edd69..2dd56d260 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -602,3 +602,7 @@ Can query remote device keys using POST after notification Device deletion propagates over federation Get left notifs in sync and /keys/changes when other user leaves Remote banned user is kicked and may not rejoin until unbanned +registration remembers parameters +registration accepts non-ascii passwords +registration with inhibit_login inhibits login +