Merge branch 'main' into kegan/dlu-debounce2

This commit is contained in:
Neil Alexander 2022-03-01 13:41:32 +00:00 committed by GitHub
commit 129052c485
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
42 changed files with 893 additions and 1188 deletions

View file

@ -1,5 +1,39 @@
# Changelog
## Dendrite 0.6.4 (2022-02-21)
### Features
* All Client-Server API endpoints are now available under the `/v3` namespace
* The `/whoami` response format now matches the latest Matrix spec version
* Support added for the `/context` endpoint, which should help clients to render quote-replies correctly
* Accounts now have an optional account type field, allowing admin accounts to be created
* Server notices are now supported
* Refactored the user API storage to deduplicate a significant amount of code, as well as merging both user API databases into a single database
* The account database is now used for all user API storage and the device database is now obsolete
* For some installations that have separate account and device databases, this may result in access tokens being revoked and client sessions being logged out — users may need to log in again
* The above can be avoided by moving the `device_devices` table into the account database manually
* Guest registration can now be separately disabled with the new `client_api.guests_disabled` configuration option
* Outbound connections now obey proxy settings from the environment, deprecating the `federation_api.proxy_outbound` configuration options
### Fixes
* The roomserver input API will now strictly consume only one database transaction per room, which should prevent situations where the roomserver can deadlock waiting for database connections to become available
* Room joins will now fall back to federation if the local room state is insufficient to create a membership event
* Create events are now correctly filtered from federation `/send` transactions
* Excessive logging when federation is disabled should now be fixed
* Dendrite will no longer panic if trying to retire an invite event that has not been seen yet
* The device list updater will now wait for longer after a connection issue, rather than flooding the logs with errors
* The device list updater will no longer produce unnecessary output events for federated key updates with no changes, which should help to reduce CPU usage
* Local device name changes will now generate key change events correctly
* The sync API will now try to share device list update notifications even if all state key NIDs cannot be fetched
* An off-by-one error in the sync stream token handling which could result in a crash has been fixed
* State events will no longer be re-sent unnecessary by the roomserver to other components if they have already been sent, which should help to reduce the NATS message sizes on the roomserver output topic in some cases
* The roomserver input API now uses the process context and should handle graceful shutdowns better
* Guest registration is now correctly disabled when the `client_api.registration_disabled` configuration option is set
* One-time encryption keys are now cleaned up correctly when a device is logged out or removed
* Invalid state snapshots in the state storage refactoring migration are now reset rather than causing a panic at startup
## Dendrite 0.6.3 (2022-02-10)
### Features

View file

@ -162,7 +162,7 @@ func AuthFallback(
}
// Success. Add recaptcha as a completed login flow
AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
serveSuccess()
return nil

View file

@ -70,7 +70,7 @@ func UploadCrossSigningDeviceKeys(
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil {
return *authErr
}
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
uploadReq.UserID = device.UserID
keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes)

View file

@ -74,7 +74,7 @@ func Password(
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil {
return *authErr
}
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
// Check the new password strength.
if resErr = validatePassword(r.NewPassword); resErr != nil {

View file

@ -72,14 +72,19 @@ func init() {
// sessionsDict keeps track of completed auth stages for each session.
// It shouldn't be passed by value because it contains a mutex.
type sessionsDict struct {
sync.Mutex
sync.RWMutex
sessions map[string][]authtypes.LoginType
params map[string]registerRequest
timer map[string]*time.Timer
}
// GetCompletedStages returns the completed stages for a session.
func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType {
d.Lock()
defer d.Unlock()
// defaultTimeout is the timeout used to clean up sessions
const defaultTimeOut = time.Minute * 5
// getCompletedStages returns the completed stages for a session.
func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
d.RLock()
defer d.RUnlock()
if completedStages, ok := d.sessions[sessionID]; ok {
return completedStages
@ -88,28 +93,79 @@ func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginTyp
return make([]authtypes.LoginType, 0)
}
func newSessionsDict() *sessionsDict {
return &sessionsDict{
sessions: make(map[string][]authtypes.LoginType),
// addParams adds a registerRequest to a sessionID and starts a timer to delete that registerRequest
func (d *sessionsDict) addParams(sessionID string, r registerRequest) {
d.startTimer(defaultTimeOut, sessionID)
d.Lock()
defer d.Unlock()
d.params[sessionID] = r
}
func (d *sessionsDict) getParams(sessionID string) (registerRequest, bool) {
d.RLock()
defer d.RUnlock()
r, ok := d.params[sessionID]
return r, ok
}
// deleteSession cleans up a given session, either because the registration completed
// successfully, or because a given timeout (default: 5min) was reached.
func (d *sessionsDict) deleteSession(sessionID string) {
d.Lock()
defer d.Unlock()
delete(d.params, sessionID)
delete(d.sessions, sessionID)
// stop the timer, e.g. because the registration was completed
if t, ok := d.timer[sessionID]; ok {
if !t.Stop() {
select {
case <-t.C:
default:
}
}
delete(d.timer, sessionID)
}
}
// AddCompletedSessionStage records that a session has completed an auth stage.
func AddCompletedSessionStage(sessionID string, stage authtypes.LoginType) {
sessions.Lock()
defer sessions.Unlock()
func newSessionsDict() *sessionsDict {
return &sessionsDict{
sessions: make(map[string][]authtypes.LoginType),
params: make(map[string]registerRequest),
timer: make(map[string]*time.Timer),
}
}
for _, completedStage := range sessions.sessions[sessionID] {
func (d *sessionsDict) startTimer(duration time.Duration, sessionID string) {
d.Lock()
defer d.Unlock()
t, ok := d.timer[sessionID]
if ok {
if !t.Stop() {
<-t.C
}
t.Reset(duration)
return
}
d.timer[sessionID] = time.AfterFunc(duration, func() {
d.deleteSession(sessionID)
})
}
// addCompletedSessionStage records that a session has completed an auth stage
// also starts a timer to delete the session once done.
func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtypes.LoginType) {
d.startTimer(defaultTimeOut, sessionID)
d.Lock()
defer d.Unlock()
for _, completedStage := range d.sessions[sessionID] {
if completedStage == stage {
return
}
}
sessions.sessions[sessionID] = append(sessions.sessions[sessionID], stage)
d.sessions[sessionID] = append(sessions.sessions[sessionID], stage)
}
var (
// TODO: Remove old sessions. Need to do so on a session-specific timeout.
// sessions stores the completed flow stages for all sessions. Referenced using their sessionID.
sessions = newSessionsDict()
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
)
@ -167,7 +223,7 @@ func newUserInteractiveResponse(
params map[string]interface{},
) userInteractiveResponse {
return userInteractiveResponse{
fs, sessions.GetCompletedStages(sessionID), params, sessionID,
fs, sessions.getCompletedStages(sessionID), params, sessionID,
}
}
@ -645,12 +701,12 @@ func handleRegistrationFlow(
}
// Add Recaptcha to the list of completed registration stages
AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
case authtypes.LoginTypeDummy:
// there is nothing to do
// Add Dummy to the list of completed registration stages
AddCompletedSessionStage(sessionID, authtypes.LoginTypeDummy)
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeDummy)
case "":
// An empty auth type means that we want to fetch the available
@ -666,7 +722,7 @@ func handleRegistrationFlow(
// Check if the user's registration flow has been completed successfully
// A response with current registration flow and remaining available methods
// will be returned if a flow has not been successfully completed yet
return checkAndCompleteFlow(sessions.GetCompletedStages(sessionID),
return checkAndCompleteFlow(sessions.getCompletedStages(sessionID),
req, r, sessionID, cfg, userAPI)
}
@ -708,7 +764,7 @@ func handleApplicationServiceRegistration(
// Don't need to worry about appending to registration stages as
// application service registration is entirely separate.
return completeRegistration(
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(),
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService,
)
}
@ -727,11 +783,11 @@ func checkAndCompleteFlow(
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
// This flow was completed, registration can continue
return completeRegistration(
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(),
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser,
)
}
sessions.addParams(sessionID, r)
// There are still more stages to complete.
// Return the flows and those that have been completed.
return util.JSONResponse{
@ -750,11 +806,25 @@ func checkAndCompleteFlow(
func completeRegistration(
ctx context.Context,
userAPI userapi.UserInternalAPI,
username, password, appserviceID, ipAddr, userAgent string,
username, password, appserviceID, ipAddr, userAgent, sessionID string,
inhibitLogin eventutil.WeakBoolean,
displayName, deviceID *string,
accType userapi.AccountType,
) util.JSONResponse {
var registrationOK bool
defer func() {
if registrationOK {
sessions.deleteSession(sessionID)
}
}()
if data, ok := sessions.getParams(sessionID); ok {
username = data.Username
password = data.Password
deviceID = data.DeviceID
displayName = data.InitialDisplayName
inhibitLogin = data.InhibitLogin
}
if username == "" {
return util.JSONResponse{
Code: http.StatusBadRequest,
@ -795,6 +865,7 @@ func completeRegistration(
// Check whether inhibit_login option is set. If so, don't create an access
// token or a device for this user
if inhibitLogin {
registrationOK = true
return util.JSONResponse{
Code: http.StatusOK,
JSON: registerResponse{
@ -828,6 +899,7 @@ func completeRegistration(
}
}
registrationOK = true
return util.JSONResponse{
Code: http.StatusOK,
JSON: registerResponse{
@ -976,5 +1048,5 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS
if ssrr.Admin {
accType = userapi.AccountTypeAdmin
}
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID, accType)
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType)
}

View file

@ -17,6 +17,7 @@ package routing
import (
"regexp"
"testing"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/setup/config"
@ -140,7 +141,7 @@ func TestFlowCheckingExtraneousIncorrectInput(t *testing.T) {
func TestEmptyCompletedFlows(t *testing.T) {
fakeEmptySessions := newSessionsDict()
fakeSessionID := "aRandomSessionIDWhichDoesNotExist"
ret := fakeEmptySessions.GetCompletedStages(fakeSessionID)
ret := fakeEmptySessions.getCompletedStages(fakeSessionID)
// check for []
if ret == nil || len(ret) != 0 {
@ -208,3 +209,45 @@ func TestValidationOfApplicationServices(t *testing.T) {
t.Errorf("user_id should not have been valid: @_something_else:localhost")
}
}
func TestSessionCleanUp(t *testing.T) {
s := newSessionsDict()
t.Run("session is cleaned up after a while", func(t *testing.T) {
t.Parallel()
dummySession := "helloWorld"
// manually added, as s.addParams() would start the timer with the default timeout
s.params[dummySession] = registerRequest{Username: "Testing"}
s.startTimer(time.Millisecond, dummySession)
time.Sleep(time.Millisecond * 2)
if data, ok := s.getParams(dummySession); ok {
t.Errorf("expected session to be deleted: %+v", data)
}
})
t.Run("session is deleted, once the registration completed", func(t *testing.T) {
t.Parallel()
dummySession := "helloWorld2"
s.startTimer(time.Minute, dummySession)
s.deleteSession(dummySession)
if data, ok := s.getParams(dummySession); ok {
t.Errorf("expected session to be deleted: %+v", data)
}
})
t.Run("session timer is restarted after second call", func(t *testing.T) {
t.Parallel()
dummySession := "helloWorld3"
// the following will start a timer with the default timeout of 5min
s.addParams(dummySession, registerRequest{Username: "Testing"})
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha)
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy)
s.getCompletedStages(dummySession)
// reset the timer with a lower timeout
s.startTimer(time.Millisecond, dummySession)
time.Sleep(time.Millisecond * 2)
if data, ok := s.getParams(dummySession); ok {
t.Errorf("expected session to be deleted: %+v", data)
}
})
}

View file

@ -235,7 +235,7 @@ func OnIncomingStateTypeRequest(
}
// If the user has never been in the room then stop at this point.
// We won't tell the user about a room they have never joined.
if !membershipRes.HasBeenInRoom {
if !membershipRes.HasBeenInRoom || membershipRes.Membership == gomatrixserverlib.Ban {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)),

View file

@ -48,24 +48,27 @@ Example:
# read password from stdin
%s --config dendrite.yaml -username alice -passwordstdin < my.pass
cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin
# reset password for a user, can be used with a combination above to read the password
%s --config dendrite.yaml -reset-password -username alice -password foobarbaz
Arguments:
`
var (
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)")
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
isAdmin = flag.Bool("admin", false, "Create an admin account")
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)")
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
isAdmin = flag.Bool("admin", false, "Create an admin account")
resetPassword = flag.Bool("reset-password", false, "Resets the password for the given username")
)
func main() {
name := os.Args[0]
flag.Usage = func() {
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name)
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name, name)
flag.PrintDefaults()
}
cfg := setup.ParseFlags(true)
@ -93,6 +96,19 @@ func main() {
if *isAdmin {
accType = api.AccountTypeAdmin
}
if *resetPassword {
err = accountDB.SetPassword(context.Background(), *username, pass)
if err != nil {
logrus.Fatalf("Failed to update password for user %s: %s", *username, err.Error())
}
if _, err = accountDB.RemoveAllDevices(context.Background(), *username, ""); err != nil {
logrus.Fatalf("Failed to remove all devices: %s", err.Error())
}
logrus.Infof("Updated password for user %s and invalidated all logins\n", *username)
return
}
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "", accType)
if err != nil {
logrus.Fatalln("Failed to create the account:", err.Error())

View file

@ -13,6 +13,7 @@ Group=dendrite
WorkingDirectory=/opt/dendrite/
ExecStart=/opt/dendrite/bin/dendrite-monolith-server
Restart=always
LimitNOFILE=65535
[Install]
WantedBy=multi-user.target

View file

@ -21,7 +21,7 @@ type FederationClient interface {
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)

View file

@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
}
func (a *FederationInternalAPI) MSC2946Spaces(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest,
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.MSC2946Spaces(ctx, s, roomID, r)
return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly)
})
if err != nil {
return res, err

View file

@ -526,23 +526,23 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships(
}
type spacesReq struct {
S gomatrixserverlib.ServerName
Req gomatrixserverlib.MSC2946SpacesRequest
RoomID string
Res gomatrixserverlib.MSC2946SpacesResponse
Err *api.FederationClientError
S gomatrixserverlib.ServerName
SuggestedOnly bool
RoomID string
Res gomatrixserverlib.MSC2946SpacesResponse
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) MSC2946Spaces(
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest,
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces")
defer span.Finish()
request := spacesReq{
S: dst,
Req: r,
RoomID: roomID,
S: dst,
SuggestedOnly: suggestedOnly,
RoomID: roomID,
}
var response spacesReq
apiURL := h.federationAPIURL + FederationAPISpacesSummaryPath

View file

@ -378,7 +378,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.Req)
res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.SuggestedOnly)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {

9
go.mod
View file

@ -18,6 +18,7 @@ require (
github.com/frankban/quicktest v1.14.0 // indirect
github.com/getsentry/sentry-go v0.12.0
github.com/gologme/log v1.3.0
github.com/google/uuid v1.2.0
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.4.2
github.com/h2non/filetype v1.1.3 // indirect
@ -39,8 +40,8 @@ require (
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed
github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.10
github.com/morikuni/aec v1.0.0 // indirect
@ -61,11 +62,11 @@ require (
github.com/uber/jaeger-lib v2.4.1+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.4.2
go.uber.org/atomic v1.9.0
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a
golang.org/x/crypto v0.0.0-20220214200702-86341886e292
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
golang.org/x/mobile v0.0.0-20220112015953-858099ff7816
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
golang.org/x/sys v0.0.0-20220207234003-57398862261d // indirect
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
gopkg.in/h2non/bimg.v1 v1.1.5
gopkg.in/yaml.v2 v2.4.0

16
go.sum
View file

@ -983,10 +983,10 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed h1:R8EiLWArq7KT96DrUq1xq9scPh8vLwKKeCTnORPyjhU=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY=
github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf h1:/nqfHUdQHr3WVdbZieaYFvHF1rin5pvDTa/NOZ/qCyE=
github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6 h1:CqbM5tPbF1mV2h1J6N0NSF8et6HvFlwA98TLCUcjiSI=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220301125114-e6012a13a6e6/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo=
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk=
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
@ -1510,8 +1510,8 @@ golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo=
golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE=
golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -1737,8 +1737,8 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc=
golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 h1:nhht2DYV/Sn3qOayu8lM+cU1ii9sTLUeBQwQQfUHtrs=
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=

View file

@ -17,7 +17,7 @@ var build string
const (
VersionMajor = 0
VersionMinor = 6
VersionPatch = 3
VersionPatch = 4
VersionTag = "" // example: "rc1"
)

View file

@ -15,6 +15,7 @@
package api
import (
"bytes"
"context"
"encoding/json"
"strings"
@ -73,6 +74,26 @@ type DeviceMessage struct {
DeviceChangeID int64
}
// DeviceKeysEqual returns true if the device keys updates contain the
// same display name and key JSON. This will return false if either of
// the updates is not a device keys update, or if the user ID/device ID
// differ between the two.
func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool {
if m1.DeviceKeys == nil || m2.DeviceKeys == nil {
return false
}
if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID {
return false
}
if m1.DisplayName != m2.DisplayName {
return false // different display names
}
if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 {
return false // either is empty
}
return bytes.Equal(m1.KeyJSON, m2.KeyJSON)
}
// DeviceKeys represents a set of device keys for a single device
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
type DeviceKeys struct {

View file

@ -166,26 +166,53 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
}
// We can't have a self-signing or user-signing key without a master
// key, so make sure we have one of those.
if !hasMasterKey {
existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID)
if err != nil {
res.Error = &api.KeyError{
Err: "Retrieving cross-signing keys from database failed: " + err.Error(),
}
return
// key, so make sure we have one of those. We will also only actually do
// something if any of the specified keys in the request are different
// to what we've got in the database, to avoid generating key change
// notifications unnecessarily.
existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID)
if err != nil {
res.Error = &api.KeyError{
Err: "Retrieving cross-signing keys from database failed: " + err.Error(),
}
_, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]
return
}
// If we still can't find a master key for the user then stop the upload.
// This satisfies the "Fails to upload self-signing key without master key" test.
if !hasMasterKey {
res.Error = &api.KeyError{
Err: "No master key was found",
IsMissingParam: true,
if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey {
res.Error = &api.KeyError{
Err: "No master key was found",
IsMissingParam: true,
}
return
}
}
// Check if anything actually changed compared to what we have in the database.
changed := false
for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{
gomatrixserverlib.CrossSigningKeyPurposeMaster,
gomatrixserverlib.CrossSigningKeyPurposeSelfSigning,
gomatrixserverlib.CrossSigningKeyPurposeUserSigning,
} {
old, gotOld := existingKeys[purpose]
new, gotNew := toStore[purpose]
if gotOld != gotNew {
// A new key purpose has been specified that we didn't know before,
// or one has been removed.
changed = true
break
}
if !bytes.Equal(old, new) {
// One of the existing keys for a purpose we already knew about has
// changed.
changed = true
break
}
}
if !changed {
return
}

View file

@ -224,7 +224,7 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
}).Info("DeviceListUpdater.Update")
// if we haven't missed anything update the database and notify users
if exists {
if exists || event.Deleted {
k := event.Keys
if event.Deleted {
k = nil
@ -267,7 +267,7 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err)
}
if err = emitDeviceKeyChanges(u.producer, existingKeys, keys); err != nil {
if err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false); err != nil {
return false, fmt.Errorf("failed to produce device key changes for %s (%s): %w", event.UserID, event.DeviceID, err)
}
return false, nil
@ -473,7 +473,7 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi
if err != nil {
return fmt.Errorf("failed to mark device list as fresh: %w", err)
}
err = emitDeviceKeyChanges(u.producer, existingKeys, keys)
err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false)
if err != nil {
return fmt.Errorf("failed to emit key changes for fresh device list: %w", err)
}

View file

@ -648,7 +648,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
}
return
}
err = emitDeviceKeyChanges(a.Producer, existingKeys, keysToStore)
err = emitDeviceKeyChanges(a.Producer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates)
if err != nil {
util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
}
@ -710,7 +710,11 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
}
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage) error {
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
// if we only want to update the display names, we can skip the checks below
if onlyUpdateDisplayName {
return producer.ProduceKeyChanges(new)
}
// find keys in new that are not in existing
var keysAdded []api.DeviceMessage
for _, newKey := range new {
@ -718,7 +722,7 @@ func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.Device
for _, existingKey := range existing {
// Do not treat the absence of keys as equal, or else we will not emit key changes
// when users delete devices which never had a key to begin with as both KeyJSONs are nil.
if bytes.Equal(existingKey.KeyJSON, newKey.KeyJSON) && len(existingKey.KeyJSON) > 0 {
if existingKey.DeviceKeysEqual(&newKey) {
exists = true
break
}

View file

@ -269,6 +269,7 @@ type QueryAuthChainResponse struct {
type QuerySharedUsersRequest struct {
UserID string
OtherUserIDs []string
ExcludeRoomIDs []string
IncludeRoomIDs []string
}
@ -312,7 +313,10 @@ type QueryBulkStateContentResponse struct {
}
type QueryCurrentStateRequest struct {
RoomID string
RoomID string
AllowWildcards bool
// State key tuples. If a state_key has '*' and AllowWidlcards is true, returns all matching
// state events with that event type.
StateTuples []gomatrixserverlib.StateKeyTuple
}

View file

@ -51,12 +51,8 @@ func SendEventWithState(
state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent,
origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool,
) error {
outliers, err := state.Events(event.RoomVersion)
if err != nil {
return err
}
var ires []InputRoomEvent
outliers := state.Events(event.RoomVersion)
ires := make([]InputRoomEvent, 0, len(outliers))
for _, outlier := range outliers {
if haveEventIDs[outlier.EventID()] {
continue

View file

@ -20,22 +20,17 @@ import (
"sort"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
type checkForAuthAndSoftFailStorage interface {
state.StateResolutionStorage
StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error)
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
}
// CheckForSoftFail returns true if the event should be soft-failed
// and false otherwise. The return error value should be checked before
// the soft-fail bool.
func CheckForSoftFail(
ctx context.Context,
db checkForAuthAndSoftFailStorage,
db storage.Database,
event *gomatrixserverlib.HeaderedEvent,
stateEventIDs []string,
) (bool, error) {
@ -97,7 +92,7 @@ func CheckForSoftFail(
// Returns the numeric IDs for the auth events.
func CheckAuthEvents(
ctx context.Context,
db checkForAuthAndSoftFailStorage,
db storage.Database,
event *gomatrixserverlib.HeaderedEvent,
authEventIDs []string,
) ([]types.EventNID, error) {

View file

@ -19,7 +19,6 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
"time"
@ -40,19 +39,6 @@ import (
"github.com/tidwall/gjson"
)
type retryAction int
type commitAction int
const (
doNotRetry retryAction = iota
retryLater
)
const (
commitTransaction commitAction = iota
rollbackTransaction
)
var keyContentFields = map[string]string{
"m.room.join_rules": "join_rule",
"m.room.history_visibility": "history_visibility",
@ -117,8 +103,7 @@ func (r *Inputer) Start() error {
_ = msg.InProgress() // resets the acknowledgement wait timer
defer eventsInProgress.Delete(index)
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
action, err := r.processRoomEventUsingUpdater(r.ProcessContext.Context(), roomID, &inputRoomEvent)
if err != nil {
if err := r.processRoomEvent(r.ProcessContext.Context(), &inputRoomEvent); err != nil {
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
sentry.CaptureException(err)
}
@ -127,11 +112,8 @@ func (r *Inputer) Start() error {
"event_id": inputRoomEvent.Event.EventID(),
"type": inputRoomEvent.Event.Type(),
}).Warn("Roomserver failed to process async event")
}
switch action {
case retryLater:
_ = msg.Nak()
case doNotRetry:
_ = msg.Term()
} else {
_ = msg.Ack()
}
})
@ -153,37 +135,6 @@ func (r *Inputer) Start() error {
return err
}
// processRoomEventUsingUpdater opens up a room updater and tries to
// process the event. It returns whether or not we should positively
// or negatively acknowledge the event (i.e. for NATS) and an error
// if it occurred.
func (r *Inputer) processRoomEventUsingUpdater(
ctx context.Context,
roomID string,
inputRoomEvent *api.InputRoomEvent,
) (retryAction, error) {
roomInfo, err := r.DB.RoomInfo(ctx, roomID)
if err != nil {
return doNotRetry, fmt.Errorf("r.DB.RoomInfo: %w", err)
}
updater, err := r.DB.GetRoomUpdater(ctx, roomInfo)
if err != nil {
return retryLater, fmt.Errorf("r.DB.GetRoomUpdater: %w", err)
}
action, err := r.processRoomEvent(ctx, updater, inputRoomEvent)
switch action {
case commitTransaction:
if cerr := updater.Commit(); cerr != nil {
return retryLater, fmt.Errorf("updater.Commit: %w", cerr)
}
case rollbackTransaction:
if rerr := updater.Rollback(); rerr != nil {
return retryLater, fmt.Errorf("updater.Rollback: %w", rerr)
}
}
return doNotRetry, err
}
// InputRoomEvents implements api.RoomserverInternalAPI
func (r *Inputer) InputRoomEvents(
ctx context.Context,
@ -230,7 +181,7 @@ func (r *Inputer) InputRoomEvents(
worker.Act(nil, func() {
defer eventsInProgress.Delete(index)
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
_, err := r.processRoomEventUsingUpdater(ctx, roomID, &inputRoomEvent)
err := r.processRoomEvent(ctx, &inputRoomEvent)
if err != nil {
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
sentry.CaptureException(err)

View file

@ -26,10 +26,10 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/hooks"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@ -68,15 +68,14 @@ var processRoomEventDuration = prometheus.NewHistogramVec(
// nolint:gocyclo
func (r *Inputer) processRoomEvent(
ctx context.Context,
updater *shared.RoomUpdater,
input *api.InputRoomEvent,
) (commitAction, error) {
) error {
select {
case <-ctx.Done():
// Before we do anything, make sure the context hasn't expired for this pending task.
// If it has then we'll give up straight away — it's probably a synchronous input
// request and the caller has already given up, but the inbox task was still queued.
return rollbackTransaction, context.DeadlineExceeded
return context.DeadlineExceeded
default:
}
@ -109,7 +108,7 @@ func (r *Inputer) processRoomEvent(
// if we have already got this event then do not process it again, if the input kind is an outlier.
// Outliers contain no extra information which may warrant a re-processing.
if input.Kind == api.KindOutlier {
evs, err2 := updater.EventsFromIDs(ctx, []string{event.EventID()})
evs, err2 := r.DB.EventsFromIDs(ctx, []string{event.EventID()})
if err2 == nil && len(evs) == 1 {
// check hash matches if we're on early room versions where the event ID was a random string
idFormat, err2 := headered.RoomVersion.EventIDFormat()
@ -118,11 +117,11 @@ func (r *Inputer) processRoomEvent(
case gomatrixserverlib.EventIDFormatV1:
if bytes.Equal(event.EventReference().EventSHA256, evs[0].EventReference().EventSHA256) {
logger.Debugf("Already processed event; ignoring")
return rollbackTransaction, nil
return nil
}
default:
logger.Debugf("Already processed event; ignoring")
return rollbackTransaction, nil
return nil
}
}
}
@ -131,17 +130,21 @@ func (r *Inputer) processRoomEvent(
// Don't waste time processing the event if the room doesn't exist.
// A room entry locally will only be created in response to a create
// event.
roomInfo, rerr := r.DB.RoomInfo(ctx, event.RoomID())
if rerr != nil {
return fmt.Errorf("r.DB.RoomInfo: %w", rerr)
}
isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("")
if !updater.RoomExists() && !isCreateEvent {
return rollbackTransaction, fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
if roomInfo == nil && !isCreateEvent {
return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
}
var missingAuth, missingPrev bool
serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{}
if !isCreateEvent {
missingAuthIDs, missingPrevIDs, err := updater.MissingAuthPrevEvents(ctx, event)
missingAuthIDs, missingPrevIDs, err := r.DB.MissingAuthPrevEvents(ctx, event)
if err != nil {
return rollbackTransaction, fmt.Errorf("updater.MissingAuthPrevEvents: %w", err)
return fmt.Errorf("updater.MissingAuthPrevEvents: %w", err)
}
missingAuth = len(missingAuthIDs) > 0
missingPrev = !input.HasState && len(missingPrevIDs) > 0
@ -153,7 +156,7 @@ func (r *Inputer) processRoomEvent(
ExcludeSelf: true,
}
if err := r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil {
return rollbackTransaction, fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)
return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)
}
// Sort all of the servers into a map so that we can randomise
// their order. Then make sure that the input origin and the
@ -182,8 +185,8 @@ func (r *Inputer) processRoomEvent(
isRejected := false
authEvents := gomatrixserverlib.NewAuthEvents(nil)
knownEvents := map[string]*types.Event{}
if err := r.fetchAuthEvents(ctx, updater, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
return rollbackTransaction, fmt.Errorf("r.fetchAuthEvents: %w", err)
if err := r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
return fmt.Errorf("r.fetchAuthEvents: %w", err)
}
// Check if the event is allowed by its auth events. If it isn't then
@ -205,12 +208,12 @@ func (r *Inputer) processRoomEvent(
// but weren't found.
if isRejected {
if event.StateKey() != nil {
return commitTransaction, fmt.Errorf(
return fmt.Errorf(
"missing auth event %s for state event %s (type %q, state key %q)",
authEventID, event.EventID(), event.Type(), *event.StateKey(),
)
} else {
return commitTransaction, fmt.Errorf(
return fmt.Errorf(
"missing auth event %s for timeline event %s (type %q)",
authEventID, event.EventID(), event.Type(),
)
@ -226,7 +229,7 @@ func (r *Inputer) processRoomEvent(
// Check that the event passes authentication checks based on the
// current room state.
var err error
softfail, err = helpers.CheckForSoftFail(ctx, updater, headered, input.StateEventIDs)
softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs)
if err != nil {
logger.WithError(err).Warn("Error authing soft-failed event")
}
@ -250,7 +253,8 @@ func (r *Inputer) processRoomEvent(
missingState := missingStateReq{
origin: input.Origin,
inputer: r,
db: updater,
db: r.DB,
roomInfo: roomInfo,
federation: r.FSAPI,
keys: r.KeyRing,
roomsMu: internal.NewMutexByRoom(),
@ -290,16 +294,16 @@ func (r *Inputer) processRoomEvent(
}
// Store the event.
_, _, stateAtEvent, redactionEvent, redactedEventID, err := updater.StoreEvent(ctx, event, authEventNIDs, isRejected)
_, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected)
if err != nil {
return rollbackTransaction, fmt.Errorf("updater.StoreEvent: %w", err)
return fmt.Errorf("updater.StoreEvent: %w", err)
}
// if storing this event results in it being redacted then do so.
if !isRejected && redactedEventID == event.EventID() {
r, rerr := eventutil.RedactEvent(redactionEvent, event)
if rerr != nil {
return rollbackTransaction, fmt.Errorf("eventutil.RedactEvent: %w", rerr)
return fmt.Errorf("eventutil.RedactEvent: %w", rerr)
}
event = r
}
@ -310,23 +314,25 @@ func (r *Inputer) processRoomEvent(
if input.Kind == api.KindOutlier {
logger.Debug("Stored outlier")
hooks.Run(hooks.KindNewEventPersisted, headered)
return commitTransaction, nil
return nil
}
roomInfo, err := updater.RoomInfo(ctx, event.RoomID())
// Request the room info again — it's possible that the room has been
// created by now if it didn't exist already.
roomInfo, err = r.DB.RoomInfo(ctx, event.RoomID())
if err != nil {
return rollbackTransaction, fmt.Errorf("updater.RoomInfo: %w", err)
return fmt.Errorf("updater.RoomInfo: %w", err)
}
if roomInfo == nil {
return rollbackTransaction, fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID())
return fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID())
}
if input.HasState || (!missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0) {
// We haven't calculated a state for this event yet.
// Lets calculate one.
err = r.calculateAndSetState(ctx, updater, input, roomInfo, &stateAtEvent, event, isRejected)
err = r.calculateAndSetState(ctx, input, roomInfo, &stateAtEvent, event, isRejected)
if err != nil {
return rollbackTransaction, fmt.Errorf("r.calculateAndSetState: %w", err)
return fmt.Errorf("r.calculateAndSetState: %w", err)
}
}
@ -337,16 +343,15 @@ func (r *Inputer) processRoomEvent(
"missing_prev": missingPrev,
}).Warn("Stored rejected event")
if rejectionErr != nil {
return commitTransaction, types.RejectedError(rejectionErr.Error())
return types.RejectedError(rejectionErr.Error())
}
return commitTransaction, nil
return nil
}
switch input.Kind {
case api.KindNew:
if err = r.updateLatestEvents(
ctx, // context
updater, // room updater
roomInfo, // room info for the room being updated
stateAtEvent, // state at event (below)
event, // event
@ -354,7 +359,7 @@ func (r *Inputer) processRoomEvent(
input.TransactionID, // transaction ID
input.HasState, // rewrites state?
); err != nil {
return rollbackTransaction, fmt.Errorf("r.updateLatestEvents: %w", err)
return fmt.Errorf("r.updateLatestEvents: %w", err)
}
case api.KindOld:
err = r.WriteOutputEvents(event.RoomID(), []api.OutputEvent{
@ -366,7 +371,7 @@ func (r *Inputer) processRoomEvent(
},
})
if err != nil {
return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (old): %w", err)
return fmt.Errorf("r.WriteOutputEvents (old): %w", err)
}
}
@ -385,14 +390,14 @@ func (r *Inputer) processRoomEvent(
},
})
if err != nil {
return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (redactions): %w", err)
return fmt.Errorf("r.WriteOutputEvents (redactions): %w", err)
}
}
// Everything was OK — the latest events updater didn't error and
// we've sent output events. Finally, generate a hook call.
hooks.Run(hooks.KindNewEventPersisted, headered)
return commitTransaction, nil
return nil
}
// fetchAuthEvents will check to see if any of the
@ -404,7 +409,6 @@ func (r *Inputer) processRoomEvent(
// they are now in the database.
func (r *Inputer) fetchAuthEvents(
ctx context.Context,
updater *shared.RoomUpdater,
logger *logrus.Entry,
event *gomatrixserverlib.HeaderedEvent,
auth *gomatrixserverlib.AuthEvents,
@ -418,7 +422,7 @@ func (r *Inputer) fetchAuthEvents(
}
for _, authEventID := range authEventIDs {
authEvents, err := updater.EventsFromIDs(ctx, []string{authEventID})
authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID})
if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil {
unknown[authEventID] = struct{}{}
continue
@ -495,7 +499,7 @@ nextAuthEvent:
}
// Finally, store the event in the database.
eventNID, _, _, _, _, err := updater.StoreEvent(ctx, authEvent, authEventNIDs, isRejected)
eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected)
if err != nil {
return fmt.Errorf("updater.StoreEvent: %w", err)
}
@ -520,14 +524,18 @@ nextAuthEvent:
func (r *Inputer) calculateAndSetState(
ctx context.Context,
updater *shared.RoomUpdater,
input *api.InputRoomEvent,
roomInfo *types.RoomInfo,
stateAtEvent *types.StateAtEvent,
event *gomatrixserverlib.Event,
isRejected bool,
) error {
var err error
var succeeded bool
updater, err := r.DB.GetRoomUpdater(ctx, roomInfo)
if err != nil {
return fmt.Errorf("r.DB.GetRoomUpdater: %w", err)
}
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
roomState := state.NewStateResolution(updater, roomInfo)
if input.HasState {
@ -536,7 +544,7 @@ func (r *Inputer) calculateAndSetState(
// We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state.
var entries []types.StateEntry
if entries, err = updater.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err)
}
entries = types.DeduplicateStateEntries(entries)
@ -557,5 +565,6 @@ func (r *Inputer) calculateAndSetState(
if err != nil {
return fmt.Errorf("r.DB.SetState: %w", err)
}
succeeded = true
return nil
}

View file

@ -20,6 +20,7 @@ import (
"context"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
@ -47,7 +48,6 @@ import (
// Can only be called once at a time
func (r *Inputer) updateLatestEvents(
ctx context.Context,
updater *shared.RoomUpdater,
roomInfo *types.RoomInfo,
stateAtEvent types.StateAtEvent,
event *gomatrixserverlib.Event,
@ -55,6 +55,14 @@ func (r *Inputer) updateLatestEvents(
transactionID *api.TransactionID,
rewritesState bool,
) (err error) {
var succeeded bool
updater, err := r.DB.GetRoomUpdater(ctx, roomInfo)
if err != nil {
return fmt.Errorf("r.DB.GetRoomUpdater: %w", err)
}
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
u := latestEventsUpdater{
ctx: ctx,
api: r,
@ -71,6 +79,7 @@ func (r *Inputer) updateLatestEvents(
return fmt.Errorf("u.doUpdateLatestEvents: %w", err)
}
succeeded = true
return
}

View file

@ -11,7 +11,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@ -23,9 +23,25 @@ type parsedRespState struct {
StateEvents []*gomatrixserverlib.Event
}
func (p *parsedRespState) Events() []*gomatrixserverlib.Event {
eventsByID := make(map[string]*gomatrixserverlib.Event, len(p.AuthEvents)+len(p.StateEvents))
for i, event := range p.AuthEvents {
eventsByID[event.EventID()] = p.AuthEvents[i]
}
for i, event := range p.StateEvents {
eventsByID[event.EventID()] = p.StateEvents[i]
}
allEvents := make([]*gomatrixserverlib.Event, 0, len(eventsByID))
for _, event := range eventsByID {
allEvents = append(allEvents, event)
}
return gomatrixserverlib.ReverseTopologicalOrdering(allEvents, gomatrixserverlib.TopologicalOrderByAuthEvents)
}
type missingStateReq struct {
origin gomatrixserverlib.ServerName
db *shared.RoomUpdater
db storage.Database
roomInfo *types.RoomInfo
inputer *Inputer
keys gomatrixserverlib.JSONVerifier
federation fedapi.FederationInternalAPI
@ -80,7 +96,7 @@ func (t *missingStateReq) processEventWithMissingState(
// we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled
// in the gap in the DAG
for _, newEvent := range newEvents {
_, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{
Kind: api.KindOld,
Event: newEvent.Headered(roomVersion),
Origin: t.origin,
@ -123,11 +139,8 @@ func (t *missingStateReq) processEventWithMissingState(
t.hadEventsMutex.Unlock()
sendOutliers := func(resolvedState *parsedRespState) error {
outliers, oerr := gomatrixserverlib.OrderAuthAndStateEvents(resolvedState.AuthEvents, resolvedState.StateEvents, roomVersion)
if oerr != nil {
return fmt.Errorf("gomatrixserverlib.OrderAuthAndStateEvents: %w", oerr)
}
var outlierRoomEvents []api.InputRoomEvent
outliers := resolvedState.Events()
outlierRoomEvents := make([]api.InputRoomEvent, 0, len(outliers))
for _, outlier := range outliers {
if hadEvents[outlier.EventID()] {
continue
@ -139,8 +152,7 @@ func (t *missingStateReq) processEventWithMissingState(
})
}
for _, ire := range outlierRoomEvents {
_, err = t.inputer.processRoomEvent(ctx, t.db, &ire)
if err != nil {
if err = t.inputer.processRoomEvent(ctx, &ire); err != nil {
if _, ok := err.(types.RejectedError); !ok {
return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err)
}
@ -163,7 +175,7 @@ func (t *missingStateReq) processEventWithMissingState(
stateIDs = append(stateIDs, event.EventID())
}
_, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{
Kind: api.KindOld,
Event: backwardsExtremity.Headered(roomVersion),
Origin: t.origin,
@ -182,7 +194,7 @@ func (t *missingStateReq) processEventWithMissingState(
// they will automatically fast-forward based on the room state at the
// extremity in the last step.
for _, newEvent := range newEvents {
_, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{
Kind: api.KindOld,
Event: newEvent.Headered(roomVersion),
Origin: t.origin,
@ -473,8 +485,10 @@ retryAllowedState:
// without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events
func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled, prevStateKnown bool, err error) {
logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID())
latest := t.db.LatestEvents()
latest, _, _, err := t.db.LatestEventIDs(ctx, t.roomInfo.RoomNID)
if err != nil {
return nil, false, false, fmt.Errorf("t.DB.LatestEventIDs: %w", err)
}
latestEvents := make([]string, len(latest))
for i, ev := range latest {
latestEvents[i] = ev.EventID

View file

@ -621,12 +621,25 @@ func (r *Queryer) QueryPublishedRooms(
func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error {
res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
for _, tuple := range req.StateTuples {
ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey)
if err != nil {
return err
}
if ev != nil {
res.StateEvents[tuple] = ev
if tuple.StateKey == "*" && req.AllowWildcards {
events, err := r.DB.GetStateEventsWithEventType(ctx, req.RoomID, tuple.EventType)
if err != nil {
return err
}
for _, e := range events {
res.StateEvents[gomatrixserverlib.StateKeyTuple{
EventType: e.Type(),
StateKey: *e.StateKey(),
}] = e
}
} else {
ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey)
if err != nil {
return err
}
if ev != nil {
res.StateEvents[tuple] = ev
}
}
}
return nil
@ -696,7 +709,7 @@ func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUser
}
roomIDs = roomIDs[:j]
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs)
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs)
if err != nil {
return err
}

View file

@ -814,6 +814,7 @@ func (v *StateResolution) resolveConflictsV2(
// events may be duplicated across these sets but that's OK.
authSets := make(map[string][]*gomatrixserverlib.Event, len(conflicted))
authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3)
gotAuthEvents := make(map[string]struct{}, estimate*3)
authDifference := make([]*gomatrixserverlib.Event, 0, estimate)
// For each conflicted event, let's try and get the needed auth events.
@ -850,9 +851,22 @@ func (v *StateResolution) resolveConflictsV2(
if err != nil {
return nil, err
}
authEvents = append(authEvents, authSets[key]...)
// Only add auth events into the authEvents slice once, otherwise the
// check for the auth difference can become expensive and produce
// duplicate entries, which just waste memory and CPU time.
for _, event := range authSets[key] {
if _, ok := gotAuthEvents[event.EventID()]; !ok {
authEvents = append(authEvents, event)
gotAuthEvents[event.EventID()] = struct{}{}
}
}
}
// Kill the reference to this so that the GC may pick it up, since we no
// longer need this after this point.
gotAuthEvents = nil // nolint:ineffassign
// This function helps us to work out whether an event exists in one of the
// auth sets.
isInAuthList := func(k string, event *gomatrixserverlib.Event) bool {
@ -866,11 +880,12 @@ func (v *StateResolution) resolveConflictsV2(
// This function works out if an event exists in all of the auth sets.
isInAllAuthLists := func(event *gomatrixserverlib.Event) bool {
found := true
for k := range authSets {
found = found && isInAuthList(k, event)
if !isInAuthList(k, event) {
return false
}
}
return found
return true
}
// Look through all of the auth events that we've been given and work out if

View file

@ -35,6 +35,11 @@ type Database interface {
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (types.StateSnapshotNID, error)
MissingAuthPrevEvents(
ctx context.Context, e *gomatrixserverlib.Event,
) (missingAuth, missingPrev []string, err error)
// Look up the state of a room at each event for a list of string event IDs.
// Returns an error if there is an error talking to the database.
// The length of []types.StateAtEvent is guaranteed to equal the length of eventIDs if no error is returned.
@ -141,13 +146,14 @@ type Database interface {
// If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error)
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error)
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
// JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms.
JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error)
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
// GetServerInRoom returns true if we think a server is in a given room or false otherwise.

View file

@ -66,7 +66,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
`
var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
@ -306,13 +307,10 @@ func (s *membershipStatements) SelectRoomsWithMembership(
func (s *membershipStatements) SelectJoinedUsersSetForRooms(
ctx context.Context, txn *sql.Tx,
roomNIDs []types.RoomNID,
userNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]int, error) {
roomIDarray := make([]int64, len(roomNIDs))
for i := range roomNIDs {
roomIDarray[i] = int64(roomNIDs[i])
}
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
if err != nil {
return nil, err
}

View file

@ -103,25 +103,6 @@ func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
return u.currentStateSnapshotNID
}
func (u *RoomUpdater) MissingAuthPrevEvents(
ctx context.Context, e *gomatrixserverlib.Event,
) (missingAuth, missingPrev []string, err error) {
for _, authEventID := range e.AuthEventIDs() {
if nids, err := u.EventNIDs(ctx, []string{authEventID}); err != nil || len(nids) == 0 {
missingAuth = append(missingAuth, authEventID)
}
}
for _, prevEventID := range e.PrevEventIDs() {
state, err := u.StateAtEventIDs(ctx, []string{prevEventID})
if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) {
missingPrev = append(missingPrev, prevEventID)
}
}
return
}
// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
@ -146,13 +127,6 @@ func (u *RoomUpdater) SnapshotNIDFromEventID(
return u.d.snapshotNIDFromEventID(ctx, u.txn, eventID)
}
func (u *RoomUpdater) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
return u.d.storeEvent(ctx, u, event, authEventNIDs, isRejected)
}
func (u *RoomUpdater) StateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
@ -212,44 +186,16 @@ func (u *RoomUpdater) EventIDs(
return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs)
}
func (u *RoomUpdater) EventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
return u.d.eventNIDs(ctx, u.txn, eventIDs, NoFilter)
}
func (u *RoomUpdater) UnsentEventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
return u.d.eventNIDs(ctx, u.txn, eventIDs, FilterUnsentOnly)
}
func (u *RoomUpdater) StateAtEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) {
return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
}
func (u *RoomUpdater) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) {
return u.d.EventsTable.BulkSelectStateEventByID(ctx, u.txn, eventIDs)
}
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false)
}
func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true)
}
func (u *RoomUpdater) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) {
return u.d.getMembershipEventNIDsForRoom(ctx, u.txn, roomNID, joinOnly, localOnly)
}
// IsReferenced implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)

View file

@ -674,6 +674,29 @@ func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) {
return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true)
}
func (d *Database) MissingAuthPrevEvents(
ctx context.Context, e *gomatrixserverlib.Event,
) (missingAuth, missingPrev []string, err error) {
authEventNIDs, err := d.EventNIDs(ctx, e.AuthEventIDs())
if err != nil {
return nil, nil, fmt.Errorf("d.EventNIDs: %w", err)
}
for _, authEventID := range e.AuthEventIDs() {
if _, ok := authEventNIDs[authEventID]; !ok {
missingAuth = append(missingAuth, authEventID)
}
}
for _, prevEventID := range e.PrevEventIDs() {
state, err := d.StateAtEventIDs(ctx, []string{prevEventID})
if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) {
missingPrev = append(missingPrev, prevEventID)
}
}
return
}
func (d *Database) assignRoomNID(
ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion,
@ -956,6 +979,62 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
return nil, nil
}
// Same as GetStateEvent but returns all matching state events with this event type. Returns no error
// if there are no events with this event type.
func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) {
roomInfo, err := d.RoomInfo(ctx, roomID)
if err != nil {
return nil, err
}
if roomInfo == nil {
return nil, fmt.Errorf("room %s doesn't exist", roomID)
}
// e.g invited rooms
if roomInfo.IsStub {
return nil, nil
}
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType)
if err == sql.ErrNoRows {
// No rooms have an event of this type, otherwise we'd have an event type NID
return nil, nil
}
if err != nil {
return nil, err
}
entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
if err != nil {
return nil, err
}
var eventNIDs []types.EventNID
for _, e := range entries {
if e.EventTypeNID == eventTypeNID {
eventNIDs = append(eventNIDs, e.EventNID)
}
}
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
// return the events requested
eventPairs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs)
if err != nil {
return nil, err
}
if len(eventPairs) == 0 {
return nil, nil
}
var result []*gomatrixserverlib.HeaderedEvent
for _, pair := range eventPairs {
ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[pair.EventNID], pair.EventJSON, false, roomInfo.RoomVersion)
if err != nil {
return nil, err
}
result = append(result, ev.Headered(roomInfo.RoomVersion))
}
return result, nil
}
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) {
var membershipState tables.MembershipState
@ -1081,13 +1160,23 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
return result, nil
}
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms.
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
if err != nil {
return nil, err
}
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs)
userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs)
if err != nil {
return nil, err
}
userNIDs := make([]types.EventStateKeyNID, 0, len(userNIDsMap))
nidToUserID := make(map[types.EventStateKeyNID]string, len(userNIDsMap))
for id, nid := range userNIDsMap {
userNIDs = append(userNIDs, nid)
nidToUserID[nid] = id
}
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs)
if err != nil {
return nil, err
}
@ -1097,10 +1186,6 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
stateKeyNIDs[i] = nid
i++
}
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs)
if err != nil {
return nil, err
}
if len(nidToUserID) != len(userNIDToCount) {
logrus.Warnf("SelectJoinedUsersSetForRooms found %d users but BulkSelectEventStateKey only returned state key NIDs for %d of them", len(userNIDToCount), len(nidToUserID))
}

View file

@ -42,7 +42,8 @@ const membershipSchema = `
`
var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid IN ($1) AND target_nid IN ($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
@ -280,18 +281,22 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil
}
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs {
iRoomNIDs[i] = v
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) {
params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs))
for _, v := range roomNIDs {
params = append(params, v)
}
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
for _, v := range userNIDs {
params = append(params, v)
}
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
var rows *sql.Rows
var err error
if txn != nil {
rows, err = txn.QueryContext(ctx, query, iRoomNIDs...)
rows, err = txn.QueryContext(ctx, query, params...)
} else {
rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...)
rows, err = s.db.QueryContext(ctx, query, params...)
}
if err != nil {
return nil, err

View file

@ -127,9 +127,8 @@ type Membership interface {
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
// counts of how many rooms they are joined.
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error)
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)

View file

@ -654,11 +654,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo
AuthEvents: res.AuthChain,
StateEvents: stateEvents,
}
eventsInOrder, err := respState.Events(rc.roomVersion)
if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("failed to calculate order to send events in MSC2836EventRelationshipsResponse")
return
}
eventsInOrder := respState.Events(rc.roomVersion)
// everything gets sent as an outlier because auth chain events may be disjoint from the DAG
// as may the threaded events.
var ires []roomserver.InputRoomEvent
@ -669,7 +665,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo
})
}
// we've got the data by this point so use a background context
err = roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false)
err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false)
if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver")
}

View file

@ -18,17 +18,18 @@ package msc2946
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/mux"
chttputil "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
fs "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/hooks"
"github.com/matrix-org/dendrite/internal/httputil"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
@ -39,15 +40,15 @@ import (
)
const (
ConstCreateEventContentKey = "type"
ConstSpaceChildEventType = "m.space.child"
ConstSpaceParentEventType = "m.space.parent"
ConstCreateEventContentKey = "type"
ConstCreateEventContentValueSpace = "m.space"
ConstSpaceChildEventType = "m.space.child"
ConstSpaceParentEventType = "m.space.parent"
)
// Defaults sets the request defaults
func Defaults(r *gomatrixserverlib.MSC2946SpacesRequest) {
r.Limit = 2000
r.MaxRoomsPerSpace = -1
type MSC2946ClientResponse struct {
Rooms []gomatrixserverlib.MSC2946Room `json:"rooms"`
NextBatch string `json:"next_batch,omitempty"`
}
// Enable this MSC
@ -55,26 +56,11 @@ func Enable(
base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI,
fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier,
) error {
db, err := NewDatabase(&base.Cfg.MSCs.Database)
if err != nil {
return fmt.Errorf("cannot enable MSC2946: %w", err)
}
hooks.Enable()
hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) {
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
hookErr := db.StoreReference(context.Background(), he)
if hookErr != nil {
util.GetLogger(context.Background()).WithError(hookErr).WithField("event_id", he.EventID()).Error(
"failed to StoreReference",
)
}
})
clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, base.Cfg.Global.ServerName))
base.PublicClientAPIMux.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions)
base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions)
base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/spaces",
httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(db, rsAPI, fsAPI, base.Cfg.Global.ServerName)),
).Methods(http.MethodPost, http.MethodOptions)
base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/spaces/{roomID}", httputil.MakeExternalAPI(
fedAPI := httputil.MakeExternalAPI(
"msc2946_fed_spaces", func(req *http.Request) util.JSONResponse {
fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest(
req, time.Now(), base.Cfg.Global.ServerName, keyRing,
@ -88,105 +74,99 @@ func Enable(
return util.ErrorResponse(err)
}
roomID := params["roomID"]
return federatedSpacesHandler(req.Context(), fedReq, roomID, db, rsAPI, fsAPI, base.Cfg.Global.ServerName)
return federatedSpacesHandler(req.Context(), fedReq, roomID, rsAPI, fsAPI, base.Cfg.Global.ServerName)
},
)).Methods(http.MethodPost, http.MethodOptions)
)
base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet)
base.PublicFederationAPIMux.Handle("/v1/hierarchy/{roomID}", fedAPI).Methods(http.MethodGet)
return nil
}
func federatedSpacesHandler(
ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, db Database,
ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string,
rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI,
thisServer gomatrixserverlib.ServerName,
) util.JSONResponse {
inMemoryBatchCache := make(map[string]set)
var r gomatrixserverlib.MSC2946SpacesRequest
Defaults(&r)
if err := json.Unmarshal(fedReq.Content(), &r); err != nil {
u, err := url.Parse(fedReq.RequestURI())
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
Code: 400,
JSON: jsonerror.InvalidParam("bad request uri"),
}
}
w := walker{
req: &r,
rootRoomID: roomID,
serverName: fedReq.Origin(),
thisServer: thisServer,
ctx: ctx,
db: db,
rsAPI: rsAPI,
fsAPI: fsAPI,
inMemoryBatchCache: inMemoryBatchCache,
}
res := w.walk()
return util.JSONResponse{
Code: 200,
JSON: res,
w := walker{
rootRoomID: roomID,
serverName: fedReq.Origin(),
thisServer: thisServer,
ctx: ctx,
suggestedOnly: u.Query().Get("suggested_only") == "true",
limit: 1000,
// The main difference is that it does not recurse into spaces and does not support pagination.
// This is somewhat equivalent to a Client-Server request with a max_depth=1.
maxDepth: 1,
rsAPI: rsAPI,
fsAPI: fsAPI,
// inline cache as we don't have pagination in federation mode
paginationCache: make(map[string]paginationInfo),
}
return w.walk()
}
func spacesHandler(
db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI,
rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI,
thisServer gomatrixserverlib.ServerName,
) func(*http.Request, *userapi.Device) util.JSONResponse {
// declared outside the returned handler so it persists between calls
// TODO: clear based on... time?
paginationCache := make(map[string]paginationInfo)
return func(req *http.Request, device *userapi.Device) util.JSONResponse {
inMemoryBatchCache := make(map[string]set)
// Extract the room ID from the request. Sanity check request data.
params, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
roomID := params["roomID"]
var r gomatrixserverlib.MSC2946SpacesRequest
Defaults(&r)
if resErr := chttputil.UnmarshalJSONRequest(req, &r); resErr != nil {
return *resErr
}
w := walker{
req: &r,
rootRoomID: roomID,
caller: device,
thisServer: thisServer,
ctx: req.Context(),
suggestedOnly: req.URL.Query().Get("suggested_only") == "true",
limit: parseInt(req.URL.Query().Get("limit"), 1000),
maxDepth: parseInt(req.URL.Query().Get("max_depth"), -1),
paginationToken: req.URL.Query().Get("from"),
rootRoomID: roomID,
caller: device,
thisServer: thisServer,
ctx: req.Context(),
db: db,
rsAPI: rsAPI,
fsAPI: fsAPI,
inMemoryBatchCache: inMemoryBatchCache,
}
res := w.walk()
return util.JSONResponse{
Code: 200,
JSON: res,
rsAPI: rsAPI,
fsAPI: fsAPI,
paginationCache: paginationCache,
}
return w.walk()
}
}
type paginationInfo struct {
processed set
unvisited []roomVisit
}
type walker struct {
req *gomatrixserverlib.MSC2946SpacesRequest
rootRoomID string
caller *userapi.Device
serverName gomatrixserverlib.ServerName
thisServer gomatrixserverlib.ServerName
db Database
rsAPI roomserver.RoomserverInternalAPI
fsAPI fs.FederationInternalAPI
ctx context.Context
rootRoomID string
caller *userapi.Device
serverName gomatrixserverlib.ServerName
thisServer gomatrixserverlib.ServerName
rsAPI roomserver.RoomserverInternalAPI
fsAPI fs.FederationInternalAPI
ctx context.Context
suggestedOnly bool
limit int
maxDepth int
paginationToken string
// user ID|device ID|batch_num => event/room IDs sent to client
inMemoryBatchCache map[string]set
mu sync.Mutex
}
func (w *walker) roomIsExcluded(roomID string) bool {
for _, exclRoom := range w.req.ExcludeRooms {
if exclRoom == roomID {
return true
}
}
return false
paginationCache map[string]paginationInfo
mu sync.Mutex
}
func (w *walker) callerID() string {
@ -196,144 +176,207 @@ func (w *walker) callerID() string {
return string(w.serverName)
}
func (w *walker) alreadySent(id string) bool {
w.mu.Lock()
defer w.mu.Unlock()
m, ok := w.inMemoryBatchCache[w.callerID()]
if !ok {
return false
func (w *walker) newPaginationCache() (string, paginationInfo) {
p := paginationInfo{
processed: make(set),
unvisited: nil,
}
return m[id]
tok := uuid.NewString()
return tok, p
}
func (w *walker) markSent(id string) {
func (w *walker) loadPaginationCache(paginationToken string) *paginationInfo {
w.mu.Lock()
defer w.mu.Unlock()
m := w.inMemoryBatchCache[w.callerID()]
if m == nil {
m = make(set)
}
m[id] = true
w.inMemoryBatchCache[w.callerID()] = m
p := w.paginationCache[paginationToken]
return &p
}
func (w *walker) walk() *gomatrixserverlib.MSC2946SpacesResponse {
var res gomatrixserverlib.MSC2946SpacesResponse
// Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms
unvisited := []string{w.rootRoomID}
processed := make(set)
func (w *walker) storePaginationCache(paginationToken string, cache paginationInfo) {
w.mu.Lock()
defer w.mu.Unlock()
w.paginationCache[paginationToken] = cache
}
type roomVisit struct {
roomID string
depth int
vias []string // vias to query this room by
}
func (w *walker) walk() util.JSONResponse {
if !w.authorised(w.rootRoomID) {
if w.caller != nil {
// CS API format
return util.JSONResponse{
Code: 403,
JSON: jsonerror.Forbidden("room is unknown/forbidden"),
}
} else {
// SS API format
return util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound("room is unknown/forbidden"),
}
}
}
var discoveredRooms []gomatrixserverlib.MSC2946Room
var cache *paginationInfo
if w.paginationToken != "" {
cache = w.loadPaginationCache(w.paginationToken)
if cache == nil {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.InvalidArgumentValue("invalid from"),
}
}
} else {
tok, c := w.newPaginationCache()
cache = &c
w.paginationToken = tok
// Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms
c.unvisited = append(c.unvisited, roomVisit{
roomID: w.rootRoomID,
depth: 0,
})
}
processed := cache.processed
unvisited := cache.unvisited
// Depth first -> stack data structure
for len(unvisited) > 0 {
roomID := unvisited[0]
unvisited = unvisited[1:]
// If this room has already been processed, skip. NB: do not remember this between calls
if processed[roomID] || roomID == "" {
if len(discoveredRooms) >= w.limit {
break
}
// pop the stack
rv := unvisited[len(unvisited)-1]
unvisited = unvisited[:len(unvisited)-1]
// If this room has already been processed, skip.
// If this room exceeds the specified depth, skip.
if processed.isSet(rv.roomID) || rv.roomID == "" || (w.maxDepth > 0 && rv.depth > w.maxDepth) {
continue
}
// Mark this room as processed.
processed[roomID] = true
processed.set(rv.roomID)
// if this room is not a space room, skip.
var roomType string
create := w.stateEvent(rv.roomID, gomatrixserverlib.MRoomCreate, "")
if create != nil {
// escape the `.`s so gjson doesn't think it's nested
roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str
}
// Collect rooms/events to send back (either locally or fetched via federation)
var discoveredRooms []gomatrixserverlib.MSC2946Room
var discoveredEvents []gomatrixserverlib.MSC2946StrippedEvent
var discoveredChildEvents []gomatrixserverlib.MSC2946StrippedEvent
// If we know about this room and the caller is authorised (joined/world_readable) then pull
// events locally
if w.roomExists(roomID) && w.authorised(roomID) {
// Get all `m.space.child` and `m.space.parent` state events for the room. *In addition*, get
// all `m.space.child` and `m.space.parent` state events which *point to* (via `state_key` or `content.room_id`)
// this room. This requires servers to store reverse lookups.
events, err := w.references(roomID)
if w.roomExists(rv.roomID) && w.authorised(rv.roomID) {
// Get all `m.space.child` state events for this room
events, err := w.childReferences(rv.roomID)
if err != nil {
util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Error("failed to extract references for room")
util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Error("failed to extract references for room")
continue
}
discoveredEvents = events
discoveredChildEvents = events
pubRoom := w.publicRoomsChunk(roomID)
roomType := ""
create := w.stateEvent(roomID, gomatrixserverlib.MRoomCreate, "")
if create != nil {
// escape the `.`s so gjson doesn't think it's nested
roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str
}
pubRoom := w.publicRoomsChunk(rv.roomID)
// Add the total number of events to `PublicRoomsChunk` under `num_refs`. Add `PublicRoomsChunk` to `rooms`.
discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{
PublicRoom: *pubRoom,
NumRefs: len(discoveredEvents),
RoomType: roomType,
PublicRoom: *pubRoom,
RoomType: roomType,
ChildrenState: events,
})
} else {
// attempt to query this room over federation, as either we've never heard of it before
// or we've left it and hence are not authorised (but info may be exposed regardless)
fedRes, err := w.federatedRoomInfo(roomID)
fedRes, err := w.federatedRoomInfo(rv.roomID, rv.vias)
if err != nil {
util.GetLogger(w.ctx).WithError(err).WithField("room_id", roomID).Errorf("failed to query federated spaces")
util.GetLogger(w.ctx).WithError(err).WithField("room_id", rv.roomID).Errorf("failed to query federated spaces")
continue
}
if fedRes != nil {
discoveredRooms = fedRes.Rooms
discoveredEvents = fedRes.Events
discoveredChildEvents = fedRes.Room.ChildrenState
discoveredRooms = append(discoveredRooms, fedRes.Room)
if len(fedRes.Children) > 0 {
discoveredRooms = append(discoveredRooms, fedRes.Children...)
}
// mark this room as a space room as the federated server responded.
// we need to do this so we add the children of this room to the unvisited stack
// as these children may be rooms we do know about.
roomType = ConstCreateEventContentValueSpace
}
}
// If this room has not ever been in `rooms` (across multiple requests), send it now
for _, room := range discoveredRooms {
if !w.alreadySent(room.RoomID) && !w.roomIsExcluded(room.RoomID) {
res.Rooms = append(res.Rooms, room)
w.markSent(room.RoomID)
}
// don't walk the children
// if the parent is not a space room
if roomType != ConstCreateEventContentValueSpace {
continue
}
uniqueRooms := make(set)
// If this is the root room from the original request, insert all these events into `events` if
// they haven't been added before (across multiple requests).
if w.rootRoomID == roomID {
for _, ev := range discoveredEvents {
if !w.alreadySent(eventKey(&ev)) {
res.Events = append(res.Events, ev)
uniqueRooms[ev.RoomID] = true
uniqueRooms[spaceTargetStripped(&ev)] = true
w.markSent(eventKey(&ev))
}
}
} else {
// Else add them to `events` honouring the `limit` and `max_rooms_per_space` values. If either
// are exceeded, stop adding events. If the event has already been added, do not add it again.
numAdded := 0
for _, ev := range discoveredEvents {
if w.req.Limit > 0 && len(res.Events) >= w.req.Limit {
break
}
if w.req.MaxRoomsPerSpace > 0 && numAdded >= w.req.MaxRoomsPerSpace {
break
}
if w.alreadySent(eventKey(&ev)) {
continue
}
// Skip the room if it's part of exclude_rooms but ONLY IF the source matches, as we still
// want to catch arrows which point to excluded rooms.
if w.roomIsExcluded(ev.RoomID) {
continue
}
res.Events = append(res.Events, ev)
uniqueRooms[ev.RoomID] = true
uniqueRooms[spaceTargetStripped(&ev)] = true
w.markSent(eventKey(&ev))
// we don't distinguish between child state events and parent state events for the purposes of
// max_rooms_per_space, maybe we should?
numAdded++
}
}
// For each referenced room ID in the events being returned to the caller (both parent and child)
// For each referenced room ID in the child events being returned to the caller
// add the room ID to the queue of unvisited rooms. Loop from the beginning.
for roomID := range uniqueRooms {
unvisited = append(unvisited, roomID)
// We need to invert the order here because the child events are lo->hi on the timestamp,
// so we need to ensure we pop in the same lo->hi order, which won't be the case if we
// insert the highest timestamp last in a stack.
for i := len(discoveredChildEvents) - 1; i >= 0; i-- {
spaceContent := struct {
Via []string `json:"via"`
}{}
ev := discoveredChildEvents[i]
_ = json.Unmarshal(ev.Content, &spaceContent)
unvisited = append(unvisited, roomVisit{
roomID: ev.StateKey,
depth: rv.depth + 1,
vias: spaceContent.Via,
})
}
}
return &res
if len(unvisited) > 0 {
// we still have more rooms so we need to send back a pagination token,
// we probably hit a room limit
cache.processed = processed
cache.unvisited = unvisited
w.storePaginationCache(w.paginationToken, *cache)
} else {
// clear the pagination token so we don't send it back to the client
// Note we do NOT nuke the cache just in case this response is lost
// and the client retries it.
w.paginationToken = ""
}
if w.caller != nil {
// return CS API format
return util.JSONResponse{
Code: 200,
JSON: MSC2946ClientResponse{
Rooms: discoveredRooms,
NextBatch: w.paginationToken,
},
}
}
// return SS API format
// the first discovered room will be the room asked for, and subsequent ones the depth=1 children
if len(discoveredRooms) == 0 {
return util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound("room is unknown/forbidden"),
}
}
return util.JSONResponse{
Code: 200,
JSON: gomatrixserverlib.MSC2946SpacesResponse{
Room: discoveredRooms[0],
Children: discoveredRooms[1:],
},
}
}
func (w *walker) stateEvent(roomID, evType, stateKey string) *gomatrixserverlib.HeaderedEvent {
@ -366,42 +409,19 @@ func (w *walker) publicRoomsChunk(roomID string) *gomatrixserverlib.PublicRoom {
// federatedRoomInfo returns more of the spaces graph from another server. Returns nil if this was
// unsuccessful.
func (w *walker) federatedRoomInfo(roomID string) (*gomatrixserverlib.MSC2946SpacesResponse, error) {
func (w *walker) federatedRoomInfo(roomID string, vias []string) (*gomatrixserverlib.MSC2946SpacesResponse, error) {
// only do federated requests for client requests
if w.caller == nil {
return nil, nil
}
// extract events which point to this room ID and extract their vias
events, err := w.db.References(w.ctx, roomID)
if err != nil {
return nil, fmt.Errorf("failed to get References events: %w", err)
}
vias := make(set)
for _, ev := range events {
if ev.StateKeyEquals(roomID) {
// event points at this room, extract vias
content := struct {
Vias []string `json:"via"`
}{}
if err = json.Unmarshal(ev.Content(), &content); err != nil {
continue // silently ignore corrupted state events
}
for _, v := range content.Vias {
vias[v] = true
}
}
}
util.GetLogger(w.ctx).Infof("Querying federatedRoomInfo via %+v", vias)
util.GetLogger(w.ctx).Infof("Querying %s via %+v", roomID, vias)
ctx := context.Background()
// query more of the spaces graph using these servers
for serverName := range vias {
for _, serverName := range vias {
if serverName == string(w.thisServer) {
continue
}
res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, gomatrixserverlib.MSC2946SpacesRequest{
Limit: w.req.Limit,
MaxRoomsPerSpace: w.req.MaxRoomsPerSpace,
})
res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly)
if err != nil {
util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName)
continue
@ -501,7 +521,7 @@ func (w *walker) authorisedUser(roomID string) bool {
hisVisEv := queryRes.StateEvents[hisVisTuple]
if memberEv != nil {
membership, _ := memberEv.Membership()
if membership == gomatrixserverlib.Join {
if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite {
return true
}
}
@ -514,40 +534,85 @@ func (w *walker) authorisedUser(roomID string) bool {
return false
}
// references returns all references pointing to or from this room.
func (w *walker) references(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) {
events, err := w.db.References(w.ctx, roomID)
// references returns all child references pointing to or from this room.
func (w *walker) childReferences(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) {
createTuple := gomatrixserverlib.StateKeyTuple{
EventType: gomatrixserverlib.MRoomCreate,
StateKey: "",
}
var res roomserver.QueryCurrentStateResponse
err := w.rsAPI.QueryCurrentState(context.Background(), &roomserver.QueryCurrentStateRequest{
RoomID: roomID,
AllowWildcards: true,
StateTuples: []gomatrixserverlib.StateKeyTuple{
createTuple, {
EventType: ConstSpaceChildEventType,
StateKey: "*",
},
},
}, &res)
if err != nil {
return nil, err
}
el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(events))
for _, ev := range events {
// don't return any child refs if the room is not a space room
if res.StateEvents[createTuple] != nil {
// escape the `.`s so gjson doesn't think it's nested
roomType := gjson.GetBytes(res.StateEvents[createTuple].Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str
if roomType != ConstCreateEventContentValueSpace {
return nil, nil
}
}
delete(res.StateEvents, createTuple)
el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(res.StateEvents))
for _, ev := range res.StateEvents {
content := gjson.ParseBytes(ev.Content())
// only return events that have a `via` key as per MSC1772
// else we'll incorrectly walk redacted events (as the link
// is in the state_key)
if gjson.GetBytes(ev.Content(), "via").Exists() {
if content.Get("via").Exists() {
strip := stripped(ev.Event)
if strip == nil {
continue
}
// if suggested only and this child isn't suggested, skip it.
// if suggested only = false we include everything so don't need to check the content.
if w.suggestedOnly && !content.Get("suggested").Bool() {
continue
}
el = append(el, *strip)
}
}
// sort by origin_server_ts as per MSC2946
sort.Slice(el, func(i, j int) bool {
return el[i].OriginServerTS < el[j].OriginServerTS
})
return el, nil
}
type set map[string]bool
type set map[string]struct{}
func (s set) set(val string) {
s[val] = struct{}{}
}
func (s set) isSet(val string) bool {
_, ok := s[val]
return ok
}
func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEvent {
if ev.StateKey() == nil {
return nil
}
return &gomatrixserverlib.MSC2946StrippedEvent{
Type: ev.Type(),
StateKey: *ev.StateKey(),
Content: ev.Content(),
Sender: ev.Sender(),
RoomID: ev.RoomID(),
Type: ev.Type(),
StateKey: *ev.StateKey(),
Content: ev.Content(),
Sender: ev.Sender(),
RoomID: ev.RoomID(),
OriginServerTS: ev.OriginServerTS(),
}
}
@ -567,3 +632,11 @@ func spaceTargetStripped(event *gomatrixserverlib.MSC2946StrippedEvent) string {
}
return ""
}
func parseInt(intstr string, defaultVal int) int {
i, err := strconv.ParseInt(intstr, 10, 32)
if err != nil {
return defaultVal
}
return int(i)
}

View file

@ -1,464 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msc2946_test
import (
"bytes"
"context"
"crypto/ed25519"
"encoding/json"
"io/ioutil"
"net/http"
"net/url"
"testing"
"time"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/hooks"
"github.com/matrix-org/dendrite/internal/httputil"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/mscs/msc2946"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
var (
client = &http.Client{
Timeout: 10 * time.Second,
}
roomVer = gomatrixserverlib.RoomVersionV6
)
// Basic sanity check of MSC2946 logic. Tests a single room with a few state events
// and a bit of recursion to subspaces. Makes a graph like:
// Root
// ____|_____
// | | |
// R1 R2 S1
// |_________
// | | |
// R3 R4 S2
// | <-- this link is just a parent, not a child
// R5
//
// Alice is not joined to R4, but R4 is "world_readable".
func TestMSC2946(t *testing.T) {
alice := "@alice:localhost"
// give access token to alice
nopUserAPI := &testUserAPI{
accessTokens: make(map[string]userapi.Device),
}
nopUserAPI.accessTokens["alice"] = userapi.Device{
AccessToken: "alice",
DisplayName: "Alice",
UserID: alice,
}
rootSpace := "!rootspace:localhost"
subSpaceS1 := "!subspaceS1:localhost"
subSpaceS2 := "!subspaceS2:localhost"
room1 := "!room1:localhost"
room2 := "!room2:localhost"
room3 := "!room3:localhost"
room4 := "!room4:localhost"
empty := ""
room5 := "!room5:localhost"
allRooms := []string{
rootSpace, subSpaceS1, subSpaceS2,
room1, room2, room3, room4, room5,
}
rootToR1 := mustCreateEvent(t, fledglingEvent{
RoomID: rootSpace,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room1,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
rootToR2 := mustCreateEvent(t, fledglingEvent{
RoomID: rootSpace,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room2,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
rootToS1 := mustCreateEvent(t, fledglingEvent{
RoomID: rootSpace,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &subSpaceS1,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
s1ToR3 := mustCreateEvent(t, fledglingEvent{
RoomID: subSpaceS1,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room3,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
s1ToR4 := mustCreateEvent(t, fledglingEvent{
RoomID: subSpaceS1,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room4,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
s1ToS2 := mustCreateEvent(t, fledglingEvent{
RoomID: subSpaceS1,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &subSpaceS2,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
// This is a parent link only
s2ToR5 := mustCreateEvent(t, fledglingEvent{
RoomID: room5,
Sender: alice,
Type: msc2946.ConstSpaceParentEventType,
StateKey: &subSpaceS2,
Content: map[string]interface{}{
"via": []string{"localhost"},
},
})
// history visibility for R4
r4HisVis := mustCreateEvent(t, fledglingEvent{
RoomID: room4,
Sender: "@someone:localhost",
Type: gomatrixserverlib.MRoomHistoryVisibility,
StateKey: &empty,
Content: map[string]interface{}{
"history_visibility": "world_readable",
},
})
var joinEvents []*gomatrixserverlib.HeaderedEvent
for _, roomID := range allRooms {
if roomID == room4 {
continue // not joined to that room
}
joinEvents = append(joinEvents, mustCreateEvent(t, fledglingEvent{
RoomID: roomID,
Sender: alice,
StateKey: &alice,
Type: gomatrixserverlib.MRoomMember,
Content: map[string]interface{}{
"membership": "join",
},
}))
}
roomNameTuple := gomatrixserverlib.StateKeyTuple{
EventType: "m.room.name",
StateKey: "",
}
hisVisTuple := gomatrixserverlib.StateKeyTuple{
EventType: "m.room.history_visibility",
StateKey: "",
}
nopRsAPI := &testRoomserverAPI{
joinEvents: joinEvents,
events: map[string]*gomatrixserverlib.HeaderedEvent{
rootToR1.EventID(): rootToR1,
rootToR2.EventID(): rootToR2,
rootToS1.EventID(): rootToS1,
s1ToR3.EventID(): s1ToR3,
s1ToR4.EventID(): s1ToR4,
s1ToS2.EventID(): s1ToS2,
s2ToR5.EventID(): s2ToR5,
r4HisVis.EventID(): r4HisVis,
},
pubRoomState: map[string]map[gomatrixserverlib.StateKeyTuple]string{
rootSpace: {
roomNameTuple: "Root",
hisVisTuple: "shared",
},
subSpaceS1: {
roomNameTuple: "Sub-Space 1",
hisVisTuple: "joined",
},
subSpaceS2: {
roomNameTuple: "Sub-Space 2",
hisVisTuple: "shared",
},
room1: {
hisVisTuple: "joined",
},
room2: {
hisVisTuple: "joined",
},
room3: {
hisVisTuple: "joined",
},
room4: {
hisVisTuple: "world_readable",
},
room5: {
hisVisTuple: "joined",
},
},
}
allEvents := []*gomatrixserverlib.HeaderedEvent{
rootToR1, rootToR2, rootToS1,
s1ToR3, s1ToR4, s1ToS2,
s2ToR5, r4HisVis,
}
allEvents = append(allEvents, joinEvents...)
router := injectEvents(t, nopUserAPI, nopRsAPI, allEvents)
cancel := runServer(t, router)
defer cancel()
t.Run("returns no events for unknown rooms", func(t *testing.T) {
res := postSpaces(t, 200, "alice", "!unknown:localhost", newReq(t, map[string]interface{}{}))
if len(res.Events) > 0 {
t.Errorf("got %d events, want 0", len(res.Events))
}
if len(res.Rooms) > 0 {
t.Errorf("got %d rooms, want 0", len(res.Rooms))
}
})
t.Run("returns the entire graph", func(t *testing.T) {
res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{}))
if len(res.Events) != 7 {
t.Errorf("got %d events, want 7", len(res.Events))
}
if len(res.Rooms) != len(allRooms) {
t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms))
}
})
t.Run("can update the graph", func(t *testing.T) {
// remove R3 from the graph
rmS1ToR3 := mustCreateEvent(t, fledglingEvent{
RoomID: subSpaceS1,
Sender: alice,
Type: msc2946.ConstSpaceChildEventType,
StateKey: &room3,
Content: map[string]interface{}{}, // redacted
})
nopRsAPI.events[rmS1ToR3.EventID()] = rmS1ToR3
hooks.Run(hooks.KindNewEventPersisted, rmS1ToR3)
res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{}))
if len(res.Events) != 6 { // one less since we don't return redacted events
t.Errorf("got %d events, want 6", len(res.Events))
}
if len(res.Rooms) != (len(allRooms) - 1) { // one less due to lack of R3
t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms)-1)
}
})
}
func newReq(t *testing.T, jsonBody map[string]interface{}) *gomatrixserverlib.MSC2946SpacesRequest {
t.Helper()
b, err := json.Marshal(jsonBody)
if err != nil {
t.Fatalf("Failed to marshal request: %s", err)
}
var r gomatrixserverlib.MSC2946SpacesRequest
if err := json.Unmarshal(b, &r); err != nil {
t.Fatalf("Failed to unmarshal request: %s", err)
}
return &r
}
func runServer(t *testing.T, router *mux.Router) func() {
t.Helper()
externalServ := &http.Server{
Addr: string(":8010"),
WriteTimeout: 60 * time.Second,
Handler: router,
}
go func() {
externalServ.ListenAndServe()
}()
// wait to listen on the port
time.Sleep(500 * time.Millisecond)
return func() {
externalServ.Shutdown(context.TODO())
}
}
func postSpaces(t *testing.T, expectCode int, accessToken, roomID string, req *gomatrixserverlib.MSC2946SpacesRequest) *gomatrixserverlib.MSC2946SpacesResponse {
t.Helper()
var r gomatrixserverlib.MSC2946SpacesRequest
msc2946.Defaults(&r)
data, err := json.Marshal(req)
if err != nil {
t.Fatalf("failed to marshal request: %s", err)
}
httpReq, err := http.NewRequest(
"POST", "http://localhost:8010/_matrix/client/unstable/org.matrix.msc2946/rooms/"+url.PathEscape(roomID)+"/spaces",
bytes.NewBuffer(data),
)
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
if err != nil {
t.Fatalf("failed to prepare request: %s", err)
}
res, err := client.Do(httpReq)
if err != nil {
t.Fatalf("failed to do request: %s", err)
}
if res.StatusCode != expectCode {
body, _ := ioutil.ReadAll(res.Body)
t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body))
}
if res.StatusCode == 200 {
var result gomatrixserverlib.MSC2946SpacesResponse
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("response 200 OK but failed to read response body: %s", err)
}
t.Logf("Body: %s", string(body))
if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("response 200 OK but failed to deserialise JSON : %s\nbody: %s", err, string(body))
}
return &result
}
return nil
}
type testUserAPI struct {
userapi.UserInternalAPITrace
accessTokens map[string]userapi.Device
}
func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error {
dev, ok := u.accessTokens[req.AccessToken]
if !ok {
res.Err = "unknown token"
return nil
}
res.Device = &dev
return nil
}
type testRoomserverAPI struct {
// use a trace API as it implements method stubs so we don't need to have them here.
// We'll override the functions we care about.
roomserver.RoomserverInternalAPITrace
joinEvents []*gomatrixserverlib.HeaderedEvent
events map[string]*gomatrixserverlib.HeaderedEvent
pubRoomState map[string]map[gomatrixserverlib.StateKeyTuple]string
}
func (r *testRoomserverAPI) QueryServerJoinedToRoom(ctx context.Context, req *roomserver.QueryServerJoinedToRoomRequest, res *roomserver.QueryServerJoinedToRoomResponse) error {
res.IsInRoom = true
res.RoomExists = true
return nil
}
func (r *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *roomserver.QueryBulkStateContentRequest, res *roomserver.QueryBulkStateContentResponse) error {
res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string)
for _, roomID := range req.RoomIDs {
pubRoomData, ok := r.pubRoomState[roomID]
if ok {
res.Rooms[roomID] = pubRoomData
}
}
return nil
}
func (r *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *roomserver.QueryCurrentStateRequest, res *roomserver.QueryCurrentStateResponse) error {
res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
checkEvent := func(he *gomatrixserverlib.HeaderedEvent) {
if he.RoomID() != req.RoomID {
return
}
if he.StateKey() == nil {
return
}
tuple := gomatrixserverlib.StateKeyTuple{
EventType: he.Type(),
StateKey: *he.StateKey(),
}
for _, t := range req.StateTuples {
if t == tuple {
res.StateEvents[t] = he
}
}
}
for _, he := range r.joinEvents {
checkEvent(he)
}
for _, he := range r.events {
checkEvent(he)
}
return nil
}
func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*gomatrixserverlib.HeaderedEvent) *mux.Router {
t.Helper()
cfg := &config.Dendrite{}
cfg.Defaults(true)
cfg.Global.ServerName = "localhost"
cfg.MSCs.Database.ConnectionString = "file:msc2946_test.db"
cfg.MSCs.MSCs = []string{"msc2946"}
base := &base.BaseDendrite{
Cfg: cfg,
PublicClientAPIMux: mux.NewRouter().PathPrefix(httputil.PublicClientPathPrefix).Subrouter(),
PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(),
}
err := msc2946.Enable(base, rsAPI, userAPI, nil, nil)
if err != nil {
t.Fatalf("failed to enable MSC2946: %s", err)
}
for _, ev := range events {
hooks.Run(hooks.KindNewEventPersisted, ev)
}
return base.PublicClientAPIMux
}
type fledglingEvent struct {
Type string
StateKey *string
Content interface{}
Sender string
RoomID string
}
func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) {
t.Helper()
seed := make([]byte, ed25519.SeedSize) // zero seed
key := ed25519.NewKeyFromSeed(seed)
eb := gomatrixserverlib.EventBuilder{
Sender: ev.Sender,
Depth: 999,
Type: ev.Type,
StateKey: ev.StateKey,
RoomID: ev.RoomID,
}
err := eb.SetContent(ev.Content)
if err != nil {
t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content)
}
// make sure the origin_server_ts changes so we can test recency
time.Sleep(1 * time.Millisecond)
signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer)
if err != nil {
t.Fatalf("mustCreateEvent: failed to sign event: %s", err)
}
h := signedEvent.Headered(roomVer)
return h
}

View file

@ -1,182 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msc2946
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
var (
relTypes = map[string]int{
ConstSpaceChildEventType: 1,
ConstSpaceParentEventType: 2,
}
)
type Database interface {
// StoreReference persists a child or parent space mapping.
StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error
// References returns all events which have the given roomID as a parent or child space.
References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error)
}
type DB struct {
db *sql.DB
writer sqlutil.Writer
insertEdgeStmt *sql.Stmt
selectEdgesStmt *sql.Stmt
}
// NewDatabase loads the database for msc2836
func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
if dbOpts.ConnectionString.IsPostgres() {
return newPostgresDatabase(dbOpts)
}
return newSQLiteDatabase(dbOpts)
}
func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
d := DB{
writer: sqlutil.NewDummyWriter(),
}
var err error
if d.db, err = sqlutil.Open(dbOpts); err != nil {
return nil, err
}
_, err = d.db.Exec(`
CREATE TABLE IF NOT EXISTS msc2946_edges (
room_version TEXT NOT NULL,
-- the room ID of the event, the source of the arrow
source_room_id TEXT NOT NULL,
-- the target room ID, the arrow destination
dest_room_id TEXT NOT NULL,
-- the kind of relation, either child or parent (1,2)
rel_type SMALLINT NOT NULL,
event_json TEXT NOT NULL,
CONSTRAINT msc2946_edges_uniq UNIQUE (source_room_id, dest_room_id, rel_type)
);
`)
if err != nil {
return nil, err
}
if d.insertEdgeStmt, err = d.db.Prepare(`
INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json)
VALUES($1, $2, $3, $4, $5)
ON CONFLICT ON CONSTRAINT msc2946_edges_uniq DO UPDATE SET event_json = $5
`); err != nil {
return nil, err
}
if d.selectEdgesStmt, err = d.db.Prepare(`
SELECT room_version, event_json FROM msc2946_edges
WHERE source_room_id = $1 OR dest_room_id = $2
`); err != nil {
return nil, err
}
return &d, err
}
func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
d := DB{
writer: sqlutil.NewExclusiveWriter(),
}
var err error
if d.db, err = sqlutil.Open(dbOpts); err != nil {
return nil, err
}
_, err = d.db.Exec(`
CREATE TABLE IF NOT EXISTS msc2946_edges (
room_version TEXT NOT NULL,
-- the room ID of the event, the source of the arrow
source_room_id TEXT NOT NULL,
-- the target room ID, the arrow destination
dest_room_id TEXT NOT NULL,
-- the kind of relation, either child or parent (1,2)
rel_type SMALLINT NOT NULL,
event_json TEXT NOT NULL,
UNIQUE (source_room_id, dest_room_id, rel_type)
);
`)
if err != nil {
return nil, err
}
if d.insertEdgeStmt, err = d.db.Prepare(`
INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json)
VALUES($1, $2, $3, $4, $5)
ON CONFLICT (source_room_id, dest_room_id, rel_type) DO UPDATE SET event_json = $5
`); err != nil {
return nil, err
}
if d.selectEdgesStmt, err = d.db.Prepare(`
SELECT room_version, event_json FROM msc2946_edges
WHERE source_room_id = $1 OR dest_room_id = $2
`); err != nil {
return nil, err
}
return &d, err
}
func (d *DB) StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error {
target := SpaceTarget(he)
if target == "" {
return nil // malformed event
}
relType := relTypes[he.Type()]
_, err := d.insertEdgeStmt.ExecContext(ctx, he.RoomVersion, he.RoomID(), target, relType, he.JSON())
return err
}
func (d *DB) References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) {
rows, err := d.selectEdgesStmt.QueryContext(ctx, roomID, roomID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "failed to close References")
refs := make([]*gomatrixserverlib.HeaderedEvent, 0)
for rows.Next() {
var roomVer string
var jsonBytes []byte
if err := rows.Scan(&roomVer, &jsonBytes); err != nil {
return nil, err
}
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(jsonBytes, false, gomatrixserverlib.RoomVersion(roomVer))
if err != nil {
return nil, err
}
he := ev.Headered(gomatrixserverlib.RoomVersion(roomVer))
refs = append(refs, he)
}
return refs, nil
}
// SpaceTarget returns the destination room ID for the space event. This is either a child or a parent
// depending on the event type.
func SpaceTarget(he *gomatrixserverlib.HeaderedEvent) string {
if he.StateKey() == nil {
return "" // no-op
}
switch he.Type() {
case ConstSpaceParentEventType:
return *he.StateKey()
case ConstSpaceChildEventType:
return *he.StateKey()
}
return ""
}

View file

@ -82,7 +82,16 @@ func DeviceListCatchup(
util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed")
return to, hasNew, nil
}
// QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user.
// Work out which user IDs we care about — that includes those in the original request,
// the response from QueryKeyChanges (which includes ALL users who have changed keys)
// as well as every user who has a join or leave event in the current sync response. We
// will request information about which rooms these users are joined to, so that we can
// see if we still share any rooms with them.
joinUserIDs, leaveUserIDs := membershipEvents(res)
queryRes.UserIDs = append(queryRes.UserIDs, joinUserIDs...)
queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...)
queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs)
var sharedUsersMap map[string]int
sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs)
util.GetLogger(ctx).Debugf(
@ -100,9 +109,8 @@ func DeviceListCatchup(
userSet[userID] = true
}
}
// if the response has any join/leave events, add them now.
// Finally, add in users who have joined or left.
// TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them.
joinUserIDs, leaveUserIDs := membershipEvents(res)
for _, userID := range joinUserIDs {
if !userSet[userID] {
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
@ -213,7 +221,8 @@ func filterSharedUsers(
var result []string
var sharedUsersRes roomserverAPI.QuerySharedUsersResponse
err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{
UserID: userID,
UserID: userID,
OtherUserIDs: usersWithChangedKeys,
}, &sharedUsersRes)
if err != nil {
// default to all users so we do needless queries rather than miss some important device update

View file

@ -24,6 +24,7 @@ Local device key changes get to remote servers with correct prev_id
# Flakey
Local device key changes appear in /keys/changes
/context/ with lazy_load_members filter works
# we don't support groups
Remove group category

View file

@ -597,3 +597,11 @@ Device list doesn't change if remote server is down
/context/ on non world readable room does not work
/context/ returns correct number of events
/context/ with lazy_load_members filter works
Can query remote device keys using POST after notification
Device deletion propagates over federation
Get left notifs in sync and /keys/changes when other user leaves
Remote banned user is kicked and may not rejoin until unbanned
registration remembers parameters
registration accepts non-ascii passwords
registration with inhibit_login inhibits login